|
21 | 21 | from dataclasses import dataclass |
22 | 22 | from datetime import date, datetime, time |
23 | 23 | from functools import cached_property, singledispatch |
24 | | -from typing import Annotated, Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union |
| 24 | +from typing import Annotated, Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union |
25 | 25 | from urllib.parse import quote_plus |
26 | 26 |
|
27 | 27 | from pydantic import ( |
@@ -272,6 +272,56 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre |
272 | 272 | T = TypeVar("T") |
273 | 273 |
|
274 | 274 |
|
| 275 | +class PartitionMap(Generic[T]): |
| 276 | + _specs: dict[int, PartitionSpec] |
| 277 | + _partition_maps: dict[int, dict[Record, T]] |
| 278 | + |
| 279 | + def __init__(self, specs: dict[int, PartitionSpec]): |
| 280 | + self._specs = specs |
| 281 | + self._partition_maps = {} |
| 282 | + |
| 283 | + def __len__(self) -> int: |
| 284 | + """Return the length of the partition map. |
| 285 | +
|
| 286 | + Returns: |
| 287 | + length of _partition_maps |
| 288 | + """ |
| 289 | + return len(self._partition_maps.values()) |
| 290 | + |
| 291 | + def is_empty(self) -> bool: |
| 292 | + return len(self._partition_maps.values()) == 0 |
| 293 | + |
| 294 | + def contains_key(self, spec_id: int, struct: Record) -> bool: |
| 295 | + try: |
| 296 | + return struct in self._partition_maps[spec_id] |
| 297 | + except KeyError as _: |
| 298 | + return False |
| 299 | + |
| 300 | + def contains_value(self, value: T) -> bool: |
| 301 | + return value in self._partition_maps.values() |
| 302 | + |
| 303 | + def get(self, spec_id: int, struct: Record) -> Optional[T]: |
| 304 | + if partition_map := self._partition_maps.get(spec_id): |
| 305 | + return partition_map.get(struct) |
| 306 | + return None |
| 307 | + |
| 308 | + def put(self, spec_id: int, struct: Record, value: T) -> None: |
| 309 | + if _ := self._specs.get(spec_id): |
| 310 | + self._partition_maps[spec_id] = {struct: value} |
| 311 | + |
| 312 | + def compute_if_absent(self, spec_id: int, struct: Record, value: T, value_factory: Callable[[], T]) -> T: |
| 313 | + if partition_map := self._partition_maps.get(spec_id): |
| 314 | + if val := partition_map.get(struct): |
| 315 | + return val |
| 316 | + return value_factory() |
| 317 | + |
| 318 | + def values(self) -> list[T]: |
| 319 | + result: list[T] = [] |
| 320 | + for partition_map in self._partition_maps.values(): |
| 321 | + result.extend(partition_map.values()) |
| 322 | + return result |
| 323 | + |
| 324 | + |
275 | 325 | class PartitionSpecVisitor(Generic[T], ABC): |
276 | 326 | @abstractmethod |
277 | 327 | def identity(self, field_id: int, source_name: str, source_id: int) -> T: |
|
0 commit comments