Skip to content

Commit 5d3a6aa

Browse files
committed
Fix for shutdown pool after doing a map
1 parent 119d92f commit 5d3a6aa

1 file changed

Lines changed: 17 additions & 24 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
from __future__ import annotations
2727

28-
import concurrent.futures
2928
import fnmatch
3029
import functools
3130
import itertools
@@ -36,7 +35,6 @@
3635
import uuid
3736
import warnings
3837
from abc import ABC, abstractmethod
39-
from concurrent.futures import Future
4038
from copy import copy
4139
from dataclasses import dataclass
4240
from enum import Enum
@@ -71,7 +69,6 @@
7169
FileType,
7270
FSSpecHandler,
7371
)
74-
from sortedcontainers import SortedList
7572

7673
from pyiceberg.conversions import to_bytes
7774
from pyiceberg.exceptions import ResolveError
@@ -1570,7 +1567,6 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
15701567
ResolveError: When a required field cannot be found in the file
15711568
ValueError: When a field type in the file cannot be projected to the schema type
15721569
"""
1573-
15741570
arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False)
15751571

15761572
batches = self.to_record_batches(tasks)
@@ -1592,9 +1588,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
15921588

15931589
return result
15941590

1595-
def to_record_batches(
1596-
self, tasks: Iterable[FileScanTask]
1597-
) -> Iterator[pa.RecordBatch]:
1591+
def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]:
15981592
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].
15991593
16001594
Returns an Iterator of pa.RecordBatch with data from the Iceberg table
@@ -1617,26 +1611,25 @@ def to_record_batches(
16171611
total_row_count = 0
16181612
executor = ExecutorFactory.get_or_create()
16191613

1620-
with executor as pool:
1621-
should_stop = False
1622-
for batches in pool.map(
1623-
lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks
1624-
):
1625-
for batch in batches:
1626-
current_batch_size = len(batch)
1627-
if self._limit is not None:
1628-
if total_row_count + current_batch_size >= self._limit:
1629-
yield batch.slice(0, self._limit - total_row_count)
1614+
limit_reached = False
1615+
for batches in executor.map(
1616+
lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks
1617+
):
1618+
for batch in batches:
1619+
current_batch_size = len(batch)
1620+
if self._limit is not None:
1621+
if total_row_count + current_batch_size >= self._limit:
1622+
yield batch.slice(0, self._limit - total_row_count)
16301623

1631-
# This break will also cancel all tasks in the Pool
1632-
should_stop = True
1633-
break
1624+
# This break will also cancel all tasks in the Pool
1625+
limit_reached = True
1626+
break
16341627

1635-
yield batch
1636-
total_row_count += current_batch_size
1628+
yield batch
1629+
total_row_count += current_batch_size
16371630

1638-
if should_stop:
1639-
break
1631+
if limit_reached:
1632+
break
16401633

16411634
def _record_batches_from_scan_tasks_and_deletes(
16421635
self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]]

0 commit comments

Comments
 (0)