Skip to content

Commit 119d92f

Browse files
committed
Simplify to_arrow to use the optimized to_record_batches
1 parent f8acdb0 commit 119d92f

3 files changed

Lines changed: 31 additions & 56 deletions

File tree

pyiceberg/io/pyarrow.py

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,47 +1570,17 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
15701570
ResolveError: When a required field cannot be found in the file
15711571
ValueError: When a field type in the file cannot be projected to the schema type
15721572
"""
1573-
deletes_per_file = _read_all_delete_files(self._io, tasks)
1574-
executor = ExecutorFactory.get_or_create()
1575-
1576-
def _table_from_scan_task(task: FileScanTask) -> pa.Table:
1577-
batches = list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
1578-
if len(batches) > 0:
1579-
return pa.Table.from_batches(batches)
1580-
else:
1581-
return None
1582-
1583-
futures = [
1584-
executor.submit(
1585-
_table_from_scan_task,
1586-
task,
1587-
)
1588-
for task in tasks
1589-
]
1590-
total_row_count = 0
1591-
# for consistent ordering, we need to maintain future order
1592-
futures_index = {f: i for i, f in enumerate(futures)}
1593-
completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], key=lambda f: futures_index[f])
1594-
for future in concurrent.futures.as_completed(futures):
1595-
completed_futures.add(future)
1596-
if table_result := future.result():
1597-
total_row_count += len(table_result)
1598-
# stop early if limit is satisfied
1599-
if self._limit is not None and total_row_count >= self._limit:
1600-
break
1601-
1602-
# by now, we've either completed all tasks or satisfied the limit
1603-
if self._limit is not None:
1604-
_ = [f.cancel() for f in futures if not f.done()]
1605-
1606-
tables = [f.result() for f in completed_futures if f.result()]
16071573

16081574
arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False)
16091575

1610-
if len(tables) < 1:
1576+
batches = self.to_record_batches(tasks)
1577+
try:
1578+
first_batch = next(batches)
1579+
except StopIteration:
1580+
# Empty
16111581
return pa.Table.from_batches([], schema=arrow_schema)
16121582

1613-
result = pa.concat_tables(tables, promote_options="permissive")
1583+
result = pa.Table.from_batches(itertools.chain([first_batch], batches))
16141584

16151585
if property_as_bool(self._io.properties, PYARROW_USE_LARGE_TYPES_ON_READ, False):
16161586
deprecation_message(
@@ -1620,13 +1590,10 @@ def _table_from_scan_task(task: FileScanTask) -> pa.Table:
16201590
)
16211591
result = result.cast(arrow_schema)
16221592

1623-
if self._limit is not None:
1624-
return result.slice(0, self._limit)
1625-
16261593
return result
16271594

16281595
def to_record_batches(
1629-
self, tasks: Iterable[FileScanTask], concurrent_tasks: Optional[int] = None
1596+
self, tasks: Iterable[FileScanTask]
16301597
) -> Iterator[pa.RecordBatch]:
16311598
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].
16321599
@@ -1636,7 +1603,6 @@ def to_record_batches(
16361603
16371604
Args:
16381605
tasks: FileScanTasks representing the data files and delete files to read from.
1639-
concurrent_tasks: number of concurrent tasks
16401606
16411607
Returns:
16421608
An Iterator of PyArrow RecordBatches.
@@ -1648,16 +1614,29 @@ def to_record_batches(
16481614
"""
16491615
deletes_per_file = _read_all_delete_files(self._io, tasks)
16501616

1651-
if concurrent_tasks is not None:
1652-
with ExecutorFactory.create(max_workers=concurrent_tasks) as pool:
1653-
for batches in pool.map(
1654-
lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks
1655-
):
1656-
for batch in batches:
1657-
yield batch
1617+
total_row_count = 0
1618+
executor = ExecutorFactory.get_or_create()
16581619

1659-
else:
1660-
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)
1620+
with executor as pool:
1621+
should_stop = False
1622+
for batches in pool.map(
1623+
lambda task: list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)), tasks
1624+
):
1625+
for batch in batches:
1626+
current_batch_size = len(batch)
1627+
if self._limit is not None:
1628+
if total_row_count + current_batch_size >= self._limit:
1629+
yield batch.slice(0, self._limit - total_row_count)
1630+
1631+
# This break will also cancel all tasks in the Pool
1632+
should_stop = True
1633+
break
1634+
1635+
yield batch
1636+
total_row_count += current_batch_size
1637+
1638+
if should_stop:
1639+
break
16611640

16621641
def _record_batches_from_scan_tasks_and_deletes(
16631642
self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]]

pyiceberg/table/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,7 +1864,7 @@ def to_arrow(self) -> pa.Table:
18641864
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
18651865
).to_table(self.plan_files())
18661866

1867-
def to_arrow_batch_reader(self, concurrent_tasks: Optional[int] = None) -> pa.RecordBatchReader:
1867+
def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
18681868
"""Return an Arrow RecordBatchReader from this DataScan.
18691869
18701870
For large results, using a RecordBatchReader requires less memory than
@@ -1882,7 +1882,7 @@ def to_arrow_batch_reader(self, concurrent_tasks: Optional[int] = None) -> pa.Re
18821882
target_schema = schema_to_pyarrow(self.projection())
18831883
batches = ArrowScan(
18841884
self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit
1885-
).to_record_batches(self.plan_files(), concurrent_tasks=concurrent_tasks)
1885+
).to_record_batches(self.plan_files())
18861886

18871887
return pa.RecordBatchReader.from_batches(
18881888
target_schema,

pyiceberg/utils/concurrent.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,3 @@ def get_or_create() -> Executor:
3838
def max_workers() -> Optional[int]:
3939
"""Return the max number of workers configured."""
4040
return Config().get_int("max-workers")
41-
42-
@staticmethod
43-
def create(max_workers: int) -> Executor:
44-
return ThreadPoolExecutor(max_workers=max_workers)

0 commit comments

Comments
 (0)