diff --git a/docs/api_reference.md b/docs/api_reference.md index 7e747318..7c710189 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -32,6 +32,12 @@ and :members: ``` +### Grid Overrides + +```{eval-rst} +.. autopydantic_model:: mdio.GridOverrides +``` + ## Core Functionality ### Dimensions diff --git a/docs/guides/grid_overrides.md b/docs/guides/grid_overrides.md index f44d5654..783afbbd 100644 --- a/docs/guides/grid_overrides.md +++ b/docs/guides/grid_overrides.md @@ -10,6 +10,30 @@ Grid overrides are transformations applied during SEG-Y import that modify how t When importing SEG-Y data, MDIO maps trace header fields to dataset dimensions. However, real-world seismic data often has complexities that require additional processing. Grid overrides address these issues by transforming header values before indexing. +## Configuring grid overrides + +Grid overrides are passed to {func}`mdio.segy_to_mdio` via the `grid_overrides` argument as an +{class}`mdio.GridOverrides` instance: + +```python +from mdio import GridOverrides +from mdio import segy_to_mdio + +segy_to_mdio( + ..., + grid_overrides=GridOverrides(calculate_shot_index=True), +) +``` + +Both modern `snake_case` field names and the legacy `CamelCase` aliases are accepted, so +`GridOverrides(CalculateShotIndex=True)` is equivalent to the example above. Unknown keys +are rejected at construction with a `pydantic.ValidationError`. + +```{deprecated} 1.2 +Passing `grid_overrides` as a `dict` still works but logs a deprecation warning and will be +removed in a future release. Switch to `mdio.GridOverrides`. +``` + ## CalculateShotIndex Calculates a dense `shot_index` dimension from sparse or interleaved `shot_point` values. Required for the `ObnReceiverGathers3D` template. @@ -37,12 +61,15 @@ The override detects the geometry type and only applies the transformation when **Usage:** ```python +from mdio import GridOverrides +from mdio import segy_to_mdio + segy_to_mdio( input_path="obn_data.sgy", output_path="obn_data.mdio", segy_spec=obn_spec, mdio_template=get_template("ObnReceiverGathers3D"), - grid_overrides={"CalculateShotIndex": True}, + grid_overrides=GridOverrides(calculate_shot_index=True), ) ``` diff --git a/docs/guides/obn_data_import.md b/docs/guides/obn_data_import.md index fbb14ed6..94499785 100644 --- a/docs/guides/obn_data_import.md +++ b/docs/guides/obn_data_import.md @@ -66,6 +66,7 @@ A warning is logged when component is synthesized: from segy.schema import HeaderField from segy.standards import get_segy_standard +from mdio import GridOverrides from mdio import segy_to_mdio from mdio.builder.template_registry import get_template @@ -91,7 +92,7 @@ segy_to_mdio( output_path="obn_data.mdio", segy_spec=obn_spec, mdio_template=get_template("ObnReceiverGathers3D"), - grid_overrides={"CalculateShotIndex": True}, + grid_overrides=GridOverrides(calculate_shot_index=True), overwrite=True, ) ``` diff --git a/src/mdio/__init__.py b/src/mdio/__init__.py index 857fb806..7e1851c0 100644 --- a/src/mdio/__init__.py +++ b/src/mdio/__init__.py @@ -10,6 +10,7 @@ from mdio.converters import segy_to_mdio from mdio.optimize.access_pattern import OptimizedAccessPatternConfig from mdio.optimize.access_pattern import optimize_access_patterns +from mdio.segy.geometry import GridOverrides try: __version__ = metadata.version("multidimio") @@ -19,6 +20,7 @@ __all__ = [ "__version__", + "GridOverrides", "open_mdio", "to_mdio", "mdio_to_segy", diff --git a/src/mdio/converters/segy.py b/src/mdio/converters/segy.py index 59545f3b..f00f7bbf 100644 --- a/src/mdio/converters/segy.py +++ b/src/mdio/converters/segy.py @@ -37,6 +37,7 @@ from mdio.ingestion.segy.validation import _validate_spec_in_template from mdio.segy import blocked_io from mdio.segy.file import get_segy_file_info +from mdio.segy.geometry import GridOverrides from mdio.segy.utilities import get_grid_plan if TYPE_CHECKING: @@ -128,7 +129,7 @@ def filtered_add_coordinate( # noqa: ANN202 def _update_template_from_grid_overrides( template: AbstractDatasetTemplate, - grid_overrides: dict[str, Any] | None, + grid_overrides: GridOverrides | None, segy_dimensions: list[Dimension], full_chunk_shape: tuple[int, ...], chunk_size: tuple[int, ...], @@ -178,30 +179,29 @@ def _update_template_from_grid_overrides( # If using NonBinned override, expose non-binned dims as logical coordinates on the template instance # and patch _add_coordinates to skip adding them as 1D dimension coordinates - if grid_overrides and "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides: - non_binned_dims = tuple(grid_overrides["non_binned_dims"]) - if non_binned_dims: - logger.debug( - "NonBinned grid override: exposing non-binned dims as coordinates: %s", - non_binned_dims, - ) - # Append any missing names; keep existing order and avoid duplicates - existing = set(template.coordinate_names) - to_add = tuple(n for n in non_binned_dims if n not in existing) - if to_add: - template._logical_coord_names = template._logical_coord_names + to_add - - # Patch _add_coordinates to skip adding non-binned dims as 1D dimension coordinates - # This prevents them from being added with wrong dimensions (e.g., just "trace") - # They will be added later by build_dataset with full spatial_dimension_names - _patch_add_coordinates_for_non_binned(template, set(non_binned_dims)) + if grid_overrides is not None and grid_overrides.non_binned and grid_overrides.non_binned_dims: + non_binned_dims = tuple(grid_overrides.non_binned_dims) + logger.debug( + "NonBinned grid override: exposing non-binned dims as coordinates: %s", + non_binned_dims, + ) + # Append any missing names; keep existing order and avoid duplicates + existing = set(template.coordinate_names) + to_add = tuple(n for n in non_binned_dims if n not in existing) + if to_add: + template._logical_coord_names = template._logical_coord_names + to_add + + # Patch _add_coordinates to skip adding non-binned dims as 1D dimension coordinates + # This prevents them from being added with wrong dimensions (e.g., just "trace") + # They will be added later by build_dataset with full spatial_dimension_names + _patch_add_coordinates_for_non_binned(template, set(non_binned_dims)) def _scan_for_headers( segy_file_kwargs: SegyFileArguments, segy_file_info: SegyFileInfo, template: AbstractDatasetTemplate, - grid_overrides: dict[str, Any] | None = None, + grid_overrides: GridOverrides | None = None, ) -> tuple[list[Dimension], SegyHeaderArray]: """Extract trace dimensions and index headers from the SEG-Y file. @@ -346,13 +346,34 @@ def determine_target_size(var_type: str) -> int: ds.variables[index].metadata.chunk_grid = chunk_grid +def _coerce_grid_overrides( + grid_overrides: GridOverrides | dict[str, Any] | None, +) -> GridOverrides | None: + """Normalize public ``grid_overrides`` input into a :class:`GridOverrides` model. + + The internal ingestion pipeline only accepts the typed model. A legacy ``dict`` is + converted via :meth:`GridOverrides.from_legacy_dict` and a deprecation message is logged. + """ + if grid_overrides is None: + return None + + if isinstance(grid_overrides, GridOverrides): + return grid_overrides + + logger.warning( + "Passing `grid_overrides` as a dict is deprecated and will be removed in a " + "future release; pass a `mdio.GridOverrides` instance instead." + ) + return GridOverrides.model_validate(grid_overrides) + + def segy_to_mdio( # noqa PLR0913 segy_spec: SegySpec, mdio_template: AbstractDatasetTemplate, input_path: UPath | Path | str, output_path: UPath | Path | str, overwrite: bool = False, - grid_overrides: dict[str, Any] | None = None, + grid_overrides: GridOverrides | dict[str, Any] | None = None, segy_header_overrides: SegyHeaderOverrides | None = None, ) -> None: """A function that converts a SEG-Y file to an MDIO v1 file. @@ -365,12 +386,15 @@ def segy_to_mdio( # noqa PLR0913 input_path: The universal path of the input SEG-Y file. output_path: The universal path for the output MDIO v1 file. overwrite: Whether to overwrite the output file if it already exists. Defaults to False. - grid_overrides: Option to add grid overrides. + grid_overrides: Option to add grid overrides. Prefer a :class:`mdio.GridOverrides` + instance; ``dict`` is still accepted but emits a :class:`DeprecationWarning`. segy_header_overrides: Option to override specific SEG-Y headers during ingestion. Raises: FileExistsError: If the output location already exists and overwrite is False. """ + typed_grid_overrides = _coerce_grid_overrides(grid_overrides) + settings = MDIOSettings() _validate_spec_in_template(segy_spec, mdio_template) @@ -395,7 +419,7 @@ def segy_to_mdio( # noqa PLR0913 segy_file_kwargs, segy_file_info, template=mdio_template, - grid_overrides=grid_overrides, + grid_overrides=typed_grid_overrides, ) grid = _build_and_check_grid(segy_dimensions, segy_file_info, segy_headers) @@ -417,7 +441,7 @@ def segy_to_mdio( # noqa PLR0913 mdio_template = _update_template_units(mdio_template, spatial_unit) mdio_ds: Dataset = mdio_template.build_dataset(name=mdio_template.name, sizes=grid.shape, header_dtype=header_dtype) - _add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=grid_overrides) + _add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=typed_grid_overrides) # Dynamically chunk the variables based on their type _chunk_variable(ds=mdio_ds, target_variable_name="trace_mask") # trace_mask is a Variable and not a Coordinate diff --git a/src/mdio/ingestion/metadata.py b/src/mdio/ingestion/metadata.py index 674e91eb..4cc617f0 100644 --- a/src/mdio/ingestion/metadata.py +++ b/src/mdio/ingestion/metadata.py @@ -3,16 +3,16 @@ from __future__ import annotations from typing import TYPE_CHECKING -from typing import Any if TYPE_CHECKING: from mdio.builder.schemas import Dataset + from mdio.segy.geometry import GridOverrides -def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: dict[str, Any] | None) -> None: +def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: GridOverrides | None) -> None: """Add grid override to Dataset metadata if needed.""" if dataset.metadata.attributes is None: dataset.metadata.attributes = {} if grid_overrides is not None: - dataset.metadata.attributes["gridOverrides"] = grid_overrides + dataset.metadata.attributes["gridOverrides"] = grid_overrides.to_legacy_dict() diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index cbf02be0..eaa28fb4 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -6,9 +6,13 @@ from abc import ABC from abc import abstractmethod from typing import TYPE_CHECKING +from typing import Any import numpy as np from numpy.lib import recfunctions as rfn +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field from mdio.ingestion.segy.header_analysis import ShotGunGeometryType from mdio.ingestion.segy.header_analysis import StreamerShotGeometryType @@ -31,6 +35,61 @@ logger = logging.getLogger(__name__) +class GridOverrides(BaseModel): + """Type-safe configuration for grid override operations during SEG-Y ingestion.""" + + model_config = ConfigDict(extra="forbid", validate_by_name=True) + + auto_channel_wrap: bool = Field( + default=False, + alias="AutoChannelWrap", + description="Streamer: auto-detect channel-wrap geometry (Type A vs B).", + ) + auto_shot_wrap: bool = Field( + default=False, + alias="AutoShotWrap", + description="Streamer: derive dense shot_index from interleaved shot_point values.", + ) + calculate_shot_index: bool = Field( + default=False, + alias="CalculateShotIndex", + description="OBN: derive dense shot_index from sparse shot_point values per shot_line.", + ) + non_binned: bool = Field( + default=False, + alias="NonBinned", + description="Collapse selected dims into a single trace dimension without spatial binning.", + ) + has_duplicates: bool = Field( + default=False, + alias="HasDuplicates", + description="Add a trace dimension (chunksize 1) to disambiguate duplicate trace indices.", + ) + chunksize: int | None = Field( + default=None, + gt=0, + description="Chunk size for the trace dimension when `non_binned` is True.", + ) + non_binned_dims: list[str] | None = Field( + default=None, + description="Dimension names to collapse into the trace dimension when `non_binned` is True.", + ) + + def __bool__(self) -> bool: + """Return True if any override flag is enabled.""" + return ( + self.auto_channel_wrap + or self.auto_shot_wrap + or self.calculate_shot_index + or self.non_binned + or self.has_duplicates + ) + + def to_legacy_dict(self) -> dict[str, Any]: + """Dump to the legacy ``CamelCase`` dict shape consumed by :class:`GridOverrider`.""" + return self.model_dump(by_alias=True, exclude_defaults=True) + + class GridOverrideCommand(ABC): """Abstract base class for grid override commands.""" diff --git a/src/mdio/segy/utilities.py b/src/mdio/segy/utilities.py index f5e32660..64c89584 100644 --- a/src/mdio/segy/utilities.py +++ b/src/mdio/segy/utilities.py @@ -5,7 +5,6 @@ import itertools import logging from typing import TYPE_CHECKING -from typing import Any import numpy as np from dask.array.core import normalize_chunks @@ -24,6 +23,7 @@ from mdio.builder.templates.base import AbstractDatasetTemplate from mdio.segy.file import SegyFileArguments from mdio.segy.file import SegyFileInfo + from mdio.segy.geometry import GridOverrides logger = logging.getLogger(__name__) @@ -34,7 +34,7 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915 chunksize: tuple[int, ...] | None, template: AbstractDatasetTemplate, return_headers: bool = False, - grid_overrides: dict[str, Any] | None = None, + grid_overrides: GridOverrides | None = None, ) -> tuple[list[Dimension], tuple[int, ...]] | tuple[list[Dimension], tuple[int, ...], HeaderArray]: """Infer dimension ranges, and increments. @@ -50,7 +50,7 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915 chunksize: Chunk sizes to be used in grid plan. template: MDIO template where coordinate names and domain will be taken. return_headers: Option to return parsed headers with `Dimension` objects. Default is False. - grid_overrides: Option to add grid overrides. See main documentation. + grid_overrides: Typed grid override configuration, or ``None`` for no overrides. Returns: All index dimensions and chunksize or dimensions and chunksize together with header values. @@ -58,9 +58,6 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915 Raises: ValueError: If computed fields are not found after grid overrides. """ - if grid_overrides is None: - grid_overrides = {} - # Keep only dimension and non-dimension coordinates excluding the vertical axis horizontal_dimensions = template.spatial_dimension_names horizontal_coordinates = horizontal_dimensions + template.coordinate_names @@ -72,8 +69,8 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915 horizontal_coordinates = tuple(c for c in horizontal_coordinates if c not in computed_fields) # Ensure non_binned_dims are included in the headers to parse, even if not in template - if grid_overrides and "non_binned_dims" in grid_overrides: - for dim in grid_overrides["non_binned_dims"]: + if grid_overrides is not None and grid_overrides.non_binned_dims: + for dim in grid_overrides.non_binned_dims: if dim not in horizontal_coordinates: horizontal_coordinates = horizontal_coordinates + (dim,) @@ -94,20 +91,24 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915 subset=tuple(c for c in horizontal_coordinates if c not in fields_to_skip), ) - # Handle grid overrides. + # The legacy GridOverrider still consumes the dict shape; dump only at this boundary. + # Future PR will replace GridOverrider. override_handler = GridOverrider() headers_subset, horizontal_coordinates, chunksize = override_handler.run( headers_subset, horizontal_coordinates, chunksize=chunksize, - grid_overrides=grid_overrides, + grid_overrides=grid_overrides.to_legacy_dict() if grid_overrides is not None else {}, template=template, ) # After grid overrides, determine final spatial dimensions and their chunk sizes - non_binned_dims = set() - if "NonBinned" in grid_overrides and "non_binned_dims" in grid_overrides: - non_binned_dims = set(grid_overrides["non_binned_dims"]) + non_binned_active = grid_overrides is not None and grid_overrides.non_binned + non_binned_dims: set[str] = ( + set(grid_overrides.non_binned_dims) + if non_binned_active and grid_overrides is not None and grid_overrides.non_binned_dims + else set() + ) # Create mapping from dimension name to original chunk size for easy lookup original_spatial_dims = list(template.spatial_dimension_names) @@ -121,8 +122,11 @@ def get_grid_plan( # noqa: C901, PLR0912, PLR0913, PLR0915 if name in non_binned_dims: continue # Skip dimensions that became coordinates if name == "trace": - # Special handling for trace dimension - chunk_val = int(grid_overrides.get("chunksize", 1)) if "NonBinned" in grid_overrides else 1 + chunk_val = ( + int(grid_overrides.chunksize) + if non_binned_active and grid_overrides is not None and grid_overrides.chunksize is not None + else 1 + ) final_spatial_dims.append(name) final_spatial_chunks.append(chunk_val) elif name in dim_to_chunk: diff --git a/tests/unit/ingestion/test_metadata.py b/tests/unit/ingestion/test_metadata.py index 40d1eb81..2ba7d4c8 100644 --- a/tests/unit/ingestion/test_metadata.py +++ b/tests/unit/ingestion/test_metadata.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from mdio.ingestion.metadata import _add_grid_override_to_metadata +from mdio.segy.geometry import GridOverrides def _make_dataset(attributes: dict | None) -> SimpleNamespace: @@ -22,18 +23,23 @@ def test_initializes_attributes_dict_when_none(self) -> None: assert dataset.metadata.attributes == {} def test_adds_grid_overrides_when_provided(self) -> None: - """Grid overrides should land under the ``gridOverrides`` key.""" + """Active grid overrides should serialize under the ``gridOverrides`` key.""" dataset = _make_dataset(attributes=None) - overrides = {"HasDuplicates": True, "chunksize": 4} + overrides = GridOverrides(has_duplicates=True, chunksize=4) _add_grid_override_to_metadata(dataset, grid_overrides=overrides) - assert dataset.metadata.attributes == {"gridOverrides": overrides} + assert dataset.metadata.attributes == { + "gridOverrides": {"HasDuplicates": True, "chunksize": 4}, + } def test_preserves_existing_attributes(self) -> None: """Existing attribute keys should be preserved when adding overrides.""" dataset = _make_dataset(attributes={"existing": "value"}) - overrides = {"NonBinned": True} + overrides = GridOverrides(non_binned=True) _add_grid_override_to_metadata(dataset, grid_overrides=overrides) - assert dataset.metadata.attributes == {"existing": "value", "gridOverrides": overrides} + assert dataset.metadata.attributes == { + "existing": "value", + "gridOverrides": {"NonBinned": True}, + } def test_no_overrides_leaves_attributes_untouched(self) -> None: """Passing ``None`` overrides must not introduce a ``gridOverrides`` key.""" diff --git a/tests/unit/test_grid_overrides_pydantic.py b/tests/unit/test_grid_overrides_pydantic.py new file mode 100644 index 00000000..4060c94c --- /dev/null +++ b/tests/unit/test_grid_overrides_pydantic.py @@ -0,0 +1,122 @@ +"""Unit tests for the typed :class:`mdio.GridOverrides` Pydantic model.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import pytest +from pydantic import ValidationError + +from mdio.converters.segy import _coerce_grid_overrides +from mdio.segy.geometry import GridOverrides + +if TYPE_CHECKING: + from _pytest.logging import LogCaptureFixture + + +def test_grid_overrides_defaults() -> None: + """Default instance has every flag off and is falsy.""" + overrides = GridOverrides() + assert not overrides.auto_channel_wrap + assert not overrides.auto_shot_wrap + assert not overrides.calculate_shot_index + assert not overrides.non_binned + assert not overrides.has_duplicates + assert overrides.chunksize is None + assert overrides.non_binned_dims is None + assert not bool(overrides) + + +def test_grid_overrides_aliases() -> None: + """Legacy CamelCase aliases populate the snake_case fields.""" + overrides = GridOverrides(AutoChannelWrap=True, chunksize=64) + assert overrides.auto_channel_wrap is True + assert overrides.chunksize == 64 + assert bool(overrides) is True + + +def test_grid_overrides_calculate_shot_index_alias() -> None: + """OBN-specific ``CalculateShotIndex`` survives the typed shape.""" + overrides = GridOverrides(CalculateShotIndex=True) + assert overrides.calculate_shot_index is True + assert bool(overrides) is True + + +def test_grid_overrides_validation() -> None: + """``chunksize`` must be strictly positive.""" + with pytest.raises(ValidationError): + GridOverrides(chunksize=0) + + with pytest.raises(ValidationError): + GridOverrides(chunksize=-1) + + +def test_grid_overrides_rejects_unknown_keys() -> None: + """Unknown keys are rejected at construction by ``extra='forbid'``.""" + with pytest.raises(ValidationError): + GridOverrides.model_validate({"FutureFlag": True}) + + +def test_grid_overrides_serialization() -> None: + """``model_dump`` round-trips both legacy and modern key shapes.""" + overrides = GridOverrides(AutoChannelWrap=True, chunksize=64) + + dumped_legacy = overrides.model_dump(by_alias=True, exclude_defaults=True) + assert dumped_legacy == {"AutoChannelWrap": True, "chunksize": 64} + + dumped_modern = overrides.model_dump(exclude_defaults=True) + assert dumped_modern == {"auto_channel_wrap": True, "chunksize": 64} + + +def test_grid_overrides_to_legacy_dict() -> None: + """``to_legacy_dict`` produces the dict shape consumed by ``GridOverrider``.""" + overrides = GridOverrides(non_binned=True, chunksize=128, non_binned_dims=["offset", "azimuth"]) + assert overrides.to_legacy_dict() == { + "NonBinned": True, + "chunksize": 128, + "non_binned_dims": ["offset", "azimuth"], + } + + +def test_grid_overrides_to_legacy_dict_default_is_empty() -> None: + """Default instance dumps to an empty dict.""" + assert GridOverrides().to_legacy_dict() == {} + + +def test_grid_overrides_legacy_dict_roundtrip() -> None: + """A legacy dict survives ``model_validate``/``to_legacy_dict`` unchanged.""" + legacy = { + "CalculateShotIndex": True, + "NonBinned": True, + "chunksize": 64, + "non_binned_dims": ["offset"], + } + assert GridOverrides.model_validate(legacy).to_legacy_dict() == legacy + + +def test_coerce_grid_overrides_converts_dict_with_log(caplog: LogCaptureFixture) -> None: + """A dict input is coerced to :class:`GridOverrides` and a deprecation is logged.""" + legacy = {"CalculateShotIndex": True} + with caplog.at_level(logging.WARNING, logger="mdio.converters.segy"): + result = _coerce_grid_overrides(legacy) + assert isinstance(result, GridOverrides) + assert result.calculate_shot_index is True + assert any("deprecated" in record.message for record in caplog.records) + + +def test_coerce_grid_overrides_rejects_unknown_dict_keys() -> None: + """Dict inputs with unknown keys fail loudly instead of silently dropping them.""" + with pytest.raises(ValidationError): + _coerce_grid_overrides({"FutureFlag": True}) + + +def test_coerce_grid_overrides_passes_pydantic_model_through() -> None: + """A :class:`GridOverrides` instance is returned unchanged.""" + overrides = GridOverrides(auto_channel_wrap=True) + assert _coerce_grid_overrides(overrides) is overrides + + +def test_coerce_grid_overrides_none_returns_none() -> None: + """``None`` round-trips to ``None``.""" + assert _coerce_grid_overrides(None) is None