Skip to content

Commit f3babd3

Browse files
committed
small fixes and tests for PartitionMap
1 parent 8457086 commit f3babd3

3 files changed

Lines changed: 81 additions & 19 deletions

File tree

pyiceberg/partitioning.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
274274

275275
class PartitionMap(Generic[T]):
276276
_specs: dict[int, PartitionSpec]
277-
_partition_maps: dict[int, dict[Record, T]]
277+
_partition_maps: dict[int, dict[Record | None, T]]
278278

279279
def __init__(self, specs: dict[int, PartitionSpec]):
280280
self._specs = specs
@@ -286,34 +286,38 @@ def __len__(self) -> int:
286286
Returns:
287287
length of _partition_maps
288288
"""
289-
return len(self._partition_maps.values())
289+
return len(self.values())
290290

291291
def is_empty(self) -> bool:
292292
return len(self._partition_maps.values()) == 0
293293

294294
def contains_key(self, spec_id: int, struct: Record) -> bool:
295-
try:
296-
return struct in self._partition_maps[spec_id]
297-
except KeyError as _:
298-
return False
295+
return self._partition_maps.get(spec_id) is not None
299296

300297
def contains_value(self, value: T) -> bool:
301298
return value in self._partition_maps.values()
302299

303-
def get(self, spec_id: int, struct: Record) -> Optional[T]:
300+
def get(self, spec_id: int, struct: Record | None) -> Optional[T]:
304301
if partition_map := self._partition_maps.get(spec_id):
305-
return partition_map.get(struct)
302+
if result := partition_map.get(struct):
303+
return result
306304
return None
307305

308-
def put(self, spec_id: int, struct: Record, value: T) -> None:
306+
def put(self, spec_id: int, struct: Record | None, value: T) -> None:
309307
if _ := self._specs.get(spec_id):
310-
self._partition_maps[spec_id] = {struct: value}
308+
if spec_id not in self._partition_maps:
309+
self._partition_maps[spec_id] = {struct: value}
310+
else:
311+
self._partition_maps[spec_id][struct] = value
311312

312313
def compute_if_absent(self, spec_id: int, struct: Record, value: T, value_factory: Callable[[], T]) -> T:
313-
if partition_map := self._partition_maps.get(spec_id):
314-
if val := partition_map.get(struct):
315-
return val
316-
return value_factory()
314+
partition_map = self._partition_maps.setdefault(spec_id, {})
315+
if struct in partition_map:
316+
return partition_map[struct]
317+
318+
value = value_factory()
319+
partition_map[struct] = value
320+
return value
317321

318322
def values(self) -> list[T]:
319323
result: list[T] = []

pyiceberg/table/update/validate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _index_if_needed(self) -> None:
5656
self._indexed = True
5757

5858
def add_entry(self, entry: ManifestEntry) -> None:
59-
if self._buffer is None:
59+
if self._indexed:
6060
raise Exception("Can't add files upon indexing.")
6161
self._buffer.append(entry)
6262

@@ -99,7 +99,7 @@ def _index_if_needed(self) -> None:
9999

100100
def add_entry(self, spec: PartitionSpec, entry: ManifestEntry) -> None:
101101
# TODO: Equality deletes should consider the spec to get the equality fields
102-
if self._buffer is None:
102+
if self._indexed:
103103
raise Exception("Can't add files upon indexing.")
104104
self._buffer.append(entry)
105105

tests/table/test_partitioning.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytest
2323

24-
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
24+
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionMap, PartitionSpec
2525
from pyiceberg.schema import Schema
2626
from pyiceberg.transforms import (
2727
BucketTransform,
@@ -225,5 +225,63 @@ def test_deserialize_partition_field_v3() -> None:
225225
assert field == PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=19), name="str_truncate")
226226

227227

228-
def test_partition_map() -> None:
229-
pass
228+
@pytest.fixture
229+
def specs_set() -> dict[int, PartitionSpec]:
230+
return {
231+
0: UNPARTITIONED_PARTITION_SPEC,
232+
1: PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=DayTransform(), name="dayPartition"), spec_id=1),
233+
2: PartitionSpec(
234+
PartitionField(source_id=1, field_id=1000, transform=DayTransform(), name="dayPartition"),
235+
PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="identityPartition"),
236+
spec_id=2,
237+
),
238+
}
239+
240+
241+
def test_empty_partition_map() -> None:
242+
specs: dict[int, PartitionSpec] = {UNPARTITIONED_PARTITION_SPEC.spec_id: UNPARTITIONED_PARTITION_SPEC}
243+
partition_map: PartitionMap[str] = PartitionMap(specs)
244+
assert partition_map.is_empty()
245+
assert len(partition_map) == 0
246+
assert not partition_map.contains_key(1, Record(1))
247+
assert len(partition_map.values()) == 0
248+
249+
250+
def test_size_partition_map(specs_set: dict[int, PartitionSpec]) -> None:
251+
partition_map: PartitionMap[str] = PartitionMap(specs_set)
252+
partition_map.put(UNPARTITIONED_PARTITION_SPEC.spec_id, None, "v1")
253+
partition_map.put(specs_set[1].spec_id, Record("aaa"), "v2")
254+
partition_map.put(specs_set[1].spec_id, Record("bbb"), "v3")
255+
partition_map.put(specs_set[2].spec_id, Record("ccc", 2), "v4")
256+
assert not partition_map.is_empty()
257+
assert len(partition_map) == 4
258+
# assert partition_map.get(UNPARTITIONED_PARTITION_SPEC.spec_id, None) == "v1"
259+
260+
261+
def test_put_and_get_partition_map(specs_set: dict[int, PartitionSpec]) -> None:
262+
partition_map: PartitionMap[str] = PartitionMap(specs_set)
263+
partition_map.put(UNPARTITIONED_PARTITION_SPEC.spec_id, None, "v1")
264+
partition_map.put(specs_set[1].spec_id, Record("aaa", 1), "v2")
265+
assert partition_map.get(UNPARTITIONED_PARTITION_SPEC.spec_id, None) == "v1"
266+
assert partition_map.get(specs_set[1].spec_id, Record("aaa", 1)) == "v2"
267+
268+
269+
def test_values_partition_map(specs_set: dict[int, PartitionSpec]) -> None:
270+
partition_map: PartitionMap[str] = PartitionMap(specs_set)
271+
partition_map.put(UNPARTITIONED_PARTITION_SPEC.spec_id, None, "v1")
272+
partition_map.put(specs_set[1].spec_id, Record("aaa"), "v2")
273+
partition_map.put(specs_set[1].spec_id, Record("bbb"), "v3")
274+
partition_map.put(specs_set[2].spec_id, Record("ccc", 2), "v4")
275+
assert partition_map.values() == ["v1", "v2", "v3", "v4"]
276+
277+
278+
def test_compute_if_absent_partition_map(specs_set: dict[int, PartitionSpec]) -> None:
279+
partition_map: PartitionMap[str] = PartitionMap(specs_set)
280+
281+
result1 = partition_map.compute_if_absent(specs_set[1].spec_id, Record("a"), "v1", lambda: "v1")
282+
assert result1 == "v1"
283+
assert partition_map.get(specs_set[1].spec_id, Record("a")) == "v1"
284+
285+
result2 = partition_map.compute_if_absent(specs_set[1].spec_id, Record("a"), "v2", lambda: "v2")
286+
assert result2 == "v1"
287+
assert partition_map.get(specs_set[1].spec_id, Record("a")) == "v1"

0 commit comments

Comments
 (0)