@@ -111,6 +111,36 @@ def _transform_literal(func: Callable[[L], L], lit: Literal[L]) -> Literal[L]:
111111 return literal (func (lit .value ))
112112
113113
114+ def _pyiceberg_transform_wrapper (
115+ transform_func : Callable [["ArrayLike" , Any ], "ArrayLike" ],
116+ * args : Any ,
117+ expected_type : Optional ["pa.DataType" ] = None ,
118+ ) -> Callable [["ArrayLike" ], "ArrayLike" ]:
119+ try :
120+ import pyarrow as pa
121+ except ModuleNotFoundError as e :
122+ raise ModuleNotFoundError ("For partition transforms, PyArrow needs to be installed" ) from e
123+
124+ def _transform (array : "ArrayLike" ) -> "ArrayLike" :
125+ def _cast_if_needed (arr : "ArrayLike" ) -> "ArrayLike" :
126+ if expected_type is not None :
127+ return arr .cast (expected_type )
128+ else :
129+ return arr
130+
131+ if isinstance (array , pa .Array ):
132+ return _cast_if_needed (transform_func (array , * args ))
133+ elif isinstance (array , pa .ChunkedArray ):
134+ result_chunks = []
135+ for arr in array .iterchunks ():
136+ result_chunks .append (_cast_if_needed (transform_func (arr , * args )))
137+ return pa .chunked_array (result_chunks )
138+ else :
139+ raise ValueError (f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found { type (array )} " )
140+
141+ return _transform
142+
143+
114144class Transform (IcebergRootModel [str ], ABC , Generic [S , T ]):
115145 """Transform base class for concrete transforms.
116146
@@ -175,27 +205,6 @@ def supports_pyarrow_transform(self) -> bool:
175205 @abstractmethod
176206 def pyarrow_transform (self , source : IcebergType ) -> "Callable[[pa.Array], pa.Array]" : ...
177207
178- def _pyiceberg_transform_wrapper (
179- self , transform_func : Callable [["ArrayLike" , Any ], "ArrayLike" ], * args : Any
180- ) -> Callable [["ArrayLike" ], "ArrayLike" ]:
181- try :
182- import pyarrow as pa
183- except ModuleNotFoundError as e :
184- raise ModuleNotFoundError ("For bucket/truncate transforms, PyArrow needs to be installed" ) from e
185-
186- def _transform (array : "ArrayLike" ) -> "ArrayLike" :
187- if isinstance (array , pa .Array ):
188- return transform_func (array , * args )
189- elif isinstance (array , pa .ChunkedArray ):
190- result_chunks = []
191- for arr in array .iterchunks ():
192- result_chunks .append (transform_func (arr , * args ))
193- return pa .chunked_array (result_chunks )
194- else :
195- raise ValueError (f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found { type (array )} " )
196-
197- return _transform
198-
199208
200209def parse_transform (v : Any ) -> Transform [Any , Any ]:
201210 if isinstance (v , str ):
@@ -375,7 +384,7 @@ def __repr__(self) -> str:
375384 def pyarrow_transform (self , source : IcebergType ) -> "Callable[[pa.Array], pa.Array]" :
376385 from pyiceberg_core import transform as pyiceberg_core_transform
377386
378- return self . _pyiceberg_transform_wrapper (pyiceberg_core_transform .bucket , self ._num_buckets )
387+ return _pyiceberg_transform_wrapper (pyiceberg_core_transform .bucket , self ._num_buckets )
379388
380389 @property
381390 def supports_pyarrow_transform (self ) -> bool :
@@ -501,22 +510,9 @@ def __repr__(self) -> str:
501510
502511 def pyarrow_transform (self , source : IcebergType ) -> "Callable[[pa.Array], pa.Array]" :
503512 import pyarrow as pa
504- import pyarrow .compute as pc
505-
506- if isinstance (source , DateType ):
507- epoch = pa .scalar (datetime .EPOCH_DATE )
508- elif isinstance (source , TimestampType ):
509- epoch = pa .scalar (datetime .EPOCH_TIMESTAMP )
510- elif isinstance (source , TimestamptzType ):
511- epoch = pa .scalar (datetime .EPOCH_TIMESTAMPTZ )
512- elif isinstance (source , TimestampNanoType ):
513- epoch = pa .scalar (datetime .EPOCH_TIMESTAMP ).cast (pa .timestamp ("ns" ))
514- elif isinstance (source , TimestamptzNanoType ):
515- epoch = pa .scalar (datetime .EPOCH_TIMESTAMPTZ ).cast (pa .timestamp ("ns" ))
516- else :
517- raise ValueError (f"Cannot apply year transform for type: { source } " )
513+ from pyiceberg_core import transform as pyiceberg_core_transform
518514
519- return lambda v : pc . years_between ( epoch , v ) if v is not None else None
515+ return _pyiceberg_transform_wrapper ( pyiceberg_core_transform . year , expected_type = pa . int32 ())
520516
521517
522518class MonthTransform (TimeTransform [S ]):
@@ -575,28 +571,9 @@ def __repr__(self) -> str:
575571
576572 def pyarrow_transform (self , source : IcebergType ) -> "Callable[[pa.Array], pa.Array]" :
577573 import pyarrow as pa
578- import pyarrow .compute as pc
579-
580- if isinstance (source , DateType ):
581- epoch = pa .scalar (datetime .EPOCH_DATE )
582- elif isinstance (source , TimestampType ):
583- epoch = pa .scalar (datetime .EPOCH_TIMESTAMP )
584- elif isinstance (source , TimestamptzType ):
585- epoch = pa .scalar (datetime .EPOCH_TIMESTAMPTZ )
586- elif isinstance (source , TimestampNanoType ):
587- epoch = pa .scalar (datetime .EPOCH_TIMESTAMP ).cast (pa .timestamp ("ns" ))
588- elif isinstance (source , TimestamptzNanoType ):
589- epoch = pa .scalar (datetime .EPOCH_TIMESTAMPTZ ).cast (pa .timestamp ("ns" ))
590- else :
591- raise ValueError (f"Cannot apply month transform for type: { source } " )
592-
593- def month_func (v : pa .Array ) -> pa .Array :
594- return pc .add (
595- pc .multiply (pc .years_between (epoch , v ), pa .scalar (12 )),
596- pc .add (pc .month (v ), pa .scalar (- 1 )),
597- )
574+ from pyiceberg_core import transform as pyiceberg_core_transform
598575
599- return lambda v : month_func ( v ) if v is not None else None
576+ return _pyiceberg_transform_wrapper ( pyiceberg_core_transform . month , expected_type = pa . int32 ())
600577
601578
602579class DayTransform (TimeTransform [S ]):
@@ -663,22 +640,9 @@ def __repr__(self) -> str:
663640
664641 def pyarrow_transform (self , source : IcebergType ) -> "Callable[[pa.Array], pa.Array]" :
665642 import pyarrow as pa
666- import pyarrow .compute as pc
667-
668- if isinstance (source , DateType ):
669- epoch = pa .scalar (datetime .EPOCH_DATE )
670- elif isinstance (source , TimestampType ):
671- epoch = pa .scalar (datetime .EPOCH_TIMESTAMP )
672- elif isinstance (source , TimestamptzType ):
673- epoch = pa .scalar (datetime .EPOCH_TIMESTAMPTZ )
674- elif isinstance (source , TimestampNanoType ):
675- epoch = pa .scalar (datetime .EPOCH_TIMESTAMP ).cast (pa .timestamp ("ns" ))
676- elif isinstance (source , TimestamptzNanoType ):
677- epoch = pa .scalar (datetime .EPOCH_TIMESTAMPTZ ).cast (pa .timestamp ("ns" ))
678- else :
679- raise ValueError (f"Cannot apply day transform for type: { source } " )
643+ from pyiceberg_core import transform as pyiceberg_core_transform
680644
681- return lambda v : pc . days_between ( epoch , v ) if v is not None else None
645+ return _pyiceberg_transform_wrapper ( pyiceberg_core_transform . day , expected_type = pa . int32 ())
682646
683647
684648class HourTransform (TimeTransform [S ]):
@@ -728,21 +692,9 @@ def __repr__(self) -> str:
728692 return "HourTransform()"
729693
730694 def pyarrow_transform (self , source : IcebergType ) -> "Callable[[pa.Array], pa.Array]" :
731- import pyarrow as pa
732- import pyarrow .compute as pc
733-
734- if isinstance (source , TimestampType ):
735- epoch = pa .scalar (datetime .EPOCH_TIMESTAMP )
736- elif isinstance (source , TimestamptzType ):
737- epoch = pa .scalar (datetime .EPOCH_TIMESTAMPTZ )
738- elif isinstance (source , TimestampNanoType ):
739- epoch = pa .scalar (datetime .EPOCH_TIMESTAMP ).cast (pa .timestamp ("ns" ))
740- elif isinstance (source , TimestamptzNanoType ):
741- epoch = pa .scalar (datetime .EPOCH_TIMESTAMPTZ ).cast (pa .timestamp ("ns" ))
742- else :
743- raise ValueError (f"Cannot apply hour transform for type: { source } " )
695+ from pyiceberg_core import transform as pyiceberg_core_transform
744696
745- return lambda v : pc . hours_between ( epoch , v ) if v is not None else None
697+ return _pyiceberg_transform_wrapper ( pyiceberg_core_transform . hour )
746698
747699
748700def _base64encode (buffer : bytes ) -> str :
@@ -965,7 +917,7 @@ def __repr__(self) -> str:
965917 def pyarrow_transform (self , source : IcebergType ) -> "Callable[[pa.Array], pa.Array]" :
966918 from pyiceberg_core import transform as pyiceberg_core_transform
967919
968- return self . _pyiceberg_transform_wrapper (pyiceberg_core_transform .truncate , self ._width )
920+ return _pyiceberg_transform_wrapper (pyiceberg_core_transform .truncate , self ._width )
969921
970922 @property
971923 def supports_pyarrow_transform (self ) -> bool :
0 commit comments