@@ -1136,6 +1136,13 @@ def _read_deletes(io: FileIO, data_file: DataFile) -> dict[str, pa.ChunkedArray]
11361136 raise ValueError (f"Delete file format not supported: { data_file .file_format } " )
11371137
11381138
1139+ def _read_equality_deletes (io : FileIO , delete_file : DataFile ) -> pa .Table :
1140+ arrow_format = _get_file_format (delete_file .file_format , pre_buffer = True , buffer_size = ONE_MEGABYTE )
1141+ with io .new_input (delete_file .file_path ).open () as fi :
1142+ fragment = arrow_format .make_fragment (fi )
1143+ return ds .Scanner .from_fragment (fragment = fragment ).to_table ()
1144+
1145+
11391146def _combine_positional_deletes (positional_deletes : list [pa .ChunkedArray ], start_index : int , end_index : int ) -> pa .Array :
11401147 if len (positional_deletes ) == 1 :
11411148 all_chunks = positional_deletes [0 ]
@@ -1609,6 +1616,7 @@ def _task_to_record_batches(
16091616 table_schema : Schema ,
16101617 projected_field_ids : set [int ],
16111618 positional_deletes : list [ChunkedArray ] | None ,
1619+ equality_deletes : list [tuple [set [int ], pa .Table ]] | None ,
16121620 case_sensitive : bool ,
16131621 name_mapping : NameMapping | None = None ,
16141622 partition_spec : PartitionSpec | None = None ,
@@ -1643,14 +1651,20 @@ def _task_to_record_batches(
16431651 bound_file_filter = bind (file_schema , translated_row_filter , case_sensitive = case_sensitive )
16441652 pyarrow_filter = expression_to_pyarrow (bound_file_filter , file_schema )
16451653
1646- file_project_schema = prune_columns (file_schema , projected_field_ids , select_full_types = False )
1654+ # Ensure equality delete columns are also projected
1655+ all_projected_field_ids = projected_field_ids .copy ()
1656+ if equality_deletes :
1657+ for eq_ids , _ in equality_deletes :
1658+ all_projected_field_ids .update (eq_ids )
1659+
1660+ file_project_schema = prune_columns (file_schema , all_projected_field_ids , select_full_types = False )
16471661
16481662 fragment_scanner = ds .Scanner .from_fragment (
16491663 fragment = fragment ,
16501664 schema = physical_schema ,
16511665 # This will push down the query to Arrow.
1652- # But in case there are positional deletes, we have to apply them first
1653- filter = pyarrow_filter if not positional_deletes else None ,
1666+ # But in case there are positional or equality deletes, we have to apply them first
1667+ filter = pyarrow_filter if not positional_deletes and not equality_deletes else None ,
16541668 columns = [col .name for col in file_project_schema .columns ],
16551669 )
16561670
@@ -1666,6 +1680,38 @@ def _task_to_record_batches(
16661680 indices = _combine_positional_deletes (positional_deletes , current_index , current_index + len (batch ))
16671681 current_batch = current_batch .take (indices )
16681682
1683+ if current_batch .num_rows > 0 and equality_deletes :
1684+ for eq_ids , eq_delete_table in equality_deletes :
1685+ try :
1686+ eq_file_schema = pyarrow_to_schema (
1687+ eq_delete_table .schema ,
1688+ name_mapping = name_mapping ,
1689+ format_version = format_version ,
1690+ )
1691+
1692+ rename_map = {}
1693+ for field_id in eq_ids :
1694+ file_name = eq_file_schema .find_column_name (field_id )
1695+ current_name = table_schema .find_column_name (field_id )
1696+ if file_name != current_name :
1697+ rename_map [file_name ] = current_name
1698+
1699+ if rename_map :
1700+ eq_delete_table = eq_delete_table .rename_columns (
1701+ [rename_map .get (name , name ) for name in eq_delete_table .column_names ]
1702+ )
1703+
1704+ join_keys = [table_schema .find_column_name (field_id ) for field_id in eq_ids ]
1705+ current_table = pa .Table .from_batches ([current_batch ])
1706+ current_table = current_table .join (eq_delete_table , keys = join_keys , join_type = "left anti" )
1707+
1708+ if current_table .num_rows == 0 :
1709+ current_batch = current_table .to_batches ()[0 ]
1710+ break
1711+ current_batch = current_table .to_batches ()[0 ]
1712+ except (ValueError , ResolveError ):
1713+ continue
1714+
16691715 # skip empty batches
16701716 if current_batch .num_rows == 0 :
16711717 continue
@@ -1691,23 +1737,57 @@ def _task_to_record_batches(
16911737 )
16921738
16931739
1694- def _read_all_delete_files (io : FileIO , tasks : Iterable [FileScanTask ]) -> dict [str , list [ChunkedArray ]]:
1695- deletes_per_file : dict [str , list [ChunkedArray ]] = {}
1740+ def _read_all_delete_files (
1741+ io : FileIO , tasks : Iterable [FileScanTask ]
1742+ ) -> tuple [dict [str , list [ChunkedArray ]], dict [str , list [tuple [set [int ], pa .Table ]]]]:
1743+ pos_deletes_per_file : dict [str , list [ChunkedArray ]] = {}
1744+ eq_deletes_per_file : dict [str , list [tuple [set [int ], pa .Table ]]] = {}
1745+
16961746 unique_deletes = set (itertools .chain .from_iterable ([task .delete_files for task in tasks ]))
16971747 if len (unique_deletes ) > 0 :
1748+ unique_pos_deletes = {d for d in unique_deletes if d .content == DataFileContent .POSITION_DELETES }
1749+ unique_eq_deletes = {d for d in unique_deletes if d .content == DataFileContent .EQUALITY_DELETES }
1750+
16981751 executor = ExecutorFactory .get_or_create ()
1699- deletes_per_files : Iterator [dict [str , ChunkedArray ]] = executor .map (
1700- lambda args : _read_deletes (* args ),
1701- [(io , delete_file ) for delete_file in unique_deletes ],
1702- )
1703- for delete in deletes_per_files :
1704- for file , arr in delete .items ():
1705- if file in deletes_per_file :
1706- deletes_per_file [file ].append (arr )
1707- else :
1708- deletes_per_file [file ] = [arr ]
17091752
1710- return deletes_per_file
1753+ if len (unique_pos_deletes ) > 0 :
1754+ pos_deletes : Iterator [dict [str , ChunkedArray ]] = executor .map (
1755+ lambda args : _read_deletes (* args ),
1756+ [(io , delete_file ) for delete_file in unique_pos_deletes ],
1757+ )
1758+ for delete in pos_deletes :
1759+ for file , arr in delete .items ():
1760+ if file in pos_deletes_per_file :
1761+ pos_deletes_per_file [file ].append (arr )
1762+ else :
1763+ pos_deletes_per_file [file ] = [arr ]
1764+
1765+ if len (unique_eq_deletes ) > 0 :
1766+ # We map each unique eq delete file location to its loaded table and its equality IDs
1767+ eq_deletes_tables : dict [str , tuple [set [int ], pa .Table ]] = dict (
1768+ zip (
1769+ [d .file_path for d in unique_eq_deletes ],
1770+ zip (
1771+ [set (d .equality_ids ) if d .equality_ids else set () for d in unique_eq_deletes ],
1772+ executor .map (
1773+ lambda args : _read_equality_deletes (* args ),
1774+ [(io , d ) for d in unique_eq_deletes ],
1775+ ),
1776+ strict = True ,
1777+ ),
1778+ strict = True ,
1779+ )
1780+ )
1781+
1782+ # Map eq deletes to each task's data file path
1783+ for task in tasks :
1784+ eq_deletes_for_task = [
1785+ eq_deletes_tables [d .file_path ] for d in task .delete_files if d .content == DataFileContent .EQUALITY_DELETES
1786+ ]
1787+ if eq_deletes_for_task :
1788+ eq_deletes_per_file [task .file .file_path ] = eq_deletes_for_task
1789+
1790+ return pos_deletes_per_file , eq_deletes_per_file
17111791
17121792
17131793class ArrowScan :
@@ -1807,7 +1887,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record
18071887 ResolveError: When a required field cannot be found in the file
18081888 ValueError: When a field type in the file cannot be projected to the schema type
18091889 """
1810- deletes_per_file = _read_all_delete_files (self ._io , tasks )
1890+ pos_deletes_per_file , eq_deletes_per_file = _read_all_delete_files (self ._io , tasks )
18111891
18121892 total_row_count = 0
18131893 executor = ExecutorFactory .get_or_create ()
@@ -1816,7 +1896,7 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
18161896 # Materialize the iterator here to ensure execution happens within the executor.
18171897 # Otherwise, the iterator would be lazily consumed later (in the main thread),
18181898 # defeating the purpose of using executor.map.
1819- return list (self ._record_batches_from_scan_tasks_and_deletes ([task ], deletes_per_file ))
1899+ return list (self ._record_batches_from_scan_tasks_and_deletes ([task ], pos_deletes_per_file , eq_deletes_per_file ))
18201900
18211901 limit_reached = False
18221902 for batches in executor .map (batches_for_task , tasks ):
@@ -1836,7 +1916,10 @@ def batches_for_task(task: FileScanTask) -> list[pa.RecordBatch]:
18361916 break
18371917
18381918 def _record_batches_from_scan_tasks_and_deletes (
1839- self , tasks : Iterable [FileScanTask ], deletes_per_file : dict [str , list [ChunkedArray ]]
1919+ self ,
1920+ tasks : Iterable [FileScanTask ],
1921+ pos_deletes_per_file : dict [str , list [ChunkedArray ]],
1922+ eq_deletes_per_file : dict [str , list [pa .Table ]],
18401923 ) -> Iterator [pa .RecordBatch ]:
18411924 total_row_count = 0
18421925 for task in tasks :
@@ -1849,7 +1932,8 @@ def _record_batches_from_scan_tasks_and_deletes(
18491932 self ._projected_schema ,
18501933 self ._table_metadata .schema (),
18511934 self ._projected_field_ids ,
1852- deletes_per_file .get (task .file .file_path ),
1935+ pos_deletes_per_file .get (task .file .file_path ),
1936+ eq_deletes_per_file .get (task .file .file_path ),
18531937 self ._case_sensitive ,
18541938 self ._table_metadata .name_mapping (),
18551939 self ._table_metadata .specs ().get (task .file .spec_id ),
0 commit comments