Skip to content

Commit 88a4ad2

Browse files
committed
Accept concurrent_tasks when fetching record_batches
1 parent 4614543 commit 88a4ad2

2 files changed

Lines changed: 19 additions & 4 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,7 +1625,9 @@ def _table_from_scan_task(task: FileScanTask) -> pa.Table:
16251625

16261626
return result
16271627

1628-
def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]:
1628+
def to_record_batches(
1629+
self, tasks: Iterable[FileScanTask], concurrent_tasks: Optional[int] = None
1630+
) -> Iterator[pa.RecordBatch]:
16291631
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].
16301632
16311633
Returns an Iterator of pa.RecordBatch with data from the Iceberg table
@@ -1634,6 +1636,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record
16341636
16351637
Args:
16361638
tasks: FileScanTasks representing the data files and delete files to read from.
1639+
concurrent_tasks: number of concurrent tasks
16371640
16381641
Returns:
16391642
An Iterator of PyArrow RecordBatches.
@@ -1643,8 +1646,20 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record
16431646
ResolveError: When a required field cannot be found in the file
16441647
ValueError: When a field type in the file cannot be projected to the schema type
16451648
"""
1649+
from concurrent.futures import ThreadPoolExecutor
1650+
16461651
deletes_per_file = _read_all_delete_files(self._io, tasks)
1647-
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)
1652+
1653+
if concurrent_tasks is not None:
1654+
with ThreadPoolExecutor(max_workers=concurrent_tasks) as pool:
1655+
for batches in pool.map(
1656+
lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks
1657+
):
1658+
for batch in batches:
1659+
yield batch
1660+
1661+
else:
1662+
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)
16481663

16491664
def _record_batches_from_scan_tasks_and_deletes(
16501665
self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]]

pyiceberg/table/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,7 @@ def to_arrow(self) -> pa.Table:
18641864
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
18651865
).to_table(self.plan_files())
18661866

1867-
def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
1867+
def to_arrow_batch_reader(self, concurrent_tasks: Optional[int] = None) -> pa.RecordBatchReader:
18681868
"""Return an Arrow RecordBatchReader from this DataScan.
18691869
18701870
For large results, using a RecordBatchReader requires less memory than
@@ -1882,7 +1882,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
18821882
target_schema = schema_to_pyarrow(self.projection())
18831883
batches = ArrowScan(
18841884
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
1885-
).to_record_batches(self.plan_files())
1885+
).to_record_batches(self.plan_files(), concurrent_tasks=concurrent_tasks)
18861886

18871887
return pa.RecordBatchReader.from_batches(
18881888
target_schema,

0 commit comments

Comments
 (0)