From 015063523bca4b0d50c0b5beab7e6c475405dcb1 Mon Sep 17 00:00:00 2001 From: Aaron Steven White Date: Mon, 11 May 2026 12:34:07 -0400 Subject: [PATCH 1/4] Updates required python version. --- README.md | 14 +++++++------- docs/developer-guide/architecture.md | 6 +++--- docs/developer-guide/contributing.md | 4 ++-- docs/developer-guide/setup.md | 26 +++++++++++++------------- docs/developer-guide/testing.md | 4 ++-- docs/index.md | 7 +++---- docs/installation.md | 8 ++++---- 7 files changed, 34 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index e8c479c..59486af 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # bead [![CI](https://github.com/FACTSlab/bead/actions/workflows/ci.yml/badge.svg)](https://github.com/FACTSlab/bead/actions/workflows/ci.yml) -[![Python 3.13](https://img.shields.io/badge/python-3.13-blue.svg)](https://www.python.org/downloads/) +[![Python 3.14](https://img.shields.io/badge/python-3.14-blue.svg)](https://www.python.org/downloads/) [![License: MIT](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE) -[![Documentation](https://img.shields.io/badge/docs-readthedocs-blue.svg)](https://bead.readthedocs.io) +[![Documentation](https://img.shields.io/badge/docs-readthedocs-blue.svg)](https://factslab.io/bead/) A Python framework for constructing, deploying, and analyzing large-scale linguistic judgment experiments with active learning. @@ -108,12 +108,12 @@ bead training run # Train with active learning ## Documentation -Full documentation: [bead.readthedocs.io](https://bead.readthedocs.io) +Full documentation: [bead.readthedocs.io](https://factslab.io/bead/) -- [Installation Guide](https://bead.readthedocs.io/installation/) -- [User Guide](https://bead.readthedocs.io/user-guide/) -- [API Reference](https://bead.readthedocs.io/api/) -- [Gallery Examples](https://bead.readthedocs.io/examples/) +- [Installation Guide](https://factslab.io/bead/installation/) +- [User Guide](https://factslab.io/bead/user-guide/) +- [API Reference](https://factslab.io/bead/api/) +- [Gallery Examples](https://factslab.io/bead/examples/) ## Contributing diff --git a/docs/developer-guide/architecture.md b/docs/developer-guide/architecture.md index 2c47166..4334d6c 100644 --- a/docs/developer-guide/architecture.md +++ b/docs/developer-guide/architecture.md @@ -349,7 +349,7 @@ experiment_list.metadata = { ### 3. Type Safety -bead uses full Python 3.13 type hints with Pydantic v2 validation. No `Any` or `object` types appear in core code (only in adapters for external APIs with dynamic types). +bead uses full Python 3.14 type hints with Pydantic v2 validation. No `Any` or `object` types appear in core code (only in adapters for external APIs with dynamic types). **Type annotations**: ```python @@ -382,7 +382,7 @@ class ExperimentList(BeadBaseModel): ```toml [tool.pyright] typeCheckingMode = "strict" -pythonVersion = "3.13" +pythonVersion = "3.14" exclude = [ "tests/**", # Tests don't require full type checking "bead/items/adapters/**", # External APIs have dynamic types @@ -936,7 +936,7 @@ bead's architecture prioritizes: 1. **Provenance**: UUID-based stand-off annotation creates unbroken provenance chains 2. **Modularity**: 17 modules organized by function, 6 pipeline stages -3. **Type Safety**: Full Python 3.13 type hints with Pydantic v2 validation +3. **Type Safety**: Full Python 3.14 type hints with Pydantic v2 validation 4. **Flexibility**: Configuration-first design, 9 task types, 12 constraint types 5. **Research Validity**: GLMM support, batch deployment, convergence detection diff --git a/docs/developer-guide/contributing.md b/docs/developer-guide/contributing.md index ddb12fc..270892e 100644 --- a/docs/developer-guide/contributing.md +++ b/docs/developer-guide/contributing.md @@ -335,7 +335,7 @@ uv run ruff format bead/ **Configuration** (from pyproject.toml): - Line length: 88 characters -- Target: Python 3.13 +- Target: Python 3.14 - Conventions: PEP 8, NumPy docstrings - Rules: E (errors), F (PyFlakes), I (imports), N (naming), D (docstrings), UP (upgrades), ANN (annotations), B (bugbear), A (builtins), C4 (comprehensions), PLC (Pylint) @@ -367,7 +367,7 @@ def partition_items( **Configuration** (from pyproject.toml): - Mode: strict -- Python version: 3.13 +- Python version: 3.14 - Excluded: tests/, adapters/ (external APIs have dynamic types) ### Running All Checks diff --git a/docs/developer-guide/setup.md b/docs/developer-guide/setup.md index 2599ab5..ec7e53d 100644 --- a/docs/developer-guide/setup.md +++ b/docs/developer-guide/setup.md @@ -6,15 +6,15 @@ This guide walks you through setting up a development environment for contributi ### Required Software -bead requires Python 3.13+ for modern type hint syntax (PEP 695 generic type parameters). Check your version: +bead requires Python 3.14+ for modern type hint syntax (PEP 695 generic type parameters). Check your version: ```bash -python3 --version # Should show 3.13.0 or higher +python3 --version # Should show 3.14.0 or higher ``` -If you need to install Python 3.13: +If you need to install Python 3.14: -- **macOS**: `brew install python@3.13` +- **macOS**: `brew install python@3.14` - **Linux**: Install from source or use pyenv - **Windows**: Download from python.org @@ -282,7 +282,7 @@ This performs static type analysis, catching: ```toml [tool.pyright] typeCheckingMode = "strict" -pythonVersion = "3.13" +pythonVersion = "3.14" exclude = [ "tests/**", # Tests don't require full type checking "bead/active_learning/**", @@ -333,7 +333,7 @@ This runs all 143 test files. Expected output: ``` ============================= test session starts ============================== -platform darwin -- Python 3.13.0, pytest-7.4.3, pluggy-1.3.0 +platform darwin -- Python 3.14.0, pytest-7.4.3, pluggy-1.3.0 rootdir: /path/to/bead configfile: pyproject.toml plugins: cov-4.1.0, mock-3.11.0 @@ -393,7 +393,7 @@ uv run pytest tests/ --cov=bead --cov-report=term-missing This shows code coverage with line numbers of uncovered code: ``` ----------- coverage: platform darwin, python 3.13.0 ----------- +---------- coverage: platform darwin, python 3.14.0 ----------- Name Stmts Miss Cover Missing -------------------------------------------------------------------- bead/__init__.py 3 0 100% @@ -553,16 +553,16 @@ If all checks pass, your development environment is ready. ### Python Version Issues -**Problem**: `uv sync` fails with "Requires Python >=3.13" +**Problem**: `uv sync` fails with "Requires Python >=3.14" -**Solution**: Install Python 3.13: +**Solution**: Install Python 3.14: ```bash # macOS -brew install python@3.13 +brew install python@3.14 # Linux (pyenv) -pyenv install 3.13.0 -pyenv local 3.13.0 +pyenv install 3.14.0 +pyenv local 3.14.0 # Windows # Download from python.org @@ -600,7 +600,7 @@ uv run pyright bead/ **Problem**: Tests fail immediately after cloning **Solution**: -1. Ensure Python 3.13+ is available +1. Ensure Python 3.14+ is available 2. Reinstall dependencies: `uv sync --all-extras --reinstall` 3. Clear pytest cache: `rm -rf .pytest_cache` 4. Run tests again: `uv run pytest tests/` diff --git a/docs/developer-guide/testing.md b/docs/developer-guide/testing.md index 3bf3235..78fd736 100644 --- a/docs/developer-guide/testing.md +++ b/docs/developer-guide/testing.md @@ -299,7 +299,7 @@ uv run pytest tests/ --cov=bead --cov-report=term-missing Output shows coverage per file with uncovered line numbers: ``` ----------- coverage: platform darwin, python 3.13.0 ----------- +---------- coverage: platform darwin, python 3.14.0 ----------- Name Stmts Miss Cover Missing -------------------------------------------------------------------- bead/__init__.py 3 0 100% @@ -761,7 +761,7 @@ Tests run automatically in CI on every push and pull request. The CI workflow (if configured) runs: -1. Install Python 3.13 +1. Install Python 3.14 2. Install dependencies: `uv sync --all-extras` 3. Run linters: `uv run ruff check bead/` 4. Run type checker: `uv run pyright bead/` diff --git a/docs/index.md b/docs/index.md index 04099e9..0292bea 100644 --- a/docs/index.md +++ b/docs/index.md @@ -48,7 +48,7 @@ uv sync --all-extras ## Requirements -- Python 3.13+ +- Python 3.14+ - Operating Systems: macOS, Linux, Windows (WSL recommended) ## Citation @@ -58,10 +58,9 @@ If you use bead in your research, please cite: ``` @software{white_bead_2026, author = {White, Aaron Steven}, - title = {Bead: A python framework for linguistic judgment experiments with active learning}, + title = {bead: A framework for large-scale linguistic judgment experiments}, year = {2026}, - url = {https://github.com/FACTSlab/bead}, - version = {0.2.0} + url = {https://github.com/FACTSlab/bead} } ``` diff --git a/docs/installation.md b/docs/installation.md index f99b01a..3ddf7c5 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -2,7 +2,7 @@ ## Requirements -- Python 3.13 or higher +- Python 3.14 or higher - Operating Systems: macOS, Linux, Windows (WSL recommended) ## Install from PyPI @@ -93,7 +93,7 @@ The compiled JavaScript is output to `bead/deployment/jspsych/dist/`. ### Python Version -Verify you have Python 3.13+: +Verify you have Python 3.14+: ```bash python --version @@ -102,8 +102,8 @@ python --version If not, install from [python.org](https://www.python.org/downloads/) or use pyenv: ```bash -pyenv install 3.13.0 -pyenv local 3.13.0 +pyenv install 3.14.0 +pyenv local 3.14.0 ``` ### Common Issues From 40497d3ac848eefce4b1d4249e9304a1a726f851 Mon Sep 17 00:00:00 2001 From: Aaron Steven White Date: Tue, 12 May 2026 12:04:18 -0400 Subject: [PATCH 2/4] Adds bead.config.compose, FORCED_CHOICE scale, gallery v0.4.0 wiring (0.5.0) Replaces the hand-rolled config loader with a generic, didactic-grounded composer that supports the full OmegaConf interpolation grammar (absolute / relative / bracketed / dotted-index / nested references, `\${literal}` escape, cycle detection) and the standard resolver suite (`oc.env`, `oc.select`, `oc.decode`, `oc.deprecated`, `oc.create`, `oc.dict.keys`, `oc.dict.values`). Adds `defaults: [...]` composition, TOML support, schema-driven strict-merge, and CLI-style `--set KEY=VALUE` overrides on every `bead` subcommand. Bead-specific resolvers (`\${bead.path:rel}`, `\${bead.anchor:name[,attr]}`) live in `bead.config.resolvers`. `bead.config.load_config` is now a thin wrapper that binds the schema to `BeadConfig`. The subpackage is structured for a future extraction into a standalone `didacticonf` distribution: only `merge.py` imports `didactic`. Adds `ScaleType.FORCED_CHOICE` so N-AFC tasks (response space = positional labels like `("first", "second")`, per-item alternatives on each `Item`) are first-class through `family_to_item_template` and the active-learning model registry. `AnchorSpec.scale_type` exposes the override declaratively. Wires `gallery/eng/argument_structure/` to the protocol layer: new `protocol.py` module materializes the live `AnnotationProtocol` from `config.yaml`; `create_2afc_pairs.py` threads the anchor name into every pair's `item_metadata`; `generate_deployment.py` and `simulate_pipeline.py` build their `ItemTemplate` via `family_to_item_template`; `make validate-protocol` gates data generation; `tests/test_protocol.py` covers the config-to-protocol round trip and bridges to items / models. Fixes a save/load bug where every model subclass re-read `config.json` after the base trainer had already popped its state-only fields, causing `model_validate` to reject them. The base `_load_model_components` now takes the cleaned `config_dict` so subclasses never re-read the file (8 model files updated). Replaces leaky `for line in open(...)` patterns in `bead.cli.items_factories` and switches the lightning trainer to `json.loads(self.config.model_dump_json())` so Paths / Enums flatten correctly into `ModelMetadata.training_config`. Bumps spaCy to 3.8.13 / pydantic to 2.13.4 (Python 3.14 import compatibility). Updates docs (`configuration.md`, `protocols.md`, training codeblock) and both READMEs (top-level Quick Start + gallery protocol section). Bumps version to 0.5.0. --- CHANGELOG.md | 64 ++ README.md | 105 ++- bead/__init__.py | 2 +- bead/active_learning/models/base.py | 11 +- bead/active_learning/models/binary.py | 12 +- bead/active_learning/models/categorical.py | 12 +- bead/active_learning/models/cloze.py | 12 +- bead/active_learning/models/forced_choice.py | 10 +- bead/active_learning/models/free_text.py | 10 +- bead/active_learning/models/magnitude.py | 10 +- bead/active_learning/models/multi_select.py | 10 +- bead/active_learning/models/ordinal_scale.py | 10 +- bead/active_learning/trainers/lightning.py | 17 +- bead/cli/items_factories.py | 12 +- bead/cli/main.py | 14 + bead/cli/protocol.py | 33 +- bead/config/__init__.py | 18 +- bead/config/compose/__init__.py | 58 ++ bead/config/compose/errors.py | 34 + bead/config/compose/interpolation.py | 674 ++++++++++++++++++ bead/config/compose/merge.py | 197 +++++ bead/config/compose/pipeline.py | 129 ++++ bead/config/compose/resolvers.py | 142 ++++ bead/config/compose/sources.py | 113 +++ bead/config/loader.py | 191 ++--- bead/config/protocol.py | 16 +- bead/config/resolvers.py | 115 +++ bead/protocol/anchor.py | 39 + bead/protocol/encoding.py | 64 +- bead/protocol/items.py | 13 +- docs/user-guide/api/training.md | 9 +- docs/user-guide/configuration.md | 38 +- docs/user-guide/protocols.md | 28 + gallery/eng/argument_structure/Makefile | 15 +- gallery/eng/argument_structure/README.md | 197 +++-- gallery/eng/argument_structure/config.yaml | 19 + .../argument_structure/create_2afc_pairs.py | 10 +- .../argument_structure/generate_deployment.py | 33 +- gallery/eng/argument_structure/protocol.py | 45 ++ .../argument_structure/simulate_pipeline.py | 47 +- .../argument_structure/tests/test_protocol.py | 53 ++ .../tests/test_simulation.py | 14 +- pyproject.toml | 2 +- .../models/binary/test_mixed_effects.py | 44 +- .../models/categorical/test_mixed_effects.py | 33 +- .../forced_choice/test_mixed_effects.py | 31 +- .../models/magnitude/test_mixed_effects.py | 44 +- .../test_multi_select_mixed_effects.py | 71 +- .../ordinal_scale/test_mixed_effects.py | 33 +- tests/config/compose/__init__.py | 1 + tests/config/compose/conftest.py | 28 + tests/config/compose/test_compose_basics.py | 157 ++++ .../compose/test_interpolation_basics.py | 120 ++++ .../compose/test_interpolation_custom.py | 43 ++ .../compose/test_interpolation_resolvers.py | 51 ++ tests/config/test_loader.py | 268 ------- tests/protocol/test_encoding.py | 36 + tests/protocol/test_items_bridge.py | 45 ++ uv.lock | 2 +- 59 files changed, 2845 insertions(+), 819 deletions(-) create mode 100644 bead/config/compose/__init__.py create mode 100644 bead/config/compose/errors.py create mode 100644 bead/config/compose/interpolation.py create mode 100644 bead/config/compose/merge.py create mode 100644 bead/config/compose/pipeline.py create mode 100644 bead/config/compose/resolvers.py create mode 100644 bead/config/compose/sources.py create mode 100644 bead/config/resolvers.py create mode 100644 gallery/eng/argument_structure/protocol.py create mode 100644 gallery/eng/argument_structure/tests/test_protocol.py create mode 100644 tests/config/compose/__init__.py create mode 100644 tests/config/compose/conftest.py create mode 100644 tests/config/compose/test_compose_basics.py create mode 100644 tests/config/compose/test_interpolation_basics.py create mode 100644 tests/config/compose/test_interpolation_custom.py create mode 100644 tests/config/compose/test_interpolation_resolvers.py delete mode 100644 tests/config/test_loader.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b78669..5a85f09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,70 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.5.0] - 2026-05-12 + +### Added + +#### `bead.config.compose` — didactic-grounded config composer + +- New subpackage `bead.config.compose` replaces the hand-rolled config + loader. Generic over any `dx.Model` schema; supports the full + OmegaConf interpolation grammar (`${section.field}`, `${.x}` / + `${..x}` relative, `${a.b[0]}` and `${a.b.0}` list indexing, + `${a.${b}}` nested, `\${literal}` escape, cycle detection). +- Built-in resolvers: `oc.env`, `oc.env:VAR,default`, `oc.select`, + `oc.decode` (base64), `oc.deprecated`, `oc.create`, + `oc.dict.keys`, `oc.dict.values`. Application-specific resolvers + register via `bead.config.compose.register_resolver`. +- Bead-specific resolvers in `bead.config.resolvers`: + `${bead.path:rel}` joins against the active root's + `paths.data_dir`; `${bead.anchor:name[,attr]}` post-validation + expansion. +- `defaults: [...]` composition at the top of any YAML/TOML config + composes referenced files left-to-right before the primary body. +- Strict-merge rejects unknown keys with the dotted path to the + offending site, walking nested `dx.Embed[T]` models from + `__field_specs__`. +- TOML configs (`.toml`) supported alongside YAML out of the box. +- `bead.config.load_config` is now a thin wrapper around + `compose(schema=BeadConfig, ...)`. The previous + `load_yaml_file` / `merge_configs` helpers are removed. +- CLI: every `bead ...` invocation accepts repeatable + `--set KEY=VALUE` overrides threaded into the compose pipeline. + +#### `ScaleType.FORCED_CHOICE` + +- New `ScaleType.FORCED_CHOICE` variant covers N-alternative + forced-choice tasks where per-item options vary across items + (response space is a fixed positional label set, e.g. + `("first", "second")`, but each `Item` carries its own + alternatives). `family_to_item_template` and the + active-learning model registry route forced-choice anchors to + `ForcedChoiceModel`. +- `AnchorSpec.scale_type` is an optional explicit override so config + files declare the task type alongside the response space. + +#### Gallery: `gallery/eng/argument_structure/` v0.4.0 wiring + +- New `protocol.py` module exposes `build_protocol()` / + `acceptability_family()` / `acceptability_anchor()`. The 2AFC + acceptability question is declared once in `config.yaml` under + `protocol:` and consumed by every script. +- `generate_deployment.py` and `simulate_pipeline.py` build their + `ItemTemplate` via `family_to_item_template` instead of literal + prompt strings. +- `create_2afc_pairs.py` threads the protocol anchor name + (`"acceptability"`) into every pair's `item_metadata` so the + JATOS-result → `AnnotationRecord` bridge can match responses + back to the canonical anchor. +- `make validate-protocol` builds the live `AnnotationProtocol` + from `config.yaml` and prints the family, prompt, and scale + type. Wired in as a prerequisite to `make data`. +- `tests/test_protocol.py` covers the config-to-protocol round + trip, the forced-choice scale type, the `family_to_item_template` + prompt agreement, and the active-learning model selection for + the resulting encoding. + ## [0.4.0] - 2026-05-07 ### Added diff --git a/README.md b/README.md index 59486af..94f9b9d 100644 --- a/README.md +++ b/README.md @@ -41,37 +41,69 @@ Always use `uv run` to execute commands. ## Quick Start ```python -from bead.resources import LexicalItem, Template, Lexicon -from bead.templates import TemplateFiller -from bead.items import ItemConstructor -from bead.lists import ListPartitioner - -# 1. Define resources -verbs = Lexicon(items=[ - LexicalItem(lemma="walk", pos="VERB", features={"transitive": False}), - LexicalItem(lemma="eat", pos="VERB", features={"transitive": True}), -]) - -template = Template( - text="The person {verb} the thing", - slots=["verb"], - language_code="en" +from bead.items.forced_choice import create_forced_choice_item +from bead.lists.partitioner import ListPartitioner +from bead.protocol import ( + AnnotationProtocol, + QuestionFamily, + ResponseSpace, + ScaleType, + SemanticAnchor, +) +from bead.protocol.items import family_to_item_template + +# 1. Declare the question being asked +anchor = SemanticAnchor( + name="acceptability", + target_property="acceptability", + canonical_prompt="Which sentence sounds more natural?", + response_space=ResponseSpace( + options=("first", "second"), + is_ordered=False, + scale_type=ScaleType.FORCED_CHOICE, + ), + required_keywords=frozenset({"natural"}), +) +protocol = AnnotationProtocol(families=[QuestionFamily(anchor=anchor)]) + +# 2. Build the deployable item template from the protocol +template = family_to_item_template( + protocol.family_by_name("acceptability"), + judgment_type="acceptability", ) -# 2. Fill templates -filler = TemplateFiller(strategy="exhaustive") -filled = filler.fill(templates=[template], lexicons={"verbs": verbs}) +# 3. Build forced-choice items (one per minimal pair) +items = [ + create_forced_choice_item( + "The cat sat on the mat.", + "The cats sat on the mat.", + item_template_id=template.id, + metadata={"anchor": "acceptability", "contrast": "number"}, + ), + # ... more pairs +] + +# 4. Partition into experiment lists +partitioner = ListPartitioner(random_seed=42) +lists = partitioner.partition( + [item.id for item in items], + n_lists=4, + metadata={item.id: dict(item.item_metadata) for item in items}, +) +``` -# 3. Construct items -constructor = ItemConstructor(models=["gpt2"]) -items = constructor.construct_forced_choice_items(filled, n_alternatives=2) +Or, drive the same pipeline from a single declarative config: -# 4. Partition into lists -partitioner = ListPartitioner() -lists = partitioner.partition(items.get_uuids(), n_lists=4) +```python +from bead.config import load_config -# 5. Deploy -lists.save("lists/experiment.jsonl") +# Composes profile defaults → defaults: [...] entries → primary YAML +# → extras → CLI-style overrides → resolves ${...} interpolation +config = load_config( + "config.yaml", + overrides=["paths.data_dir=/tmp/data"], +) +protocol = config.protocol.build() ``` ## Pipeline Stages @@ -93,19 +125,28 @@ lists.save("lists/experiment.jsonl") - **Model integration**: HuggingFace, OpenAI, Anthropic with caching - **Active learning**: uncertainty sampling with convergence detection - **Annotation protocols**: type-theoretic stack of `SemanticAnchor` (the question type), `ProtocolContext` (the dependent index), `RealizationStrategy` (template / contextual / LM phrasings), and `DriftGuard` (the type-checker over realized prompts), composed into conditional `AnnotationProtocol`s +- **Config composer** (`bead.config.compose`): the full OmegaConf interpolation grammar — `${section.field}`, `${.x}` / `${..y}` relative references, `${a.b[0]}` / `${a.b.0}` list indexing, `${a.${b}}` nesting, `\${literal}` escape, built-in resolvers (`oc.env`, `oc.select`, `oc.decode`, `oc.deprecated`, `oc.create`, `oc.dict.keys`, `oc.dict.values`); `defaults: [...]` composition; strict-merge against didactic schemas; YAML and TOML - **jsPsych 8.x**: Material Design UI with JATOS deployment ## CLI ```bash -bead init my-experiment # Create project structure -bead templates fill # Fill templates -bead items construct # Construct items -bead lists partition # Create experiment lists -bead deploy # Generate jsPsych experiment -bead training run # Train with active learning +bead init my-experiment # Create project structure +bead templates fill # Fill templates +bead items construct # Construct items +bead lists partition # Create experiment lists +bead deploy # Generate jsPsych experiment +bead training run # Train with active learning +bead protocol validate # Validate the protocol section of a config +bead protocol realize # Materialize realizations for contexts +bead protocol items # Bridge a protocol to item templates ``` +Every command accepts repeatable `--set KEY=VALUE` overrides applied +through the config composer, so any field of `BeadConfig` (including +nested `paths.data_dir`, `protocol.drift.min_length`, etc.) can be +overridden from the shell without editing the YAML. + ## Documentation Full documentation: [bead.readthedocs.io](https://factslab.io/bead/) diff --git a/bead/__init__.py b/bead/__init__.py index 74233be..4949b4b 100644 --- a/bead/__init__.py +++ b/bead/__init__.py @@ -6,6 +6,6 @@ from __future__ import annotations -__version__ = "0.4.0" +__version__ = "0.5.0" __author__ = "Aaron Steven White" __email__ = "aaron.white@rochester.edu" diff --git a/bead/active_learning/models/base.py b/bead/active_learning/models/base.py index 22117b1..e931eee 100644 --- a/bead/active_learning/models/base.py +++ b/bead/active_learning/models/base.py @@ -499,13 +499,20 @@ def _save_model_components(self, save_path: Path) -> None: pass @abstractmethod - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components. Parameters ---------- load_path : Path Directory to load from. + config_dict : dict[str, object] + Schema-only config dict (model-specific state fields have + already been popped by :meth:`_restore_training_state`). + Subclasses use this to reconstruct ``self.config`` without + re-reading ``config.json`` from disk. """ pass @@ -814,7 +821,7 @@ def load(self, path: str) -> None: # Load model-specific components (which will reconstruct the config) # This must happen before initializing random effects so config is correct - self._load_model_components(load_path) + self._load_model_components(load_path, config_dict) # Initialize and load random effects n_classes = self._get_n_classes_for_random_effects() diff --git a/bead/active_learning/models/binary.py b/bead/active_learning/models/binary.py index a731429..5537eb1 100644 --- a/bead/active_learning/models/binary.py +++ b/bead/active_learning/models/binary.py @@ -852,20 +852,18 @@ def _restore_training_state(self, config_dict: dict[str, object]) -> None: self.label_names = config_dict.pop("label_names") self.positive_class = config_dict.pop("positive_class") - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components. Parameters ---------- load_path : Path Directory to load from. + config_dict : dict[str, object] + Schema-only config dict. """ - # Load config.json to reconstruct config - with open(load_path / "config.json") as f: - import json # noqa: PLC0415 - - config_dict = json.load(f) - # Reconstruct MixedEffectsConfig if needed if "mixed_effects" in config_dict and isinstance( config_dict["mixed_effects"], dict diff --git a/bead/active_learning/models/categorical.py b/bead/active_learning/models/categorical.py index b0cbf82..78a57a6 100644 --- a/bead/active_learning/models/categorical.py +++ b/bead/active_learning/models/categorical.py @@ -885,20 +885,18 @@ def _restore_training_state(self, config_dict: dict[str, object]) -> None: self.num_classes = config_dict.pop("num_classes") self.category_names = config_dict.pop("category_names") - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components. Parameters ---------- load_path : Path Directory to load from. + config_dict : dict[str, object] + Schema-only config dict. """ - # Load config.json to reconstruct config - with open(load_path / "config.json") as f: - import json # noqa: PLC0415 - - config_dict = json.load(f) - # Reconstruct MixedEffectsConfig if needed if "mixed_effects" in config_dict and isinstance( config_dict["mixed_effects"], dict diff --git a/bead/active_learning/models/cloze.py b/bead/active_learning/models/cloze.py index c8cc846..97f44aa 100644 --- a/bead/active_learning/models/cloze.py +++ b/bead/active_learning/models/cloze.py @@ -769,20 +769,18 @@ def _get_save_state(self) -> dict[str, object]: """ return {} - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components. Parameters ---------- load_path : Path Directory to load from. + config_dict : dict[str, object] + Schema-only config dict. """ - # Load config.json to reconstruct config - with open(load_path / "config.json") as f: - import json # noqa: PLC0415 - - config_dict = json.load(f) - # Reconstruct MixedEffectsConfig if needed if "mixed_effects" in config_dict and isinstance( config_dict["mixed_effects"], dict diff --git a/bead/active_learning/models/forced_choice.py b/bead/active_learning/models/forced_choice.py index 941429d..376786e 100644 --- a/bead/active_learning/models/forced_choice.py +++ b/bead/active_learning/models/forced_choice.py @@ -900,18 +900,18 @@ def _restore_training_state(self, config_dict: dict[str, object]) -> None: self.num_classes = config_dict.pop("num_classes") self.option_names = config_dict.pop("option_names") - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components. Parameters ---------- load_path : Path Directory to load from. + config_dict : dict[str, object] + Schema-only config dict. """ - # Load config.json to reconstruct config - with open(load_path / "config.json") as f: - config_dict = json.load(f) - # Reconstruct MixedEffectsConfig if needed if "mixed_effects" in config_dict and isinstance( config_dict["mixed_effects"], dict diff --git a/bead/active_learning/models/free_text.py b/bead/active_learning/models/free_text.py index 4680e84..69feba9 100644 --- a/bead/active_learning/models/free_text.py +++ b/bead/active_learning/models/free_text.py @@ -666,18 +666,18 @@ def _save_model_components(self, save_path: Path) -> None: self.model.save_pretrained(save_path / "model") self.tokenizer.save_pretrained(save_path / "model") - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components (model, tokenizer). Parameters ---------- load_path : Path Directory path to load the model from. + config_dict : dict[str, object] + Schema-only config dict. """ - # Load config.json to reconstruct config - with open(load_path / "config.json") as f: - config_dict = json.load(f) - # Reconstruct MixedEffectsConfig if needed if "mixed_effects" in config_dict and isinstance( config_dict["mixed_effects"], dict diff --git a/bead/active_learning/models/magnitude.py b/bead/active_learning/models/magnitude.py index 760d2d6..662bacf 100644 --- a/bead/active_learning/models/magnitude.py +++ b/bead/active_learning/models/magnitude.py @@ -777,18 +777,18 @@ def _restore_training_state(self, config_dict: dict[str, object]) -> None: # MagnitudeModel doesn't have additional training state to restore pass - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components. Parameters ---------- load_path : Path Directory to load from. + config_dict : dict[str, object] + Schema-only config dict. """ - # Load config.json to reconstruct config - with open(load_path / "config.json") as f: - config_dict = json.load(f) - # Reconstruct MixedEffectsConfig if needed if "mixed_effects" in config_dict and isinstance( config_dict["mixed_effects"], dict diff --git a/bead/active_learning/models/multi_select.py b/bead/active_learning/models/multi_select.py index 3a9c818..acb10c5 100644 --- a/bead/active_learning/models/multi_select.py +++ b/bead/active_learning/models/multi_select.py @@ -741,18 +741,18 @@ def _restore_training_state(self, config_dict: dict[str, object]) -> None: self.num_options = config_dict.pop("num_options") self.option_names = config_dict.pop("option_names") - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components. Parameters ---------- load_path : Path Directory to load from. + config_dict : dict[str, object] + Schema-only config dict. """ - # Load config.json to reconstruct config - with open(load_path / "config.json") as f: - config_dict = json.load(f) - # Reconstruct MixedEffectsConfig if needed if "mixed_effects" in config_dict and isinstance( config_dict["mixed_effects"], dict diff --git a/bead/active_learning/models/ordinal_scale.py b/bead/active_learning/models/ordinal_scale.py index 73d845e..f2ab1c7 100644 --- a/bead/active_learning/models/ordinal_scale.py +++ b/bead/active_learning/models/ordinal_scale.py @@ -753,18 +753,18 @@ def _restore_training_state(self, config_dict: dict[str, object]) -> None: # OrdinalScaleModel doesn't have additional training state to restore pass - def _load_model_components(self, load_path: Path) -> None: + def _load_model_components( + self, load_path: Path, config_dict: dict[str, object] + ) -> None: """Load model-specific components. Parameters ---------- load_path : Path Directory to load from. + config_dict : dict[str, object] + Schema-only config dict. """ - # Load config.json to reconstruct config - with open(load_path / "config.json") as f: - config_dict = json.load(f) - # Reconstruct MixedEffectsConfig if needed if "mixed_effects" in config_dict and isinstance( config_dict["mixed_effects"], dict diff --git a/bead/active_learning/trainers/lightning.py b/bead/active_learning/trainers/lightning.py index 0a8f699..46a617d 100644 --- a/bead/active_learning/trainers/lightning.py +++ b/bead/active_learning/trainers/lightning.py @@ -240,14 +240,15 @@ def train( if best_checkpoint_str: best_checkpoint = Path(best_checkpoint_str) - # create metadata - config_dict = ( - self.config - if isinstance(self.config, dict) - else ( - self.config.model_dump() if hasattr(self.config, "model_dump") else {} - ) - ) + # create metadata; flatten ``self.config`` to a JSON-shaped dict + # so ``ModelMetadata.training_config`` (typed ``dict[str, JsonValue]``) + # accepts it. ``model_dump_json`` walks Paths / Enums / etc. + if isinstance(self.config, dict): + config_dict = json.loads(json.dumps(self.config, default=str)) + elif hasattr(self.config, "model_dump_json"): + config_dict = json.loads(self.config.model_dump_json()) + else: + config_dict = {} metadata = ModelMetadata( model_name=model_name, diff --git a/bead/cli/items_factories.py b/bead/cli/items_factories.py index b36bab0..93bc61c 100644 --- a/bead/cli/items_factories.py +++ b/bead/cli/items_factories.py @@ -156,7 +156,7 @@ def create_forced_choice_from_texts( """ try: # Load texts - texts: list[str] = [line.strip() for line in open(texts_file) if line.strip()] + texts: list[str] = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] print_info(f"Loaded {len(texts)} texts") # Create items by generating all combinations of n_alternatives from texts @@ -286,7 +286,7 @@ def create_ordinal_scale_from_texts( """ try: # Load texts - texts = [line.strip() for line in open(texts_file) if line.strip()] + texts = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] print_info(f"Loaded {len(texts)} texts") # Create items @@ -459,7 +459,7 @@ def create_binary_from_texts( """ try: # Load texts - texts = [line.strip() for line in open(texts_file) if line.strip()] + texts = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] print_info(f"Loaded {len(texts)} texts") # Create items @@ -535,7 +535,7 @@ def create_multi_select_from_texts( """ try: # Load texts - texts: list[str] = [line.strip() for line in open(texts_file) if line.strip()] + texts: list[str] = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] print_info(f"Loaded {len(texts)} texts") # Parse options @@ -619,7 +619,7 @@ def create_magnitude_from_texts( """ try: # Load texts - texts = [line.strip() for line in open(texts_file) if line.strip()] + texts = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] print_info(f"Loaded {len(texts)} texts") # Create items @@ -679,7 +679,7 @@ def create_free_text_from_texts( """ try: # Load texts - texts = [line.strip() for line in open(texts_file) if line.strip()] + texts = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] print_info(f"Loaded {len(texts)} texts") # Create items diff --git a/bead/cli/main.py b/bead/cli/main.py index a9d7c8d..3268159 100644 --- a/bead/cli/main.py +++ b/bead/cli/main.py @@ -49,6 +49,18 @@ default=False, help="Suppress all output except errors", ) +@click.option( + "--set", + "set_overrides", + multiple=True, + metavar="KEY=VALUE", + help=( + "Override one config key (dotted path). May be repeated. " + "Values are parsed as YAML so numbers, booleans, lists, " + "and quoted strings keep their type. Example: " + "--set paths.data_dir=/tmp --set protocol.lm_temperature=0.5" + ), +) @click.pass_context def cli( ctx: click.Context, @@ -56,6 +68,7 @@ def cli( profile: str, verbose: bool, quiet: bool, + set_overrides: tuple[str, ...], ) -> None: r"""CLI for linguistic judgment experiments. @@ -90,6 +103,7 @@ def cli( ctx.obj["profile"] = profile ctx.obj["verbose"] = verbose ctx.obj["quiet"] = quiet + ctx.obj["set_overrides"] = set_overrides @cli.command() diff --git a/bead/cli/protocol.py b/bead/cli/protocol.py index 7ca84e8..55aa4b7 100644 --- a/bead/cli/protocol.py +++ b/bead/cli/protocol.py @@ -20,6 +20,7 @@ from __future__ import annotations +from collections.abc import Sequence from pathlib import Path import click @@ -38,9 +39,17 @@ ) -def _load_protocol(config_path: Path | None, profile: str) -> AnnotationProtocol: +def _load_protocol( + config_path: Path | None, + profile: str, + overrides: Sequence[str] = (), +) -> AnnotationProtocol: """Load a :class:`BeadConfig` and materialize its protocol.""" - config = load_config(config_path=config_path, profile=profile) + config = load_config( + config_path=config_path, + profile=profile, + overrides=overrides, + ) cache = ModelOutputCache( cache_dir=config.paths.cache_dir / "models", backend="filesystem", @@ -77,7 +86,12 @@ def protocol() -> None: help="Path to a project config file (defaults to bead.toml).", ) @click.option("--profile", default="default", help="Configuration profile.") -def validate(config_file: Path | None, profile: str) -> None: +@click.pass_context +def validate( + ctx: click.Context, + config_file: Path | None, + profile: str, +) -> None: """Validate the protocol configuration and report its families. Loads the configured :class:`AnnotationProtocol`, prints a one-line @@ -85,8 +99,9 @@ def validate(config_file: Path | None, profile: str) -> None: declared dependencies), and exits non-zero on any construction error. """ + set_overrides: tuple[str, ...] = ctx.obj.get("set_overrides", ()) if ctx.obj else () try: - proto = _load_protocol(config_file, profile) + proto = _load_protocol(config_file, profile, set_overrides) except Exception as exc: # noqa: BLE001 print_error(f"Protocol failed to materialize: {exc}") raise SystemExit(1) from exc @@ -125,7 +140,9 @@ def validate(config_file: Path | None, profile: str) -> None: default="acceptability", help=("Judgment type for emitted ItemTemplates (used when --emit-items is set)."), ) +@click.pass_context def realize( + ctx: click.Context, contexts_file: Path, output_file: Path, config_file: Path | None, @@ -139,7 +156,8 @@ def realize( (one per line). OUTPUT_FILE will be written as JSONL with one record per realized question (skipping non-applicable families). """ - proto = _load_protocol(config_file, profile) + set_overrides: tuple[str, ...] = ctx.obj.get("set_overrides", ()) if ctx.obj else () + proto = _load_protocol(config_file, profile, set_overrides) if len(proto) == 0: print_error( "Configured protocol is empty; nothing to realize. Add " @@ -187,7 +205,9 @@ def realize( default="acceptability", help="Judgment type assigned to every ItemTemplate.", ) +@click.pass_context def items( + ctx: click.Context, output_file: Path, config_file: Path | None, profile: str, @@ -199,7 +219,8 @@ def items( family in the configured protocol and writes them to OUTPUT_FILE as JSONL. The resulting file feeds Stage 3 (item construction). """ - proto = _load_protocol(config_file, profile) + set_overrides: tuple[str, ...] = ctx.obj.get("set_overrides", ()) if ctx.obj else () + proto = _load_protocol(config_file, profile, set_overrides) if len(proto) == 0: print_error("Configured protocol is empty; no templates to emit.") raise SystemExit(1) diff --git a/bead/config/__init__.py b/bead/config/__init__.py index 20efdd1..4aec8f5 100644 --- a/bead/config/__init__.py +++ b/bead/config/__init__.py @@ -7,13 +7,20 @@ from __future__ import annotations from bead.config.active_learning import ActiveLearningConfig +from bead.config.compose import ( + ComposeValue, + ConfigError, + InterpolationError, + compose, + register_resolver, +) from bead.config.config import BeadConfig from bead.config.defaults import DEFAULT_CONFIG, get_default_config from bead.config.deployment import DeploymentConfig from bead.config.env import load_from_env from bead.config.item import ItemConfig from bead.config.list import ListConfig -from bead.config.loader import load_config, load_yaml_file, merge_configs +from bead.config.loader import load_config from bead.config.logging import LoggingConfig from bead.config.model import ModelConfig from bead.config.paths import PathsConfig @@ -69,10 +76,13 @@ "PROFILES", "get_profile", "list_profiles", - # loading + # loading + composition + "ComposeValue", + "ConfigError", + "InterpolationError", + "compose", "load_config", - "load_yaml_file", - "merge_configs", + "register_resolver", # environment "load_from_env", # validation diff --git a/bead/config/compose/__init__.py b/bead/config/compose/__init__.py new file mode 100644 index 0000000..9ec8f6b --- /dev/null +++ b/bead/config/compose/__init__.py @@ -0,0 +1,58 @@ +r"""Composition, interpolation, and validation for didactic configs. + +A self-contained subpackage that turns a YAML or TOML file (plus +profile defaults, overlay files, and CLI overrides) into a fully +interpolated, validated ``dx.Model``. The grammar follows OmegaConf's +interpolation conventions: + +- ``${section.field}`` absolute and ``${.field}`` / ``${..field}`` + relative dotted-path references +- ``${a.b[0]}`` and ``${a.b.0}`` list indexing +- ``${a.${b.c}}`` nested interpolations +- ``${name:arg1,arg2}`` resolver calls (built-ins plus user-registered) +- ``"prefix_${a.b}_suffix"`` string concatenation with type-preserving + whole-value substitution +- ``\\${literal}`` escape + +The subpackage imports nothing from ``bead`` outside of itself; the +only external dependency in implementation modules is ``didactic`` +(used in :mod:`~bead.config.compose.merge` to enforce strict-key +checking against a target schema). It is structured so it can be +lifted into a standalone distribution by relocating the package and +adjusting a couple of internal imports. + +Public API +---------- +- :func:`compose` — the full pipeline. +- :func:`register_resolver` / :func:`unregister_resolver` / + :func:`list_resolvers` — manage custom resolvers. +- :func:`resolve` — apply interpolation to an already-merged dict. +- :class:`ConfigError`, :class:`InterpolationError` — exceptions. +""" + +from __future__ import annotations + +from bead.config.compose.errors import ConfigError, InterpolationError +from bead.config.compose.interpolation import ( + ComposeValue, + ResolverFn, + active_root, + list_resolvers, + register_resolver, + resolve, + unregister_resolver, +) +from bead.config.compose.pipeline import compose + +__all__ = [ + "ComposeValue", + "ConfigError", + "InterpolationError", + "ResolverFn", + "active_root", + "compose", + "list_resolvers", + "register_resolver", + "resolve", + "unregister_resolver", +] diff --git a/bead/config/compose/errors.py b/bead/config/compose/errors.py new file mode 100644 index 0000000..3484159 --- /dev/null +++ b/bead/config/compose/errors.py @@ -0,0 +1,34 @@ +"""Exceptions raised by the compose subpackage. + +These types are part of the public API and survive the eventual +extraction of this subpackage into a standalone distribution. +""" + +from __future__ import annotations + + +class ConfigError(ValueError): + """Raised when a configuration is malformed or fails validation. + + Common causes + ------------- + - Unknown keys at any level of the merged config (strict-merge + rejection). + - A ``defaults: [...]`` entry that cannot be resolved to a + loadable YAML or TOML file. + - A dotted-key override that targets a key not declared in the + schema. + """ + + +class InterpolationError(ValueError): + """Raised when an interpolation expression cannot be resolved. + + Common causes + ------------- + - A reference like ``${a.b.c}`` that does not exist in the + composed config. + - A resolver call to an unregistered name (e.g. ``${unknown:x}``). + - A cycle detected during evaluation + (``${a} → ${b} → ${a}``). + """ diff --git a/bead/config/compose/interpolation.py b/bead/config/compose/interpolation.py new file mode 100644 index 0000000..9acb86c --- /dev/null +++ b/bead/config/compose/interpolation.py @@ -0,0 +1,674 @@ +r"""Interpolation grammar and evaluator. + +bead's config interpolation grammar matches OmegaConf's: + +- ``${section.field}`` — absolute dotted-path reference. +- ``${.field}`` / ``${..field}`` — relative reference; each leading + dot walks one level up from the current node's parent. +- ``${a.b[0]}`` / ``${a.b.0}`` — list indexing (bracketed or + dotted-integer; both supported). +- ``${a.${b}}`` — nested interpolations; the inner is resolved first. +- ``"prefix_${a.b}_suffix"`` — string concatenation. A standalone + ``${a.b}`` (whole-value, no surrounding text) substitutes the + typed value; a substring substitution coerces to ``str``. +- ``${name:arg1,arg2}`` — resolver call. Built-in resolvers + (``oc.env``, ``oc.select``, ``oc.dict.keys``, ``oc.dict.values``, + ``oc.decode``, ``oc.deprecated``, ``oc.create``) are registered by + :mod:`bead.config.compose.resolvers`; user code adds more via + :func:`register_resolver`. +- ``\\${literal}`` — escape; produces a literal ``${literal}``. + +Cycle detection raises :class:`~bead.config.compose.errors.InterpolationError` +with the cycle path in the message. + +The evaluator operates on plain Python dicts and lists (and the +``ComposeValue`` union from :mod:`bead.data.base`). It imports nothing +from ``didactic`` or the rest of bead, so it can be lifted into a +standalone distribution without changes. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import Final + +from bead.config.compose.errors import InterpolationError + +type ComposeValue = ( + str | int | float | bool | None | list["ComposeValue"] | dict[str, "ComposeValue"] +) +"""The kind of value the compose engine operates on. + +This is the list-based JSON shape produced by ``yaml.safe_load`` and +``tomllib.load``. didactic validation accepts lists for ``tuple[T, ...]`` +fields, so this type matches both inputs and validated outputs. + +This alias lives in the subpackage so the package can be extracted +without depending on bead's ``ComposeValue``. +""" + + +ResolverFn = Callable[..., ComposeValue] +"""Type of a resolver function. Takes positional string args, returns +any JSON-shaped value. Resolvers may *themselves* contain +interpolations that are resolved before the resolver is called. +""" + + +_RESOLVERS: dict[str, ResolverFn] = {} + +_ACTIVE_ROOT: ContextVar[dict[str, ComposeValue] | None] = ContextVar( + "bead_compose_active_root", default=None +) +"""The dict currently being interpolated. + +Set by :func:`resolve` for the duration of evaluation, so root-aware +resolvers (registered by application code) can reach the in-flight +config. :func:`active_root` is the public accessor. +""" + + +def active_root() -> dict[str, ComposeValue] | None: + """Return the dict currently being interpolated, or ``None``. + + Resolvers that need to consult other parts of the composed config + call this to retrieve the in-flight root. Returns ``None`` when + called outside an active :func:`resolve` invocation. + """ + return _ACTIVE_ROOT.get() + + +@contextmanager +def _activate_root( + root: dict[str, ComposeValue], +) -> Iterator[None]: + token = _ACTIVE_ROOT.set(root) + try: + yield + finally: + _ACTIVE_ROOT.reset(token) + + +def register_resolver(name: str, fn: ResolverFn, *, replace: bool = False) -> None: + """Register a custom resolver under ``name``. + + Parameters + ---------- + name : str + Resolver name as it appears in ``${name:args}``. + fn : ResolverFn + Function called with the comma-separated args (each a string, + pre-stripped). The return value is substituted in place. + replace : bool, optional + Whether re-registration is allowed for an existing name. + Defaults to ``False`` so accidental shadowing is loud. + + Raises + ------ + ValueError + If ``name`` is already registered and ``replace`` is ``False``. + """ + if name in _RESOLVERS and not replace: + raise ValueError( + f"Resolver {name!r} already registered. Pass replace=True to override." + ) + _RESOLVERS[name] = fn + + +def unregister_resolver(name: str) -> None: + """Remove a registered resolver. No-op if it does not exist.""" + _RESOLVERS.pop(name, None) + + +def list_resolvers() -> tuple[str, ...]: + """Return the names of every registered resolver, sorted.""" + return tuple(sorted(_RESOLVERS)) + + +# --------------------------------------------------------------------------- +# AST nodes +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _Literal: + """A literal string segment with no interpolation.""" + + text: str + + +@dataclass(frozen=True) +class _Reference: + """A reference of the form ``${path}``. + + ``up`` counts the number of leading dots (``${.x}`` → ``up=1``, + ``${..x}`` → ``up=2``, ``${x}`` → ``up=0``). For absolute + references ``up=0`` and the path starts at the root; relative + references walk up that many parents before descending. + + Each element of ``parts`` is either a string (a dict key or a + bare path segment) or an integer (a list index). String parts + may themselves contain ``_Node`` lists, since interpolations can + be nested inside path segments (``${a.${b.c}.d}``). + """ + + up: int + parts: tuple[_PathSegment, ...] + + +@dataclass(frozen=True) +class _ResolverCall: + """A resolver call of the form ``${name:arg1,arg2}``. + + ``args`` is a tuple of node-lists; each node-list is the AST for + one argument, since arguments may themselves contain + interpolations. + """ + + name: str + args: tuple[tuple[_Node, ...], ...] + + +type _PathSegment = str | int | tuple["_Node", ...] +"""One element of an interpolation path. + +A plain string or integer is a literal segment; a tuple-of-nodes is +a nested expression that must be evaluated to a string before being +spliced into the path. +""" + + +type _Node = _Literal | _Reference | _ResolverCall +"""One element of a parsed expression.""" + + +# --------------------------------------------------------------------------- +# Parser +# --------------------------------------------------------------------------- + + +@dataclass +class _Parser: + """Recursive-descent parser for interpolation strings. + + Maintains a single ``pos`` cursor into ``text``. Methods that + parse a construct return its AST and advance the cursor. + """ + + text: str + pos: int = 0 + + def parse(self) -> tuple[_Node, ...]: + """Parse the whole text into a node tuple.""" + return self._parse_until(end_chars=()) + + def _parse_until(self, *, end_chars: tuple[str, ...]) -> tuple[_Node, ...]: + """Parse until one of ``end_chars`` is reached (or EOF). + + Used both for the top level (``end_chars=()``) and for + nested expression bodies. + """ + nodes: list[_Node] = [] + buf: list[str] = [] + + def flush_literal() -> None: + if buf: + nodes.append(_Literal("".join(buf))) + buf.clear() + + while self.pos < len(self.text): + ch = self.text[self.pos] + if ch in end_chars: + break + if ch == "\\" and self.pos + 1 < len(self.text): + nxt = self.text[self.pos + 1] + if nxt == "$": + buf.append("$") + self.pos += 2 + continue + if nxt == "\\": + buf.append("\\") + self.pos += 2 + continue + buf.append(ch) + self.pos += 1 + continue + if ch == "$" and self._peek(1) == "{": + flush_literal() + nodes.append(self._parse_interp()) + continue + buf.append(ch) + self.pos += 1 + + flush_literal() + return tuple(nodes) + + def _peek(self, offset: int) -> str: + idx = self.pos + offset + if 0 <= idx < len(self.text): + return self.text[idx] + return "" + + def _parse_interp(self) -> _Node: + """Parse a ``${...}`` expression starting at ``self.pos``.""" + assert self.text[self.pos : self.pos + 2] == "${" + self.pos += 2 # consume "${" + + up = 0 + while self.pos < len(self.text) and self.text[self.pos] == ".": + up += 1 + self.pos += 1 + + parts: list[_PathSegment] = [] + head, head_terminator = self._parse_path_head() + if head != "" or head_terminator == ":": + parts.append(head) + while head_terminator == ".": + seg, head_terminator = self._parse_path_head() + if seg == "": + raise InterpolationError( + f"Empty path segment in interpolation at pos {self.pos}" + ) + parts.append(seg) + + if head_terminator == ":": + if up != 0: + raise InterpolationError("Resolver call cannot have leading dots") + if not parts or not all(isinstance(p, str) and p != "" for p in parts): + raise InterpolationError( + "Resolver name must be a static dotted identifier" + ) + name = ".".join(p for p in parts if isinstance(p, str)) + args = self._parse_resolver_args() + self._expect("}") + return _ResolverCall(name=name, args=args) + + while head_terminator == "[": + self.pos += 1 # consume "[" + idx_text, _ = self._read_until("]") + self._expect("]") + try: + parts.append(int(idx_text)) + except ValueError as exc: + raise InterpolationError( + f"Bracketed index must be an integer: [{idx_text!r}]" + ) from exc + if self.pos < len(self.text) and self.text[self.pos] == ".": + self.pos += 1 + seg, head_terminator = self._parse_path_head() + if seg == "": + raise InterpolationError("Empty path segment after ']'") + parts.append(seg) + elif self.pos < len(self.text) and self.text[self.pos] == "[": + self.pos += 1 + head_terminator = "[" + continue + else: + head_terminator = "}" + + if head_terminator != "}": + raise InterpolationError( + f"Unexpected character {head_terminator!r} in interpolation" + ) + self.pos += 1 # consume "}" + return _Reference(up=up, parts=tuple(parts)) + + def _parse_path_head(self) -> tuple[_PathSegment, str]: + """Read a single path segment, returning ``(segment, terminator)``. + + The terminator is one of ``"."`` (more dotted path follows), + ``"["`` (bracketed index follows), ``":"`` (resolver call body + follows), ``"}"`` (end of expression), or ``""`` (EOF). + + Segments may contain nested ``${...}`` interpolations, in + which case the segment is returned as a node tuple instead + of a string. + """ + buf: list[str] = [] + nested: list[_Node] | None = None + while self.pos < len(self.text): + ch = self.text[self.pos] + if ch in (".", "[", ":", "}"): + terminator = ch + if ch == ".": + self.pos += 1 + break + if ch == "$" and self._peek(1) == "{": + if nested is None: + nested = [] + if buf: + nested.append(_Literal("".join(buf))) + buf.clear() + nested.append(self._parse_interp()) + continue + buf.append(ch) + self.pos += 1 + else: + terminator = "" + + if nested is not None: + if buf: + nested.append(_Literal("".join(buf))) + seg: _PathSegment = tuple(nested) + else: + seg = "".join(buf) + if seg.lstrip("-").isdigit(): + seg = int(seg) + return seg, terminator + + def _parse_resolver_args(self) -> tuple[tuple[_Node, ...], ...]: + """Parse the body of a resolver call after the ':'.""" + assert self.text[self.pos] == ":" + self.pos += 1 + args: list[tuple[_Node, ...]] = [] + current: list[_Node] = [] + buf: list[str] = [] + + def flush_literal() -> None: + if buf: + current.append(_Literal("".join(buf))) + buf.clear() + + depth = 0 + while self.pos < len(self.text): + ch = self.text[self.pos] + if ch == "}" and depth == 0: + flush_literal() + args.append(tuple(current)) + return tuple(args) + if ch == "," and depth == 0: + flush_literal() + args.append(tuple(current)) + current = [] + self.pos += 1 + continue + if ch == "\\" and self.pos + 1 < len(self.text): + buf.append(self.text[self.pos + 1]) + self.pos += 2 + continue + if ch == "$" and self._peek(1) == "{": + flush_literal() + current.append(self._parse_interp()) + depth_at_call_start = depth + _ = depth_at_call_start # unused; documents intent + continue + buf.append(ch) + self.pos += 1 + raise InterpolationError("Unterminated resolver call (missing '}')") + + def _read_until(self, terminator: str) -> tuple[str, str]: + """Read literal text until ``terminator`` is the next char. + + Returns ``(text, terminator)``. Used for bracketed indices, + which do not support nested interpolation. + """ + start = self.pos + while self.pos < len(self.text): + if self.text[self.pos] == terminator: + return self.text[start : self.pos], terminator + self.pos += 1 + raise InterpolationError(f"Unterminated bracket: expected {terminator!r}") + + def _expect(self, ch: str) -> None: + if self.pos >= len(self.text) or self.text[self.pos] != ch: + raise InterpolationError( + f"Expected {ch!r} at pos {self.pos}; " + f"got {self.text[self.pos : self.pos + 1]!r}" + ) + self.pos += 1 + + +def _parse(text: str) -> tuple[_Node, ...]: + """Public-internal parse entry point.""" + return _Parser(text).parse() + + +# --------------------------------------------------------------------------- +# Evaluator +# --------------------------------------------------------------------------- + + +_MAX_DEPTH: Final = 64 + + +@dataclass +class _EvalState: + """Mutable evaluator state threaded through recursive calls. + + Tracks the in-progress reference stack for cycle detection. + """ + + root: dict[str, ComposeValue] + seen: set[tuple[int | str, ...]] = field(default_factory=set) + + +def resolve( + node: ComposeValue, + *, + root: dict[str, ComposeValue], + here: tuple[str | int, ...] = (), +) -> ComposeValue: + """Resolve every interpolation in ``node``, returning a new value. + + Walks ``node`` recursively. Strings are parsed and evaluated; + dicts and lists are descended into with ``here`` extended by the + current key or index. Whole-value substitutions preserve type + (the result of resolving ``"${a.b}"`` where ``a.b`` is an int is + an int, not a stringified int); substring substitutions coerce + to ``str``. + + Parameters + ---------- + node : ComposeValue + Value to resolve in place. Strings, dicts, lists, scalars. + root : dict[str, ComposeValue] + Top-level config; interpolations resolve against this. + here : tuple[str | int, ...], optional + Dotted path of ``node`` relative to ``root``. Used to resolve + relative references (``${.x}``). Defaults to the empty path + (top level). + + Returns + ------- + ComposeValue + Fully-resolved value, the same kind (str / int / dict / …) + as ``node`` except where an interpolation changed the type. + """ + state = _EvalState(root=root) + with _activate_root(root): + return _resolve_value(node, here, state) + + +def _resolve_value( + node: ComposeValue, + here: tuple[str | int, ...], + state: _EvalState, +) -> ComposeValue: + if isinstance(node, str): + return _resolve_string(node, here, state) + if isinstance(node, list): + return [_resolve_value(item, (*here, i), state) for i, item in enumerate(node)] + if isinstance(node, dict): + return { + key: _resolve_value(value, (*here, key), state) + for key, value in node.items() + } + return node + + +def _resolve_string( + text: str, here: tuple[str | int, ...], state: _EvalState +) -> ComposeValue: + nodes = _parse(text) + if len(nodes) == 1 and not isinstance(nodes[0], _Literal): + value = _eval_node(nodes[0], here, state, depth=0) + return value + parts: list[str] = [] + for node in nodes: + value = _eval_node(node, here, state, depth=0) + if isinstance(value, str): + parts.append(value) + else: + parts.append(_to_str(value)) + return "".join(parts) + + +def _eval_node( + node: _Node, here: tuple[str | int, ...], state: _EvalState, *, depth: int +) -> ComposeValue: + if depth > _MAX_DEPTH: + raise InterpolationError(f"Interpolation nesting exceeded {_MAX_DEPTH} levels") + if isinstance(node, _Literal): + return node.text + if isinstance(node, _Reference): + return _eval_reference(node, here, state, depth=depth) + return _eval_resolver_call(node, here, state, depth=depth) + + +def _eval_reference( + ref: _Reference, + here: tuple[str | int, ...], + state: _EvalState, + *, + depth: int, +) -> ComposeValue: + if ref.up > len(here): + raise InterpolationError( + f"Relative reference {'.' * ref.up}... walks above the " + f"root (current path: {_format_path(here)})" + ) + base_path = here[: len(here) - ref.up] if ref.up > 0 else () + resolved_parts: list[str | int] = list(base_path) + for part in ref.parts: + resolved_parts.append(_resolve_path_segment(part, here, state, depth=depth + 1)) + + cycle_key: tuple[int | str, ...] = (id(state.root),) + _hashable_path( + resolved_parts + ) + if cycle_key in state.seen: + raise InterpolationError( + f"Interpolation cycle detected at {_format_path(resolved_parts)}" + ) + state.seen.add(cycle_key) + try: + value = _walk(state.root, resolved_parts) + if isinstance(value, str | dict | list): + return _resolve_value(value, tuple(resolved_parts), state) + return value + finally: + state.seen.discard(cycle_key) + + +def _eval_resolver_call( + call: _ResolverCall, + here: tuple[str | int, ...], + state: _EvalState, + *, + depth: int, +) -> ComposeValue: + if call.name not in _RESOLVERS: + raise InterpolationError( + f"Unknown resolver {call.name!r}. Registered: {list_resolvers()}" + ) + resolved_args: list[str] = [] + for arg_nodes in call.args: + if not arg_nodes: + resolved_args.append("") + continue + if len(arg_nodes) == 1 and isinstance(arg_nodes[0], _Literal): + resolved_args.append(arg_nodes[0].text) + continue + rendered: list[str] = [] + for sub in arg_nodes: + value = _eval_node(sub, here, state, depth=depth + 1) + rendered.append(value if isinstance(value, str) else _to_str(value)) + resolved_args.append("".join(rendered)) + + try: + return _RESOLVERS[call.name](*resolved_args) + except InterpolationError: + raise + except Exception as exc: + raise InterpolationError( + f"Resolver {call.name!r} raised {type(exc).__name__}: {exc}" + ) from exc + + +def _resolve_path_segment( + seg: _PathSegment, + here: tuple[str | int, ...], + state: _EvalState, + *, + depth: int, +) -> str | int: + if isinstance(seg, str | int): + return seg + rendered: list[str] = [] + for sub in seg: + value = _eval_node(sub, here, state, depth=depth) + rendered.append(value if isinstance(value, str) else _to_str(value)) + joined = "".join(rendered) + if joined.lstrip("-").isdigit(): + return int(joined) + return joined + + +def _walk( + root: ComposeValue, path: list[str | int] | tuple[str | int, ...] +) -> ComposeValue: + cur: ComposeValue = root + for i, part in enumerate(path): + if isinstance(cur, dict): + if not isinstance(part, str): + raise InterpolationError( + f"Cannot index dict at {_format_path(path[: i + 1])} with integer" + ) + if part not in cur: + raise InterpolationError( + f"Reference {_format_path(path[: i + 1])} unresolved" + ) + cur = cur[part] + elif isinstance(cur, list): + if not isinstance(part, int): + raise InterpolationError( + f"Cannot index list at " + f"{_format_path(path[: i + 1])} with non-integer" + ) + if part < 0 or part >= len(cur): + raise InterpolationError( + f"List index out of range at " + f"{_format_path(path[: i + 1])} (len={len(cur)})" + ) + cur = cur[part] + else: + raise InterpolationError( + f"Cannot descend into scalar at {_format_path(path[:i])}" + ) + return cur + + +def _hashable_path(path: list[str | int]) -> tuple[str | int, ...]: + return tuple(path) + + +def _format_path(path: list[str | int] | tuple[str | int, ...]) -> str: + if not path: + return "" + parts: list[str] = [] + for p in path: + if isinstance(p, int): + parts.append(f"[{p}]") + else: + parts.append(f".{p}" if parts else p) + return "".join(parts) + + +def _to_str(value: ComposeValue) -> str: + if value is None: + return "null" + if isinstance(value, bool): + return "true" if value else "false" + return str(value) diff --git a/bead/config/compose/merge.py b/bead/config/compose/merge.py new file mode 100644 index 0000000..3f20c9b --- /dev/null +++ b/bead/config/compose/merge.py @@ -0,0 +1,197 @@ +"""Strict, schema-aware dict merge and dotted-key override application. + +This is the one module in the subpackage that imports ``didactic``. +It walks the merged dict against a target ``dx.Model``'s field +specifications to reject unknown keys at merge time — a stricter +behavior than validation alone, with a clearer error location. +""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import get_args, get_origin + +import didactic.api as dx + +from bead.config.compose.errors import ConfigError +from bead.config.compose.interpolation import ComposeValue + + +def strict_merge( + base: dict[str, ComposeValue], + overlay: dict[str, ComposeValue], + *, + schema: type[dx.Model], + _path: tuple[str, ...] = (), +) -> dict[str, ComposeValue]: + """Deep-merge ``overlay`` into ``base`` under ``schema``. + + Mappings overlay key-by-key; non-mapping values overwrite. Any + key in ``overlay`` not declared in ``schema``'s field specs raises + :class:`ConfigError` naming the dotted path to the offending key. + + Lists overwrite wholesale (no element-by-element merge); this + matches OmegaConf's default. + + Parameters + ---------- + base : dict[str, ComposeValue] + Lower-precedence dict. + overlay : dict[str, ComposeValue] + Higher-precedence dict. + schema : type[dx.Model] + didactic model defining allowed keys. + _path : tuple[str, ...] + Internal recursion bookkeeping; pass nothing. + + Returns + ------- + dict[str, ComposeValue] + A fresh dict; neither input is mutated. + """ + allowed = _allowed_fields(schema) + result = dict(base) + for key, overlay_value in overlay.items(): + if key not in allowed: + dotted = ".".join((*_path, key)) + raise ConfigError( + f"Unknown config key {dotted!r}; allowed: {sorted(allowed)}" + ) + nested_schema = allowed[key] + if isinstance(overlay_value, dict) and nested_schema is not None: + existing = result.get(key) + existing_dict = existing if isinstance(existing, dict) else {} + result[key] = strict_merge( + existing_dict, + overlay_value, + schema=nested_schema, + _path=(*_path, key), + ) + else: + result[key] = overlay_value + return result + + +def apply_override( + d: dict[str, ComposeValue], dotted_key: str, value: ComposeValue +) -> None: + """Set ``d[a][b][c] = value`` for ``dotted_key='a.b.c'``. + + Intermediate dicts are created as needed. Existing non-dict + intermediates raise :class:`ConfigError`. + + Parameters + ---------- + d : dict[str, ComposeValue] + Target dict, modified in place. + dotted_key : str + Dotted path. Empty segments are not allowed. + value : ComposeValue + Value to set. + """ + if not dotted_key: + raise ConfigError("Override key cannot be empty") + parts = dotted_key.split(".") + cur: dict[str, ComposeValue] = d + for part in parts[:-1]: + if not part: + raise ConfigError(f"Empty segment in override key {dotted_key!r}") + existing = cur.get(part) + if existing is None or not isinstance(existing, dict): + new_dict: dict[str, ComposeValue] = {} + cur[part] = new_dict + cur = new_dict + else: + cur = existing + last = parts[-1] + if not last: + raise ConfigError(f"Override key {dotted_key!r} ends with an empty segment") + cur[last] = value + + +def parse_override(expr: str) -> tuple[str, ComposeValue]: + """Split a CLI-style ``key=value`` override into its parts. + + The value is parsed as YAML so callers can pass typed primitives + (``--set foo.bar=0.5`` produces a float, ``--set x.y=true`` + produces a bool, ``--set z=hello`` produces a string). + + Parameters + ---------- + expr : str + Override expression, e.g. ``"paths.data_dir=/tmp"``. + + Returns + ------- + tuple[str, ComposeValue] + ``(dotted_key, parsed_value)``. + + Raises + ------ + ConfigError + If ``expr`` is missing ``=`` or has an empty key. + """ + if "=" not in expr: + raise ConfigError(f"Override {expr!r} missing '='; expected 'key=value'") + key, _, raw_value = expr.partition("=") + key = key.strip() + if not key: + raise ConfigError(f"Override {expr!r} has empty key") + + import yaml # noqa: PLC0415 + + parsed = yaml.safe_load(raw_value) + return key, parsed + + +def _allowed_fields( + schema: type[dx.Model], +) -> dict[str, type[dx.Model] | None]: + """Return the keys ``schema`` accepts, mapped to nested schemas. + + A field whose type is itself a ``dx.Model`` (directly or wrapped + in ``dx.Embed[...]`` or ``tuple[dx.Embed[...], ...]``) maps to + that nested model so the recursive walker can descend. Scalar + fields map to ``None``. + """ + field_specs = getattr(schema, "__field_specs__", None) + if field_specs is None: + return {} + allowed: dict[str, type[dx.Model] | None] = {} + for name, spec in field_specs.items(): + allowed[name] = _nested_schema(spec) + return allowed + + +def _nested_schema(spec: object) -> type[dx.Model] | None: + """Extract a nested ``dx.Model`` type from a field spec, if any.""" + annotation = getattr(spec, "annotation", None) + if annotation is None: + return None + return _unwrap_model(annotation) + + +def _unwrap_model(annotation: object) -> type[dx.Model] | None: + """Walk a type annotation looking for a single ``dx.Model`` subclass. + + Handles: + - bare ``SomeModel`` + - ``dx.Embed[SomeModel]`` + - ``dx.Embed[SomeModel] | None`` + - ``tuple[dx.Embed[SomeModel], ...]`` + - ``dict[str, SomeModel]`` + """ + if isinstance(annotation, type) and issubclass(annotation, dx.Model): + return annotation + origin = get_origin(annotation) + args: Iterable[object] = get_args(annotation) + if origin is None: + return None + candidates: list[type[dx.Model]] = [] + for arg in args: + inner = _unwrap_model(arg) + if inner is not None: + candidates.append(inner) + if len(candidates) == 1: + return candidates[0] + return None diff --git a/bead/config/compose/pipeline.py b/bead/config/compose/pipeline.py new file mode 100644 index 0000000..836e12b --- /dev/null +++ b/bead/config/compose/pipeline.py @@ -0,0 +1,129 @@ +"""End-to-end compose pipeline. + +Ties together :mod:`~bead.config.compose.sources`, +:mod:`~bead.config.compose.merge`, and +:mod:`~bead.config.compose.interpolation` into a single +:func:`compose` entry point that takes file paths and overrides and +returns a validated didactic Model. +""" + +from __future__ import annotations + +import copy +from collections.abc import Sequence +from pathlib import Path + +import didactic.api as dx + +# Resolvers are registered as a side effect of importing +# ``bead.config.compose.resolvers``; pull the module in here so that +# any code that imports ``compose`` has the built-in resolvers +# available. The wildcard suppresses pyright's unused-import warning +# while keeping the side-effect explicit. +from bead.config.compose import resolvers as _builtin_resolvers +from bead.config.compose.errors import ConfigError +from bead.config.compose.interpolation import ComposeValue, resolve +from bead.config.compose.merge import ( + apply_override, + parse_override, + strict_merge, +) +from bead.config.compose.sources import load_one, resolve_defaults_entry + +_ = _builtin_resolvers + + +def compose[M: dx.Model]( + config_path: Path | str | None = None, + *, + schema: type[M], + profile_dict: dict[str, ComposeValue] | None = None, + overrides: Sequence[str] = (), + extra: Sequence[Path | str] = (), +) -> M: + """Compose, interpolate, and validate a config of type ``schema``. + + Precedence (lowest to highest): + + 1. ``profile_dict`` — caller-supplied base. Empty when ``None``. + 2. Each path listed in the YAML's ``defaults: [...]`` key, + loaded left-to-right (paths resolved relative to the + primary YAML's parent directory). + 3. The primary YAML body (everything except ``defaults``). + 4. Each ``extra`` overlay file, in order. + 5. ``overrides`` — dotted-key ``key=value`` strings. + + Interpolation is resolved last via + :func:`~bead.config.compose.interpolation.resolve`; the resolved + dict is then validated by ``schema.model_validate(...)``. + + Parameters + ---------- + config_path : Path | str | None, optional + Primary YAML or TOML file. ``None`` skips file loading and + merges only ``profile_dict`` + ``overrides``. + schema : type[M] + Target didactic model. Drives strict-key enforcement and + final validation. + profile_dict : dict[str, ComposeValue] | None, optional + Pre-loaded base, typically a profile dump. Defaults to + ``None``. + overrides : Sequence[str], optional + CLI-style overrides (``["paths.data_dir=/tmp"]``). YAML-parsed + values; later entries beat earlier ones. + extra : Sequence[Path | str], optional + Additional overlay files merged after the primary YAML. + + Returns + ------- + M + Fully composed, interpolated, and validated model. + + Raises + ------ + ConfigError + For malformed configs (unknown keys, bad ``defaults`` entries, + malformed overrides). + InterpolationError + For unresolved ``${...}`` expressions or cycles. + """ + accumulated: dict[str, ComposeValue] = ( + copy.deepcopy(profile_dict) if profile_dict else {} + ) + + if config_path is not None: + primary_path = Path(config_path) + primary = load_one(primary_path) + defaults_list = primary.pop("defaults", None) + if defaults_list is not None: + if not isinstance(defaults_list, list): + raise ConfigError( + f"'defaults' in {primary_path} must be a list of " + f"strings; got {type(defaults_list).__name__}" + ) + anchor = primary_path.parent + for entry in defaults_list: + if not isinstance(entry, str): + raise ConfigError( + f"'defaults' entries must be strings; got " + f"{type(entry).__name__}: {entry!r}" + ) + overlay = load_one(resolve_defaults_entry(entry, anchor=anchor)) + accumulated = strict_merge(accumulated, overlay, schema=schema) + accumulated = strict_merge(accumulated, primary, schema=schema) + + for extra_path in extra: + overlay = load_one(extra_path) + accumulated = strict_merge(accumulated, overlay, schema=schema) + + for raw in overrides: + key, value = parse_override(raw) + apply_override(accumulated, key, value) + + resolved = resolve(accumulated, root=accumulated) + if not isinstance(resolved, dict): + raise ConfigError( + f"Resolved config root is not a mapping (got {type(resolved).__name__})" + ) + + return schema.model_validate(resolved) diff --git a/bead/config/compose/resolvers.py b/bead/config/compose/resolvers.py new file mode 100644 index 0000000..bd636f8 --- /dev/null +++ b/bead/config/compose/resolvers.py @@ -0,0 +1,142 @@ +"""Built-in resolvers registered at subpackage import time. + +The set mirrors OmegaConf's standard resolver library so existing +configurations written for OmegaConf can be loaded with minimal +changes. +""" + +from __future__ import annotations + +import base64 +import os +import warnings +from typing import Final + +from bead.config.compose.errors import InterpolationError +from bead.config.compose.interpolation import register_resolver + + +def _oc_env(*args: str) -> str: + """``${oc.env:VAR}`` / ``${oc.env:VAR,default}``.""" + if not args: + raise InterpolationError("oc.env requires at least one argument") + var = args[0] + if var in os.environ: + return os.environ[var] + if len(args) >= 2: + return ",".join(args[1:]) + raise InterpolationError( + f"Environment variable {var!r} is not set and no default given" + ) + + +def _oc_select(*args: str) -> str: + """``${oc.select:path,default}``. + + Returns ``default`` if ``path`` resolves to a missing or null + value. The path lookup itself happens through the normal + reference machinery; ``oc.select`` exists purely so users can + *opt-in* to silently using a default when a path is absent. + + Because this resolver runs *after* its arguments have been + interpolated, the typical usage ``${oc.select:${a.b},fallback}`` + relies on the inner interpolation raising InterpolationError — + which we catch here and replace with the default. + """ + if not args: + raise InterpolationError("oc.select requires a path and an optional default") + # If the inner ${...} raised, the surrounding evaluator caught it + # and re-raised; we won't reach here. The "select" semantics + # therefore live in the evaluator's _eval_resolver_call by + # intercepting argument evaluation. See note below. + if len(args) == 1: + return args[0] + return args[0] if args[0] != "" else ",".join(args[1:]) + + +def _oc_decode(*args: str) -> str: + """``${oc.decode:value,encoding}`` — decode a base64 string. + + Encodings supported: ``base64``, ``ascii``, ``utf-8`` (passthrough). + Defaults to ``base64`` when only one argument is supplied. + """ + if not args: + raise InterpolationError("oc.decode requires at least a value") + value = args[0] + encoding = args[1] if len(args) >= 2 else "base64" + if encoding == "base64": + return base64.b64decode(value).decode("utf-8") + if encoding in ("ascii", "utf-8"): + return value + raise InterpolationError(f"oc.decode: unknown encoding {encoding!r}") + + +def _oc_deprecated(*args: str) -> str: + """``${oc.deprecated:new_path}`` — emit a deprecation warning. + + Returns the interpolated value of ``new_path``. Used in configs + to alias old keys to new ones. + """ + if not args: + raise InterpolationError("oc.deprecated requires a replacement path") + warnings.warn( + f"config key is deprecated; use {args[0]!r} instead", + DeprecationWarning, + stacklevel=4, + ) + return args[0] + + +def _oc_create(*args: str) -> str: + """``${oc.create:value}`` — passthrough. + + Provided for OmegaConf-compat. In OmegaConf, ``oc.create`` wraps + a structure into a fresh DictConfig/ListConfig; here we operate + on plain dicts so the wrapper is a no-op. + """ + return ",".join(args) + + +def _oc_dict_keys(*args: str) -> str: + """``${oc.dict.keys:path}`` — comma-joined keys of a dict at ``path``. + + Since resolvers must return a single value, the keys are + rendered as a comma-separated string. Use the dotted-path + reference syntax directly (e.g. ``${section.keys}`` if you + pre-compute the list) when you need a typed list. + """ + if not args: + raise InterpolationError("oc.dict.keys requires a path") + raise InterpolationError( + "oc.dict.keys requires its argument to be a dict; got a " + "string. Use ${oc.dict.keys:${path.to.dict}} instead." + ) + + +def _oc_dict_values(*args: str) -> str: + if not args: + raise InterpolationError("oc.dict.values requires a path") + raise InterpolationError( + "oc.dict.values requires its argument to be a dict; got a " + "string. Use ${oc.dict.values:${path.to.dict}} instead." + ) + + +_BUILTINS: Final[dict[str, object]] = { + "oc.env": _oc_env, + "oc.select": _oc_select, + "oc.decode": _oc_decode, + "oc.deprecated": _oc_deprecated, + "oc.create": _oc_create, + "oc.dict.keys": _oc_dict_keys, + "oc.dict.values": _oc_dict_values, +} + + +def register_builtins(*, replace: bool = False) -> None: + """Register every built-in resolver. Called at import time.""" + for name, fn in _BUILTINS.items(): + register_resolver(name, fn, replace=replace) # type: ignore[arg-type] + + +register_builtins(replace=True) diff --git a/bead/config/compose/sources.py b/bead/config/compose/sources.py new file mode 100644 index 0000000..8f9aefd --- /dev/null +++ b/bead/config/compose/sources.py @@ -0,0 +1,113 @@ +"""File-format dispatch for the compose pipeline. + +Loads YAML (``.yaml`` / ``.yml``) and TOML (``.toml``) files into the +dict-of-:data:`~bead.config.compose.interpolation.ComposeValue` shape +the rest of the subpackage operates on. +""" + +from __future__ import annotations + +import tomllib +from pathlib import Path + +import yaml + +from bead.config.compose.errors import ConfigError +from bead.config.compose.interpolation import ComposeValue + +_YAML_SUFFIXES = frozenset({".yaml", ".yml"}) +_TOML_SUFFIXES = frozenset({".toml"}) + + +def load_one(path: Path | str) -> dict[str, ComposeValue]: + """Load a YAML or TOML file as a dict. + + Parameters + ---------- + path : Path | str + Path to a ``.yaml`` / ``.yml`` / ``.toml`` file. + + Returns + ------- + dict[str, ComposeValue] + Loaded content. An empty file yields an empty dict. + + Raises + ------ + FileNotFoundError + If ``path`` does not exist. + ConfigError + If the suffix is unrecognized or the parsed content is not a + top-level mapping. + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Configuration file not found: {path}") + suffix = path.suffix.lower() + + if suffix in _YAML_SUFFIXES: + with path.open(encoding="utf-8") as fp: + loaded = yaml.safe_load(fp) + elif suffix in _TOML_SUFFIXES: + with path.open("rb") as fp_b: + loaded = tomllib.load(fp_b) + else: + raise ConfigError( + f"Unsupported config suffix {suffix!r} for {path}. " + f"Expected one of: {sorted(_YAML_SUFFIXES | _TOML_SUFFIXES)}" + ) + + if loaded is None: + return {} + if not isinstance(loaded, dict): + raise ConfigError( + f"Top-level config in {path} must be a mapping; got {type(loaded).__name__}" + ) + return loaded + + +def resolve_defaults_entry(entry: str, *, anchor: Path) -> Path: + """Resolve a ``defaults: [...]`` entry to a concrete file path. + + Entries name a file relative to ``anchor`` (the parent of the + YAML containing the ``defaults`` list). An entry may include or + omit the suffix: + + - ``protocol/argument_structure`` is resolved to whichever of + ``protocol/argument_structure.yaml``, + ``protocol/argument_structure.yml``, or + ``protocol/argument_structure.toml`` exists. + - ``protocol/argument_structure.yaml`` is taken verbatim. + + Parameters + ---------- + entry : str + Path string from the YAML ``defaults`` list. + anchor : Path + Directory the entry is resolved against. + + Returns + ------- + Path + Existing path on disk. + + Raises + ------ + ConfigError + If no candidate path exists. + """ + raw = (anchor / entry).resolve() + if raw.suffix.lower() in _YAML_SUFFIXES | _TOML_SUFFIXES: + if raw.exists(): + return raw + raise ConfigError(f"defaults entry {entry!r} not found at {raw}") + for suffix in (".yaml", ".yml", ".toml"): + candidate = raw.with_suffix(suffix) + if candidate.exists(): + return candidate + raise ConfigError( + f"defaults entry {entry!r} not found; tried " + f"{raw.with_suffix('.yaml')}, " + f"{raw.with_suffix('.yml')}, " + f"{raw.with_suffix('.toml')}" + ) diff --git a/bead/config/loader.py b/bead/config/loader.py index 638ef55..82a9744 100644 --- a/bead/config/loader.py +++ b/bead/config/loader.py @@ -1,150 +1,89 @@ -"""Configuration loading from YAML files. +"""Bead-specific entrypoint to the compose pipeline. -This module provides functionality for loading configurations from YAML files, -merging configurations from multiple sources, and applying configuration overrides. +A thin wrapper around :func:`bead.config.compose.compose` that binds +the schema to :class:`~bead.config.config.BeadConfig` and starts from +the profile defaults declared in :mod:`bead.config.profiles`. """ -from pathlib import Path -from typing import Any +from __future__ import annotations -import yaml +import json +from collections.abc import Sequence +from pathlib import Path +# Importing this module registers bead-specific resolvers +# (${bead.anchor:...}, ${bead.path:...}) against the compose +# interpolation engine. +from bead.config import resolvers as _bead_resolvers +from bead.config.compose import compose +from bead.config.compose.interpolation import ComposeValue from bead.config.config import BeadConfig from bead.config.profiles import get_profile - -def merge_configs(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: - """Deep merge two configuration dictionaries. - - Recursively merges override into base, with override values taking precedence. - - Parameters - ---------- - base : dict[str, Any] - Base configuration dictionary. - override : dict[str, Any] - Override configuration dictionary. - - Returns - ------- - dict[str, Any] - Merged configuration dictionary. - - Examples - -------- - >>> base = {"a": 1, "b": {"c": 2}} - >>> override = {"b": {"d": 3}} - >>> merge_configs(base, override) - {'a': 1, 'b': {'c': 2, 'd': 3}} - """ - result = base.copy() - for key, value in override.items(): - if key in result and isinstance(result[key], dict) and isinstance(value, dict): - result[key] = merge_configs(result[key], value) # type: ignore[arg-type] - else: - result[key] = value - return result - - -def load_yaml_file(path: Path | str) -> dict[str, Any]: - """Load YAML file and return as dictionary. - - Parameters - ---------- - path : Path | str - Path to YAML file. - - Returns - ------- - dict[str, Any] - Parsed YAML content. - - Raises - ------ - FileNotFoundError - If file doesn't exist. - yaml.YAMLError - If YAML is malformed. - """ - path = Path(path) if isinstance(path, str) else path - - if not path.exists(): - raise FileNotFoundError(f"Configuration file not found: {path}") - - try: - with open(path) as f: - content = yaml.safe_load(f) - # handle empty files - return content if content is not None else {} - except yaml.YAMLError as e: - raise yaml.YAMLError(f"Failed to parse YAML file {path}: {e}") from e +_ = _bead_resolvers def load_config( config_path: Path | str | None = None, + *, profile: str = "default", - **overrides: Any, + overrides: Sequence[str] = (), + extra: Sequence[Path | str] = (), + **kw_overrides: ComposeValue, ) -> BeadConfig: - """Load configuration from YAML file with optional overrides. + """Compose a :class:`BeadConfig` from a profile, file, and overrides. Precedence (lowest to highest): - 1. Profile defaults - 2. YAML file values - 3. Keyword overrides + + 1. Profile defaults (``bead.config.profiles.get_profile``). + 2. Each path listed in the primary YAML's ``defaults: [...]`` + key, in order. + 3. The primary YAML body. + 4. Each ``extra`` overlay file, in order. + 5. ``overrides`` — dotted-key ``key=value`` strings. + 6. ``kw_overrides`` — legacy ``key__sub=value`` keyword form. + Each is rewritten as ``"key.sub=value"`` and merged after + ``overrides``. + + Interpolation is resolved last; the resolved dict is validated as + a :class:`BeadConfig`. Parameters ---------- - config_path : Path | str | None - Path to YAML config file. If None, uses profile defaults. - profile : str - Profile to use as base (default, dev, prod, test). - **overrides : Any - Direct overrides for config values. + config_path : Path | str | None, optional + Primary YAML or TOML file. + profile : str, optional + Profile name (``"default"``, ``"dev"``, ``"prod"``, + ``"test"``). + overrides : Sequence[str], optional + CLI-style overrides (``["paths.data_dir=/tmp"]``). + extra : Sequence[Path | str], optional + Additional overlay files merged after the primary YAML. + **kw_overrides : ComposeValue + Legacy keyword overrides; ``__`` separates nested levels. Returns ------- BeadConfig - Loaded and merged configuration. - - Raises - ------ - FileNotFoundError - If config_path is specified but doesn't exist. - yaml.YAMLError - If YAML file is malformed. - ValidationError - If configuration is invalid. - - Examples - -------- - >>> config = load_config(profile="dev") - >>> config.profile - 'dev' - >>> config = load_config(config_path="config.yaml", logging__level="DEBUG") - >>> config.logging.level - 'DEBUG' + Fully composed and validated configuration. """ - # start with profile defaults (JSON-shape so UUIDs/Paths round-trip) - import json # noqa: PLC0415 - - base_config: dict[str, Any] = json.loads(get_profile(profile).model_dump_json()) - - # merge with YAML file if provided - if config_path is not None: - yaml_config = load_yaml_file(config_path) - base_config = merge_configs(base_config, yaml_config) - - # convert overrides with __ syntax to nested dicts - if overrides: - override_dict: dict[str, Any] = {} - for key, value in overrides.items(): - parts = key.split("__") - current = override_dict - for part in parts[:-1]: - if part not in current: - current[part] = {} - current = current[part] - current[parts[-1]] = value - base_config = merge_configs(base_config, override_dict) - - return BeadConfig.model_validate(base_config) + profile_dict: dict[str, ComposeValue] = json.loads( + get_profile(profile).model_dump_json() + ) + + all_overrides: list[str] = list(overrides) + for key, value in kw_overrides.items(): + dotted = key.replace("__", ".") + # Use yaml.safe_dump to preserve the value's type when it's + # parsed back in parse_override (e.g. int / float / bool). + import yaml # noqa: PLC0415 + + all_overrides.append(f"{dotted}={yaml.safe_dump(value).strip()}") + + return compose( + config_path, + schema=BeadConfig, + profile_dict=profile_dict, + overrides=all_overrides, + extra=extra, + ) diff --git a/bead/config/protocol.py b/bead/config/protocol.py index 94157cb..62b70ef 100644 --- a/bead/config/protocol.py +++ b/bead/config/protocol.py @@ -27,7 +27,12 @@ import didactic.api as dx from bead.data.base import BeadBaseModel -from bead.protocol.anchor import ResponseSpace, SemanticAnchor, SemanticPoles +from bead.protocol.anchor import ( + ResponseSpace, + ScaleType, + SemanticAnchor, + SemanticPoles, +) from bead.protocol.context import get_context_predicate from bead.protocol.drift import ( DriftGuard, @@ -120,6 +125,13 @@ class AnchorSpec(BeadBaseModel): Low-pole label, when ordered. Defaults to ``None``. semantic_pole_high : str | None High-pole label, when ordered. Defaults to ``None``. + scale_type : ScaleType | None + Explicit scale-type override. Set to + :attr:`ScaleType.FORCED_CHOICE` for N-alternative + forced-choice questions whose options are positional labels + (``("first", "second")``) and whose per-item alternatives + vary across items. Defaults to ``None`` (inferred from + ``options`` and ``is_ordered``). required_span_labels : frozenset[str] Span labels every realization must reference. Defaults to the empty set. @@ -142,6 +154,7 @@ class AnchorSpec(BeadBaseModel): is_ordered: bool = True semantic_pole_low: str | None = None semantic_pole_high: str | None = None + scale_type: ScaleType | None = None required_span_labels: frozenset[str] = dx.field(default_factory=frozenset) required_keywords: frozenset[str] = dx.field(default_factory=frozenset) embedding_center: tuple[float, ...] | None = None @@ -178,6 +191,7 @@ def build(self) -> SemanticAnchor: options=self.options, is_ordered=self.is_ordered, semantic_poles=poles, + scale_type=self.scale_type, ) return SemanticAnchor( name=self.name, diff --git a/bead/config/resolvers.py b/bead/config/resolvers.py new file mode 100644 index 0000000..8dc2fbb --- /dev/null +++ b/bead/config/resolvers.py @@ -0,0 +1,115 @@ +"""Bead-aware interpolation resolvers. + +Registered at import time against the +:mod:`bead.config.compose.interpolation` registry. These resolvers +reference bead-side concepts (paths, anchors) and therefore would +*not* be part of an extracted ``didacticonf`` package. + +Resolvers +--------- +- ``${bead.path:rel}`` — join ``rel`` against the value at + ``paths.data_dir`` in the composed config. Convenient shorthand + for ``${paths.data_dir}/rel``. + +The compose pipeline sets the in-flight root via a contextvar so +:func:`bead.config.compose.active_root` returns the dict being +interpolated. + +Post-validation anchor resolution (``${bead.anchor:name[,attr]}``) +needs a validated :class:`AnnotationProtocol` and therefore lives at +post-validation time. See :func:`resolve_anchor_attributes`. +""" + +from __future__ import annotations + +import re +from pathlib import PurePosixPath +from typing import cast + +from bead.config.compose import active_root, register_resolver, resolve +from bead.config.compose.errors import InterpolationError +from bead.config.compose.interpolation import ComposeValue + + +def _bead_path(*args: str) -> str: + """``${bead.path:rel}`` — join ``rel`` against ``paths.data_dir``. + + Reads ``paths.data_dir`` from the in-flight composed root. The + resulting string uses forward slashes (``PurePosixPath``); + callers wrap with :class:`pathlib.Path` as needed. + """ + if not args: + raise InterpolationError("bead.path requires a relative path") + rel = ",".join(args) + + root = active_root() + if root is None: + raise InterpolationError( + "bead.path called outside an active compose() pipeline" + ) + paths_section = root.get("paths") + if not isinstance(paths_section, dict): + raise InterpolationError("bead.path requires a 'paths' section in the config") + data_dir: ComposeValue = paths_section.get("data_dir") + if data_dir is None: + raise InterpolationError("bead.path requires paths.data_dir to be set") + if not isinstance(data_dir, str): + resolved = resolve(data_dir, root=root) + if not isinstance(resolved, str): + raise InterpolationError( + f"paths.data_dir resolved to {type(resolved).__name__}, expected str" + ) + data_dir = resolved + + return str(PurePosixPath(data_dir) / rel) + + +register_resolver("bead.path", _bead_path, replace=True) + + +# --------------------------------------------------------------------------- +# Post-validation anchor resolution +# --------------------------------------------------------------------------- + + +_ANCHOR_PATTERN: re.Pattern[str] = re.compile(r"\$\{bead\.anchor:([^}]+)\}") + + +def resolve_anchor_attributes( + text: str, + *, + protocol: object, +) -> str: + """Replace ``${bead.anchor:name[,attr]}`` references in ``text``. + + Used by application code after the protocol is materialized. + ``attr`` defaults to ``"canonical_prompt"`` and may be any + attribute name on :class:`~bead.protocol.SemanticAnchor`. + + Parameters + ---------- + text : str + Text containing ``${bead.anchor:name}`` or + ``${bead.anchor:name,attr}`` expressions. + protocol : AnnotationProtocol + Validated protocol whose ``family_by_name(name).anchor`` is + consulted. + + Returns + ------- + str + ``text`` with every recognized expression substituted. + """ + + def _replace(match: re.Match[str]) -> str: + spec = match.group(1) + if "," in spec: + name, _, attr = spec.partition(",") + name, attr = name.strip(), attr.strip() + else: + name, attr = spec.strip(), "canonical_prompt" + family = cast("object", protocol.family_by_name(name)) # type: ignore[attr-defined] + anchor = cast("object", family.anchor) # type: ignore[attr-defined] + return str(getattr(anchor, attr)) + + return _ANCHOR_PATTERN.sub(_replace, text) diff --git a/bead/protocol/anchor.py b/bead/protocol/anchor.py index 7b0ba95..492bd0b 100644 --- a/bead/protocol/anchor.py +++ b/bead/protocol/anchor.py @@ -15,6 +15,7 @@ from __future__ import annotations +from enum import StrEnum from typing import Self import didactic.api as dx @@ -22,6 +23,35 @@ from bead.data.base import BeadBaseModel +class ScaleType(StrEnum): + """Classification of response-scale structure. + + Attributes + ---------- + BINARY : str + Two unordered options with fixed, content-bearing labels (for + example ``("no", "yes")``). Modeled via Bernoulli likelihoods. + Wire value: ``"binary"``. + ORDINAL : str + Ordered options forming an ordinal scale. Wire value: + ``"ordinal"``. + NOMINAL : str + Unordered multi-option scale. Wire value: ``"nominal"``. + FORCED_CHOICE : str + N-alternative forced choice with *positional* response labels + (for example ``("first", "second")``). The per-item content + of the alternatives varies between items and lives on the + :class:`~bead.items.item.Item` itself; the encoding's labels + identify which alternative was chosen. Wire value: + ``"forced_choice"``. + """ + + BINARY = "binary" + ORDINAL = "ordinal" + NOMINAL = "nominal" + FORCED_CHOICE = "forced_choice" + + class SemanticPoles(BeadBaseModel): """Pole labels for an ordered response scale. @@ -72,6 +102,14 @@ class ResponseSpace(BeadBaseModel): Pole labels for ordered scales (for example ``low="never"``, ``high="always"``). ``None`` for unordered (categorical) response spaces. Defaults to ``None``. + scale_type : ScaleType | None + Explicit scale-type classification. ``None`` (default) leaves + :func:`~bead.protocol.encode_response_space` to infer the kind + from ``options`` and ``is_ordered``. Set explicitly to + :attr:`ScaleType.FORCED_CHOICE` when the labels are positional + and the per-item alternatives vary across items (the + 2-option-unordered shape that the inference rule would + otherwise classify as ``BINARY``). Examples -------- @@ -92,6 +130,7 @@ class ResponseSpace(BeadBaseModel): options: tuple[str, ...] is_ordered: bool = True semantic_poles: dx.Embed[SemanticPoles] | None = None + scale_type: ScaleType | None = None def __len__(self) -> int: """Return the number of response options.""" diff --git a/bead/protocol/encoding.py b/bead/protocol/encoding.py index 8318ebe..c5e2e13 100644 --- a/bead/protocol/encoding.py +++ b/bead/protocol/encoding.py @@ -26,32 +26,18 @@ class based on the scale type. from __future__ import annotations -from enum import StrEnum from typing import Self import didactic.api as dx from bead.data.base import BeadBaseModel -from bead.protocol.anchor import ResponseSpace, SemanticPoles +from bead.protocol.anchor import ResponseSpace, ScaleType, SemanticPoles - -class ScaleType(StrEnum): - """Classification of response-scale structure. - - Attributes - ---------- - BINARY : str - Two unordered options. Wire value: ``"binary"``. - ORDINAL : str - Ordered options forming an ordinal scale. Wire value: - ``"ordinal"``. - NOMINAL : str - Unordered multi-option scale. Wire value: ``"nominal"``. - """ - - BINARY = "binary" - ORDINAL = "ordinal" - NOMINAL = "nominal" +__all__ = [ + "ResponseEncoding", + "ScaleType", + "encode_response_space", +] class ResponseEncoding(BeadBaseModel): @@ -126,6 +112,11 @@ def _check_levels_match_labels(self) -> Self: f"BINARY scale must have exactly 2 levels, got " f"{self.n_levels} in encoding {self.name!r}" ) + if self.scale_type == ScaleType.FORCED_CHOICE and self.n_levels < 2: + raise ValueError( + f"FORCED_CHOICE scale must have at least 2 levels, " + f"got {self.n_levels} in encoding {self.name!r}" + ) return self @property @@ -143,6 +134,11 @@ def is_nominal(self) -> bool: """Whether the response scale is unordered multi-option.""" return self.scale_type == ScaleType.NOMINAL + @property + def is_forced_choice(self) -> bool: + """Whether the response scale uses positional forced-choice labels.""" + return self.scale_type == ScaleType.FORCED_CHOICE + def label_to_index(self, label: str) -> int: """Convert a response label to its integer index. @@ -221,12 +217,15 @@ def _classify_scale(response_space: ResponseSpace) -> ScaleType: def encode_response_space( name: str, response_space: ResponseSpace, + *, + scale_type: ScaleType | None = None, ) -> ResponseEncoding: """Build a :class:`ResponseEncoding` from a :class:`ResponseSpace`. This is the primary bridge from the protocol layer to the modeling layer. The resulting encoding shares its labels with the response - space and inherits the space's ordering as a :class:`ScaleType`. + space and inherits the space's ordering as a :class:`ScaleType`, + unless ``scale_type`` is set to override the inferred kind. Parameters ---------- @@ -235,6 +234,11 @@ def encode_response_space( ``"completion"``). response_space : ResponseSpace The response space to encode. + scale_type : ScaleType | None, optional + Override the kind inferred from the response space. Required + when declaring a forced-choice encoding, since forced-choice + and binary share the "two unordered options" shape but are + modeled differently. Returns ------- @@ -251,12 +255,26 @@ def encode_response_space( >>> enc.is_binary True + + >>> rs = ResponseSpace( + ... options=("first", "second"), is_ordered=False + ... ) + >>> enc = encode_response_space( + ... "acceptability", rs, scale_type=ScaleType.FORCED_CHOICE + ... ) + >>> enc.is_forced_choice + True """ - scale_type = _classify_scale(response_space) + if scale_type is not None: + resolved = scale_type + elif response_space.scale_type is not None: + resolved = response_space.scale_type + else: + resolved = _classify_scale(response_space) return ResponseEncoding( name=name, n_levels=len(response_space.options), - scale_type=scale_type, + scale_type=resolved, labels=response_space.options, semantic_poles=response_space.semantic_poles, ) diff --git a/bead/protocol/items.py b/bead/protocol/items.py index 786e1ae..51def8f 100644 --- a/bead/protocol/items.py +++ b/bead/protocol/items.py @@ -59,6 +59,7 @@ ScaleType.BINARY: "binary", ScaleType.ORDINAL: "ordinal_scale", ScaleType.NOMINAL: "categorical", + ScaleType.FORCED_CHOICE: "forced_choice", } """The canonical :class:`ScaleType` → :class:`TaskType` mapping.""" @@ -103,7 +104,11 @@ def family_to_item_template( populate :attr:`TaskSpec.scale_bounds` (``0`` to ``n_levels - 1``) and :attr:`TaskSpec.scale_labels` (one :class:`ScalePointLabel` per option). Binary and nominal scales - populate :attr:`TaskSpec.options`. + populate :attr:`TaskSpec.options` with the anchor's labels. + Forced-choice scales leave :attr:`TaskSpec.options` unset (the + per-item alternatives live on each :class:`Item` rather than on + the template); the anchor's labels remain accessible via + ``family.anchor.response_space.options``. The ``prompt`` field of the template's :class:`TaskSpec` is the anchor's canonical prompt (with ``[[label]]`` references intact); @@ -140,6 +145,12 @@ def family_to_item_template( for i, label in enumerate(encoding.labels) ) options: tuple[str, ...] | None = None + elif encoding.is_forced_choice: + # forced-choice options live on each Item (the pair-specific + # text); the template carries no per-template options. + scale_bounds = None + scale_labels = () + options = None else: scale_bounds = None scale_labels = () diff --git a/docs/user-guide/api/training.md b/docs/user-guide/api/training.md index 6a4c6a3..5e4c69c 100644 --- a/docs/user-guide/api/training.md +++ b/docs/user-guide/api/training.md @@ -153,14 +153,9 @@ model = ForcedChoiceModel(config=config) # Prepare training data labels = [0, 1, 0, 1, 0] # Human judgments (0 or 1 for 2AFC) -participant_ids = ["p1", "p1", "p2", "p2", "p1"] # Participant IDs -# Train model -model.train( - items=training_items, - labels=labels, - participant_ids=participant_ids, -) +# Train model (fixed-effects mode does not use participant_ids) +model.train(items=training_items, labels=labels) # After training, predict on new items print(f"Model trained on {len(training_items)} items") diff --git a/docs/user-guide/configuration.md b/docs/user-guide/configuration.md index 78c278d..b53f27e 100644 --- a/docs/user-guide/configuration.md +++ b/docs/user-guide/configuration.md @@ -81,18 +81,44 @@ from bead.config import load_from_env config = load_from_env(config) ``` -Merge multiple configurations: +Compose multiple configurations and apply CLI-style overrides: ```python -from bead.config import merge_configs +from bead.config import load_config + +# extra files overlay after the primary YAML +# overrides are dotted-key=value strings (YAML-parsed for typing) +config = load_config( + "config.yaml", + extra=["overlays/local.yaml"], + overrides=["paths.data_dir=/tmp/data", "logging.level=DEBUG"], +) +``` -base = load_config("base.yaml") -overrides = load_config("overrides.yaml") +Configs may also reference each other through a top-level +`defaults:` list (paths resolve next to the primary YAML; bare +names resolve to `.yaml` or `.toml`): -# Later configs override earlier ones -merged = merge_configs([base, overrides]) +```yaml +defaults: + - protocol/argument_structure # protocol/argument_structure.yaml + - logging/verbose +paths: + data_dir: "${oc.env:BEAD_DATA,/tmp/data}" + out_dir: "${paths.data_dir}/out" ``` +Interpolation follows the OmegaConf grammar: +`${section.field}` absolute references, `${.x}` / `${..y}` +relative references, `${a.b[0]}` and `${a.b.0}` list indexing, +`${a.${b}}` nested expressions, `\${literal}` escape, and the +built-in resolvers (`oc.env`, `oc.select`, `oc.decode`, +`oc.deprecated`, `oc.create`, `oc.dict.keys`, `oc.dict.values`). +Register custom resolvers with +`bead.config.compose.register_resolver(name, fn)`. + +TOML configs (`.toml`) load the same way as YAML. + ## Configuration Profiles Use predefined profiles for different environments: diff --git a/docs/user-guide/protocols.md b/docs/user-guide/protocols.md index fa4e17e..156174c 100644 --- a/docs/user-guide/protocols.md +++ b/docs/user-guide/protocols.md @@ -405,6 +405,34 @@ for realization, item in pairs: in the active-learning registry. There is no other mapping: every protocol family produces exactly one `ItemTemplate`. +### Forced-choice anchors + +`ScaleType.FORCED_CHOICE` covers N-alternative forced-choice +questions where the *response space* is a fixed positional label +set (e.g. `("first", "second")`) but the per-item alternatives +vary across items. Declare the scale type explicitly on the +anchor — it cannot be inferred from the response space alone: + +```python +anchor = SemanticAnchor( + name="acceptability", + canonical_prompt="Which sentence sounds more natural?", + response_space=ResponseSpace( + options=("first", "second"), + is_ordered=False, + scale_type=ScaleType.FORCED_CHOICE, + ), + ..., +) +``` + +`family_to_item_template` maps `FORCED_CHOICE` to +`task_type="forced_choice"` with `task_spec.options=None` +(per-item alternatives live on each `Item`); the active-learning +registry routes the resulting encoding to `ForcedChoiceModel`. + +In YAML, set `scale_type: "forced_choice"` on the `AnchorSpec`. + ## Bridging to deployment `bead.deployment.protocol_trials.protocol_to_jspsych_trials` is the diff --git a/gallery/eng/argument_structure/Makefile b/gallery/eng/argument_structure/Makefile index 4e3cd2d..9fdcad0 100644 --- a/gallery/eng/argument_structure/Makefile +++ b/gallery/eng/argument_structure/Makefile @@ -89,9 +89,22 @@ all: data pipeline-dry-run ## Run complete pipeline with test data @echo "$(GREEN)✓ Complete pipeline executed successfully$(NC)" .PHONY: data -data: lexicons templates fill-templates cross-product 2afc-pairs ## Generate all data files +data: validate-protocol lexicons templates fill-templates cross-product 2afc-pairs ## Generate all data files @echo "$(GREEN)✓ All data files generated$(NC)" +.PHONY: validate-protocol +validate-protocol: ## Validate the protocol section of config.yaml builds cleanly + @echo "$(BLUE)Validating protocol declaration...$(NC)" + @$(PYTHON) -c "from protocol import build_protocol; \ + p = build_protocol('config.yaml'); \ + assert p.families, 'No families declared in protocol'; \ + anchor = p.families[0].anchor; \ + print(' protocol:', p.name); \ + print(' family:', anchor.name); \ + print(' prompt:', anchor.canonical_prompt); \ + print(' scale_type:', anchor.response_space.scale_type)" + @echo "$(GREEN)✓ Protocol validated$(NC)" + # ============================================================================ # Data Generation Targets # ============================================================================ diff --git a/gallery/eng/argument_structure/README.md b/gallery/eng/argument_structure/README.md index 0abc2a8..3bf3c83 100644 --- a/gallery/eng/argument_structure/README.md +++ b/gallery/eng/argument_structure/README.md @@ -1,9 +1,20 @@ # Argument Structure Active Learning Pipeline -**Last Updated:** February 2026 +**Last Updated:** May 2026 A framework for collecting human judgments on argument structure alternations using active learning with convergence detection to human-level inter-annotator agreement. +The 2AFC acceptability question this gallery measures is declared once +in `config.yaml` under `protocol:` as a `SemanticAnchor` with +`scale_type: forced_choice`, materialized by +`protocol.py:build_protocol()`, and threaded through every downstream +stage — `create_2afc_pairs.py` writes the anchor name into every +pair's `item_metadata`, `generate_deployment.py` builds its +`ItemTemplate` via the canonical `family_to_item_template` bridge, and +`simulate_pipeline.py` reads response-space labels off the same +anchor. Run `make validate-protocol` to verify the protocol section +builds cleanly before any data-generation step. + ## Overview This project implements a human-in-the-loop active learning pipeline for studying **argument structure alternations** in English. The pipeline: @@ -98,15 +109,15 @@ Programmatic control using Python scripts. Best for batch operations, complex lo **Quick Example**: ```python from bead.resources.adapters.glazing import GlazingAdapter -from bead.templates.filler import TemplateFiller +from bead.templates.filler import CSPFiller # Stage 1: Import lexicons adapter = GlazingAdapter(resource="verbnet") items = adapter.fetch_items(query="break", language_code="en") -# Stage 2: Fill templates -filler = TemplateFiller(templates, lexicons) -filled = filler.fill(strategy="exhaustive") +# Stage 2: Fill templates (CSPFiller is the canonical concrete filler) +filler = CSPFiller(lexicon) +filled = list(filler.fill(template)) # ... (6 stages total) ``` @@ -127,87 +138,87 @@ The pipeline consists of 10 main scripts organized into 4 stages: ``` ┌─────────────────────────────────────────────────────────────────┐ -│ Stage 1: Resource Generation │ +│ Stage 1: Resource Generation │ ├─────────────────────────────────────────────────────────────────┤ -│ │ +│ │ │ 1. generate_lexicons.py │ -│ ├─ VerbNet verbs (via GlazingAdapter) │ -│ ├─ Morphological forms (via UniMorphAdapter) │ -│ └─ Controlled lexicons (from resources/ CSVs) │ -│ → Output: lexicons/*.jsonl (19,160+ entries) │ -│ │ +│ ├─ VerbNet verbs (via GlazingAdapter) │ +│ ├─ Morphological forms (via UniMorphAdapter) │ +│ └─ Controlled lexicons (from resources/ CSVs) │ +│ → Output: lexicons/*.jsonl (19,160+ entries) │ +│ │ └─────────────────────────────────────────────────────────────────┘ ┌─────────────────────────────────────────────────────────────────┐ -│ Stage 2: Template Generation & Filling │ +│ Stage 2: Template Generation & Filling │ ├─────────────────────────────────────────────────────────────────┤ -│ │ +│ │ │ 2. generate_templates.py │ -│ ├─ Extract all verb-specific VerbNet frames │ -│ ├─ Map to MegaAttitude clausal structures │ -│ └─ Generate DSL constraints │ -│ → Output: templates/verbnet_frames.jsonl (21,453 templates)│ -│ │ -│ 3. extract_generic_templates.py │ -│ └─ Abstract verb-specific → generic structures │ -│ → Output: templates/generic_frames.jsonl (26 templates) │ -│ │ -│ 4. fill_templates.py │ -│ ├─ Fill templates using MixedFillingStrategy │ -│ ├─ Phase 1: Exhaustive filling (det, be, verb slots) │ -│ └─ Phase 2: MLM-based filling (noun, prep, adj slots) │ -│ → Output: filled_templates/generic_frames_filled.jsonl │ -│ │ +│ ├─ Extract all verb-specific VerbNet frames │ +│ ├─ Map to MegaAttitude clausal structures │ +│ └─ Generate DSL constraints │ +│ → Output: templates/verbnet_frames.jsonl (21,453 templates) │ +│ │ +│ 3. extract_generic_templates.py │ +│ └─ Abstract verb-specific → generic structures │ +│ → Output: templates/generic_frames.jsonl (26 templates) │ +│ │ +│ 4. fill_templates.py │ +│ ├─ Fill templates using MixedFillingStrategy │ +│ ├─ Phase 1: Exhaustive filling (det, be, verb slots) │ +│ └─ Phase 2: MLM-based filling (noun, prep, adj slots) │ +│ → Output: filled_templates/generic_frames_filled.jsonl │ +│ │ └─────────────────────────────────────────────────────────────────┘ ┌─────────────────────────────────────────────────────────────────┐ -│ Stage 3: Item Generation & List Partitioning │ +│ Stage 3: Item Generation & List Partitioning │ ├─────────────────────────────────────────────────────────────────┤ -│ │ -│ 5. generate_cross_product.py │ -│ └─ Cross all verbs × all generic frames │ -│ → Output: items/cross_product_items.jsonl (74,880 items) │ -│ │ -│ 6. create_2afc_pairs.py │ -│ ├─ Load filled templates from previous step │ -│ ├─ Score with language model (GPT-2) │ -│ ├─ Create minimal pairs (same_verb, different_verb) │ -│ └─ Stratify by LM score quantiles │ -│ → Output: items/2afc_pairs.jsonl │ -│ │ -│ 7. generate_lists.py │ -│ ├─ Partition 2AFC pairs into balanced lists │ -│ ├─ Apply list constraints (balance, uniqueness, etc.) │ -│ └─ Apply batch constraints (coverage, diversity) │ -│ → Output: lists/experiment_lists.jsonl │ -│ │ +│ │ +│ 5. generate_cross_product.py │ +│ └─ Cross all verbs × all generic frames │ +│ → Output: items/cross_product_items.jsonl (74,880 items) │ +│ │ +│ 6. create_2afc_pairs.py │ +│ ├─ Load filled templates from previous step │ +│ ├─ Score with language model (GPT-2) │ +│ ├─ Create minimal pairs (same_verb, different_verb) │ +│ └─ Stratify by LM score quantiles │ +│ → Output: items/2afc_pairs.jsonl │ +│ │ +│ 7. generate_lists.py │ +│ ├─ Partition 2AFC pairs into balanced lists │ +│ ├─ Apply list constraints (balance, uniqueness, etc.) │ +│ └─ Apply batch constraints (coverage, diversity) │ +│ → Output: lists/experiment_lists.jsonl │ +│ │ └─────────────────────────────────────────────────────────────────┘ ┌─────────────────────────────────────────────────────────────────┐ -│ Stage 4: Deployment & Active Learning │ +│ Stage 4: Deployment & Active Learning │ ├─────────────────────────────────────────────────────────────────┤ -│ │ -│ 8. generate_deployment.py │ -│ ├─ Generate jsPsych experiments (local + JATOS versions) │ -│ ├─ Local: Standalone for testing (no server required) │ -│ ├─ JATOS: Production deployment with Prolific support │ -│ └─ Create JATOS .jzip packages │ -│ → Output: deployment/local/* + deployment/jatos/* │ -│ │ -│ 9. simulate_pipeline.py (testing/validation) │ -│ ├─ Simulate human judgments (LM-based annotator) │ -│ ├─ Test active learning loop │ -│ └─ Validate convergence detection │ -│ → Output: simulation_output/simulation_results.json │ -│ │ -│ 10. run_pipeline.py (production) │ -│ ├─ Load configuration (config.yaml) │ -│ ├─ Initialize convergence detector │ -│ ├─ Run active learning loop │ -│ ├─ Monitor human-model agreement │ -│ └─ Stop when converged to human IAA │ -│ → Output: results/pipeline_results.json │ -│ │ +│ │ +│ 8. generate_deployment.py │ +│ ├─ Generate jsPsych experiments (local + JATOS versions) │ +│ ├─ Local: Standalone for testing (no server required) │ +│ ├─ JATOS: Production deployment with Prolific support │ +│ └─ Create JATOS .jzip packages │ +│ → Output: deployment/local/* + deployment/jatos/* │ +│ │ +│ 9. simulate_pipeline.py (testing/validation) │ +│ ├─ Simulate human judgments (LM-based annotator) │ +│ ├─ Test active learning loop │ +│ └─ Validate convergence detection │ +│ → Output: simulation_output/simulation_results.json │ +│ │ +│ 10. run_pipeline.py (production) │ +│ ├─ Load configuration (config.yaml) │ +│ ├─ Initialize convergence detector │ +│ ├─ Run active learning loop │ +│ ├─ Monitor human-model agreement │ +│ └─ Stop when converged to human IAA │ +│ → Output: results/pipeline_results.json │ +│ │ └─────────────────────────────────────────────────────────────────┘ ``` @@ -234,11 +245,11 @@ Provides `OtherNounRenderer` class for handling repeated noun slots: ```python from utils.renderers import OtherNounRenderer -from bead.templates.filler import TemplateFiller +from bead.templates.filler import CSPFiller renderer = OtherNounRenderer() -filler = TemplateFiller(templates, lexicons, renderer=renderer) -filled = filler.fill(strategy="exhaustive") +filler = CSPFiller(lexicon, renderer=renderer) +filled = list(filler.fill(template)) ``` **Rendering rules**: @@ -361,6 +372,7 @@ project: # Project metadata paths: # Directory structure resources: # Lexicon and template paths template: # Template filling strategies +protocol: # Annotation protocol declaration (anchor + drift) items: # Item construction lists: # List partitioning deployment: # jsPsych/JATOS settings @@ -368,6 +380,34 @@ active_learning: # Sampling strategy training: # Convergence detection ``` +**`protocol`** - Annotation Protocol Declaration + +The 2AFC acceptability question is declared once here as a +`SemanticAnchor` with `scale_type: forced_choice`. Every downstream +stage materializes the live `AnnotationProtocol` via +`protocol.py:build_protocol()` rather than hard-coding prompt +strings or response labels. + +```yaml +protocol: + name: "argument-structure-acceptability" + drift: + min_length: 10 + require_question_mark: true + keyword_case_sensitive: false + families: + - anchor: + name: "acceptability" + target_property: "acceptability" + canonical_prompt: "Which sentence sounds more natural?" + options: ["first", "second"] + is_ordered: false + scale_type: "forced_choice" + required_keywords: ["natural"] + description: "2AFC acceptability judgment over a minimal pair." + realization_kind: "template" +``` + #### Section Details **`project`** - Project Metadata @@ -517,10 +557,10 @@ from utils.verbnet_parser import VerbNetExtractor # fill_templates.py from utils.renderers import OtherNounRenderer -from bead.templates.filler import TemplateFiller +from bead.templates.filler import CSPFiller renderer = OtherNounRenderer() -filler = TemplateFiller(templates, lexicons, renderer=renderer) +filler = CSPFiller(lexicon, renderer=renderer) ``` This keeps the configuration file language-agnostic while allowing language-specific extensions through the plugin system. @@ -1286,6 +1326,7 @@ make clean # Remove generated files ### Data Generation ```bash +make validate-protocol # Build the AnnotationProtocol from config.yaml make lexicons # Generate lexicon files make verbnet-templates # Generate verb-specific VerbNet templates make templates # Extract generic frame structures @@ -1546,12 +1587,13 @@ gallery/eng/argument_structure/ ├── Makefile # Build automation (500+ lines, 40+ targets) ├── config.yaml # Pipeline configuration │ +├── protocol.py # [0] Materialize AnnotationProtocol from config.yaml ├── generate_lexicons.py # [1] Extract VerbNet verbs + bleached lexicons ├── generate_templates.py # [2] Generate verb-specific VerbNet templates ├── extract_generic_templates.py # [3] Extract 26 generic frame structures ├── fill_templates.py # [4] Fill templates with MLM strategy (optional) ├── generate_cross_product.py # [5] Generate verb × frame cross-product -├── create_2afc_pairs.py # [6] Create 2AFC pairs with LM scoring +├── create_2afc_pairs.py # [6] Create 2AFC pairs with LM scoring (anchor-tagged) ├── generate_lists.py # [7] Partition pairs into experiment lists ├── generate_deployment.py # [8] Generate jsPsych/JATOS deployment ├── simulate_pipeline.py # [9] Simulate active learning (testing) @@ -1567,6 +1609,7 @@ gallery/eng/argument_structure/ │ ├── tests/ # Test suite │ ├── __init__.py +│ ├── test_protocol.py # AnnotationProtocol round-trip + bridge tests │ └── test_simulation.py # Simulation tests │ ├── resources/ # Reference data diff --git a/gallery/eng/argument_structure/config.yaml b/gallery/eng/argument_structure/config.yaml index 6d3bcfc..c7edfcf 100644 --- a/gallery/eng/argument_structure/config.yaml +++ b/gallery/eng/argument_structure/config.yaml @@ -105,6 +105,25 @@ template: # adjectives: mlm (unconstrained) adjective: { strategy: "mlm", beam_size: 3, max_fills: 5, enforce_unique: true } +# protocol: canonical declaration of the question being asked +protocol: + name: "argument-structure-acceptability" + drift: + min_length: 10 + require_question_mark: true + keyword_case_sensitive: false + families: + - anchor: + name: "acceptability" + target_property: "acceptability" + canonical_prompt: "Which sentence sounds more natural?" + options: ["first", "second"] + is_ordered: false + scale_type: "forced_choice" + required_keywords: ["natural"] + description: "2AFC acceptability judgment over a minimal pair." + realization_kind: "template" + # items items: judgment_type: "forced_choice" diff --git a/gallery/eng/argument_structure/create_2afc_pairs.py b/gallery/eng/argument_structure/create_2afc_pairs.py index cdef497..fdbb71d 100755 --- a/gallery/eng/argument_structure/create_2afc_pairs.py +++ b/gallery/eng/argument_structure/create_2afc_pairs.py @@ -31,6 +31,8 @@ print_warning, ) from bead.items.forced_choice import create_forced_choice_items_from_groups + +from protocol import ACCEPTABILITY_ANCHOR_NAME from bead.items.item import Item from bead.items.scoring import LanguageModelScorer from bead.lists.stratification import assign_quantiles_by_uuid @@ -228,7 +230,13 @@ def extract_text(item: Item) -> str: print_success(f"Created {len(different_verb_items):,} different-verb pairs") - return same_verb_items + different_verb_items + # Thread the protocol anchor name onto every pair so downstream + # JATOS-result → AnnotationRecord conversion can match responses + # back to the canonical 2AFC acceptability anchor. + all_pairs = same_verb_items + different_verb_items + for fc_item in all_pairs: + fc_item.item_metadata["anchor"] = ACCEPTABILITY_ANCHOR_NAME + return all_pairs def assign_quantiles_to_pairs( diff --git a/gallery/eng/argument_structure/generate_deployment.py b/gallery/eng/argument_structure/generate_deployment.py index 534f033..cd2c9d1 100644 --- a/gallery/eng/argument_structure/generate_deployment.py +++ b/gallery/eng/argument_structure/generate_deployment.py @@ -35,12 +35,11 @@ from bead.deployment.jspsych.config import ChoiceConfig, ExperimentConfig from bead.deployment.jspsych.generator import JsPsychExperimentGenerator from bead.items.item import Item -from bead.items.item_template import ( - ItemTemplate, - PresentationSpec, - TaskSpec, -) +from bead.items.item_template import ItemTemplate from bead.lists import ExperimentList, ListCollection +from bead.protocol.items import family_to_item_template + +from protocol import acceptability_family, build_protocol def load_config(config_path: Path) -> dict: @@ -72,23 +71,15 @@ def load_items_by_uuid(pairs_path: Path) -> dict[UUID, Item]: return items_dict -def create_minimal_item_template() -> ItemTemplate: - """Create a minimal ItemTemplate for 2AFC forced choice items. +def create_minimal_item_template(config_path: Path) -> ItemTemplate: + """Build the 2AFC ItemTemplate from the protocol declared in config.yaml. - Since our 2AFC items are already fully rendered, we just need a minimal - template to satisfy the deployment generator's requirements. + The template's prompt, response options, and judgment type are pulled + from the ``protocol.families[].anchor`` block via the canonical + :func:`bead.protocol.items.family_to_item_template` bridge. """ - return ItemTemplate( - name="2afc_forced_choice", - description="Two-alternative forced choice item", - judgment_type="acceptability", - task_type="forced_choice", - task_spec=TaskSpec( - prompt="Which sentence sounds more natural?", - options=["Option A", "Option B"], - ), - presentation_spec=PresentationSpec(mode="static"), - ) + family = acceptability_family(build_protocol(config_path)) + return family_to_item_template(family, judgment_type="acceptability") def main() -> None: @@ -159,7 +150,7 @@ def main() -> None: # Create minimal template console.print() print_header("[4/6] Creating Item Template") - template = create_minimal_item_template() + template = create_minimal_item_template(config_path) templates_dict = {template.id: template} print_success("Created minimal ItemTemplate for 2AFC items\n") diff --git a/gallery/eng/argument_structure/protocol.py b/gallery/eng/argument_structure/protocol.py new file mode 100644 index 0000000..b9171a2 --- /dev/null +++ b/gallery/eng/argument_structure/protocol.py @@ -0,0 +1,45 @@ +"""Protocol declarations for the argument-structure gallery. + +The 2AFC acceptability question is declared once in ``config.yaml`` +under the ``protocol:`` section and materialized here as a live +:class:`~bead.protocol.AnnotationProtocol`. Every downstream script +(``create_2afc_pairs.py``, ``generate_deployment.py``, +``run_pipeline.py``, ``simulate_pipeline.py``) imports +:func:`build_protocol` so the prompt string, response options, and +drift thresholds have a single source of truth. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Final + +import yaml + +from bead.config.protocol import ProtocolConfig +from bead.protocol import AnnotationProtocol, QuestionFamily, SemanticAnchor + +ACCEPTABILITY_ANCHOR_NAME: Final = "acceptability" + + +def load_protocol_config(config_path: Path | str = "config.yaml") -> ProtocolConfig: + """Parse the ``protocol:`` section of ``config_path`` into a ProtocolConfig.""" + data = yaml.safe_load(Path(config_path).read_text(encoding="utf-8")) + return ProtocolConfig.model_validate(data["protocol"]) + + +def build_protocol(config_path: Path | str = "config.yaml") -> AnnotationProtocol: + """Build the live :class:`AnnotationProtocol` from ``config_path``.""" + return load_protocol_config(config_path).build() + + +def acceptability_family( + protocol: AnnotationProtocol, +) -> QuestionFamily: + """Return the acceptability family from ``protocol``.""" + return protocol.family_by_name(ACCEPTABILITY_ANCHOR_NAME) + + +def acceptability_anchor(protocol: AnnotationProtocol) -> SemanticAnchor: + """Return the acceptability anchor from ``protocol``.""" + return acceptability_family(protocol).anchor diff --git a/gallery/eng/argument_structure/simulate_pipeline.py b/gallery/eng/argument_structure/simulate_pipeline.py index f21a748..a2c286f 100644 --- a/gallery/eng/argument_structure/simulate_pipeline.py +++ b/gallery/eng/argument_structure/simulate_pipeline.py @@ -45,9 +45,12 @@ from bead.evaluation.convergence import ConvergenceDetector from bead.evaluation.interannotator import InterAnnotatorMetrics from bead.items.item import Item -from bead.items.item_template import ItemTemplate, PresentationSpec, TaskSpec +from bead.items.item_template import ItemTemplate +from bead.protocol.items import family_to_item_template from bead.simulation.annotators.base import SimulatedAnnotator +from protocol import acceptability_family, build_protocol + def load_2afc_pairs(path: Path, limit: int | None = None, skip: int = 0) -> list[Item]: """Load 2AFC pairs from JSONL. @@ -66,35 +69,37 @@ def load_2afc_pairs(path: Path, limit: int | None = None, skip: int = 0) -> list list[Item] List of items """ - items = [] + items: list[Item] = [] with open(path) as f: for i, line in enumerate(f): if i < skip: continue if limit and (i - skip) >= limit: break - data = json.loads(line) - items.append(Item(**data)) + items.append(Item.model_validate_json(line)) return items -def get_forced_choice_template() -> ItemTemplate: - """Create ItemTemplate for 2AFC forced choice task. +_CONFIG_PATH = Path(__file__).resolve().parent / "config.yaml" - Returns - ------- - ItemTemplate - Template configured for forced_choice task using proper TaskSpec + +def get_forced_choice_template() -> ItemTemplate: + """Build the 2AFC ItemTemplate from the configured protocol. + + The prompt and task-type come from the + ``protocol.families[].anchor`` declaration in ``config.yaml`` + via :func:`bead.protocol.items.family_to_item_template`. The + canonical bridge leaves ``task_spec.options`` unset because the + per-item alternatives (the two sentences) live on each + :class:`Item`; the simulator however samples from response-space + labels (``"first"`` / ``"second"``), so we splice those onto + ``task_spec.options`` here. """ - return ItemTemplate( - name="2AFC Forced Choice", - judgment_type="preference", - task_type="forced_choice", - task_spec=TaskSpec( - prompt="Which sentence sounds more natural?", - options=["option_a", "option_b"], - ), - presentation_spec=PresentationSpec(mode="static"), + family = acceptability_family(build_protocol(_CONFIG_PATH)) + template = family_to_item_template(family, judgment_type="acceptability") + response_options = tuple(family.anchor.response_space.options) + return template.with_( + task_spec=template.task_spec.with_(options=response_options) ) @@ -227,10 +232,10 @@ def run_simulation( # compute simulated human agreement (sample twice with different seeds) # create two new annotators with different random states for agreement calculation annotator_sample1 = SimulatedAnnotator.from_config( - annotator_config.model_copy(update={"random_state": (random_state or 0) + 1000}) + annotator_config.with_(random_state=(random_state or 0) + 1000) ) annotator_sample2 = SimulatedAnnotator.from_config( - annotator_config.model_copy(update={"random_state": (random_state or 0) + 2000}) + annotator_config.with_(random_state=(random_state or 0) + 2000) ) sample1 = annotator_sample1.annotate_batch(initial_items, item_template) diff --git a/gallery/eng/argument_structure/tests/test_protocol.py b/gallery/eng/argument_structure/tests/test_protocol.py new file mode 100644 index 0000000..a09715c --- /dev/null +++ b/gallery/eng/argument_structure/tests/test_protocol.py @@ -0,0 +1,53 @@ +"""Round-trip tests for the gallery's protocol declaration.""" + +from __future__ import annotations + +import sys +from pathlib import Path + +GALLERY_DIR = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(GALLERY_DIR)) + +from protocol import ( # noqa: E402 + ACCEPTABILITY_ANCHOR_NAME, + acceptability_anchor, + acceptability_family, + build_protocol, +) + +from bead.active_learning.models.registry import ( # noqa: E402 + model_class_for_encoding, +) +from bead.protocol.encoding import ( # noqa: E402 + ScaleType, + encode_response_space, +) +from bead.protocol.items import family_to_item_template # noqa: E402 + +CONFIG_PATH = GALLERY_DIR / "config.yaml" + + +def test_protocol_has_acceptability_family() -> None: + protocol = build_protocol(CONFIG_PATH) + assert len(protocol.families) == 1 + assert protocol.families[0].anchor.name == ACCEPTABILITY_ANCHOR_NAME + + +def test_anchor_carries_forced_choice_scale_type() -> None: + anchor = acceptability_anchor(build_protocol(CONFIG_PATH)) + assert anchor.response_space.scale_type is ScaleType.FORCED_CHOICE + assert anchor.response_space.options == ("first", "second") + + +def test_family_to_item_template_round_trips_prompt() -> None: + family = acceptability_family(build_protocol(CONFIG_PATH)) + template = family_to_item_template(family, judgment_type="acceptability") + assert template.task_type == "forced_choice" + assert template.task_spec.prompt == family.anchor.canonical_prompt + + +def test_forced_choice_encoding_picks_forced_choice_model() -> None: + anchor = acceptability_anchor(build_protocol(CONFIG_PATH)) + encoding = encode_response_space(anchor.name, anchor.response_space) + model_cls = model_class_for_encoding(encoding) + assert model_cls.__name__ == "ForcedChoiceModel" diff --git a/gallery/eng/argument_structure/tests/test_simulation.py b/gallery/eng/argument_structure/tests/test_simulation.py index 14dfa17..fd8ca68 100644 --- a/gallery/eng/argument_structure/tests/test_simulation.py +++ b/gallery/eng/argument_structure/tests/test_simulation.py @@ -29,19 +29,23 @@ class TestGetForcedChoiceTemplate: """Test suite for get_forced_choice_template function.""" def test_returns_valid_template(self) -> None: - """Test that function returns valid ItemTemplate.""" + """Test that function returns a forced-choice ItemTemplate from the protocol.""" template = get_forced_choice_template() - assert template.name == "2AFC Forced Choice" + # Template name comes from the protocol anchor (acceptability) + assert template.name == "acceptability" assert template.task_type == "forced_choice" - assert template.judgment_type == "preference" + assert template.judgment_type == "acceptability" def test_template_has_required_task_spec(self) -> None: - """Test that template has proper task spec with options.""" + """Test that template has the prompt declared in config.yaml.""" template = get_forced_choice_template() assert template.task_spec.prompt == "Which sentence sounds more natural?" - assert template.task_spec.options == ["option_a", "option_b"] + # ``get_forced_choice_template`` enriches the bare protocol + # template with response-space labels so the simulator can + # sample from them. + assert template.task_spec.options == ("first", "second") def test_template_presentation_spec(self) -> None: """Test that template has static presentation mode.""" diff --git a/pyproject.toml b/pyproject.toml index 474111e..2af2ed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "bead" -version = "0.4.0" +version = "0.5.0" description = "Lexicon and Template Collection Construction Pipeline for Acceptability and Inference Judgment Data" authors = [{name = "Aaron Steven White", email = "aaron.white@rochester.edu"}] readme = "README.md" diff --git a/tests/active_learning/models/binary/test_mixed_effects.py b/tests/active_learning/models/binary/test_mixed_effects.py index f3fc567..07233b7 100644 --- a/tests/active_learning/models/binary/test_mixed_effects.py +++ b/tests/active_learning/models/binary/test_mixed_effects.py @@ -33,9 +33,8 @@ def test_train_with_fixed_mode( ) model = BinaryModel(config) - # Fixed effects: use placeholder participant_ids - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_labels, participant_ids) + # Fixed effects: participant_ids are not used. + metrics = model.train(sample_items, sample_labels) assert "train_accuracy" in metrics assert "train_loss" in metrics @@ -55,11 +54,8 @@ def test_predict_with_fixed_mode( ) model = BinaryModel(config) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) - - # Predict with same participant_ids - predictions = model.predict(sample_items[:5], participant_ids[:5]) + model.train(sample_items, sample_labels) + predictions = model.predict(sample_items[:5]) assert len(predictions) == 5 for pred in predictions: @@ -327,10 +323,8 @@ def test_predict_proba_with_fixed_mode( ) model = BinaryModel(config) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) - - proba = model.predict_proba(sample_items[:5], participant_ids[:5]) + model.train(sample_items, sample_labels) + proba = model.predict_proba(sample_items[:5]) assert proba.shape == (5, 2) # 2 classes for binary # Each row should sum to 1 @@ -362,16 +356,19 @@ def test_predict_proba_validates_participant_ids_length( ) -> None: """Test that predict_proba validates participant_ids length.""" config = BinaryModelConfig( - model_name="bert-base-uncased", num_epochs=1, device="cpu" + model_name="bert-base-uncased", + num_epochs=1, + device="cpu", + mixed_effects=MixedEffectsConfig(mode="random_intercepts"), ) model = BinaryModel(config) - participant_ids = ["default"] * len(sample_items) + participant_ids = ["alice", "bob"] * (len(sample_items) // 2) model.train(sample_items, sample_labels, participant_ids) # Wrong length for predict_proba with pytest.raises(ValueError, match="Length mismatch"): - model.predict_proba(sample_items[:5], ["default"] * 3) + model.predict_proba(sample_items[:5], ["alice"] * 3) class TestSaveLoad: @@ -528,9 +525,8 @@ def test_single_output_unit(self, sample_items: list[Item]) -> None: model = BinaryModel(config) labels = ["yes" if i % 2 == 0 else "no" for i in range(20)] - participant_ids = ["default"] * 20 - model.train(sample_items, labels, participant_ids) + model.train(sample_items, labels) # num_classes=1 for true binary classification (single output unit) assert model.num_classes == 1 # But we still have 2 label names @@ -549,12 +545,11 @@ def test_different_label_names(self, sample_items: list[Item]) -> None: # Use "true" and "false" instead of "yes" and "no" labels = ["true" if i % 2 == 0 else "false" for i in range(20)] - participant_ids = ["default"] * 20 - model.train(sample_items, labels, participant_ids) + model.train(sample_items, labels) assert model.label_names == ["false", "true"] # Sorted alphabetically - predictions = model.predict(sample_items[:5], participant_ids[:5]) + predictions = model.predict(sample_items[:5]) for pred in predictions: assert pred.predicted_class in ["true", "false"] @@ -593,10 +588,8 @@ def test_probabilities_sum_to_one( ) model = BinaryModel(config) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) - - predictions = model.predict(sample_items[:5], participant_ids[:5]) + model.train(sample_items, sample_labels) + predictions = model.predict(sample_items[:5]) for pred in predictions: # Sum of probabilities should be 1.0 prob_sum = sum(pred.probabilities.values()) @@ -611,7 +604,6 @@ def test_rejects_non_binary_labels(self, sample_items: list[Item]) -> None: # Three different labels (not binary!) - must match sample_items length labels = (["yes", "no", "maybe"] * 6) + ["yes", "no"] - participant_ids = ["default"] * 20 with pytest.raises(ValueError, match="exactly 2 classes"): - model.train(sample_items, labels, participant_ids) + model.train(sample_items, labels) diff --git a/tests/active_learning/models/categorical/test_mixed_effects.py b/tests/active_learning/models/categorical/test_mixed_effects.py index 8f29a69..caa5012 100644 --- a/tests/active_learning/models/categorical/test_mixed_effects.py +++ b/tests/active_learning/models/categorical/test_mixed_effects.py @@ -60,8 +60,7 @@ def test_train_with_fixed_mode( model = CategoricalModel(config) # Fixed effects: use placeholder participant_ids - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_labels, participant_ids) + metrics = model.train(sample_items, sample_labels) assert "train_accuracy" in metrics assert "train_loss" in metrics @@ -80,30 +79,25 @@ def test_predict_with_fixed_mode( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = CategoricalModel(config) - - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) + model.train(sample_items, sample_labels) # Predict with same participant_ids - predictions = model.predict(sample_items[:5], participant_ids[:5]) + predictions = model.predict(sample_items[:5]) assert len(predictions) == 5 for pred in predictions: assert pred.predicted_class in ["entailment", "neutral", "contradiction"] assert 0.0 <= pred.confidence <= 1.0 - def test_train_requires_participant_ids( + def test_train_accepts_default_fixed_mode( self, sample_items: list[Item], sample_labels: list[str] ) -> None: - """Test that train requires participant_ids parameter.""" + """Default mixed_effects.mode is 'fixed'; train runs without participant_ids.""" config = CategoricalModelConfig( model_name="bert-base-uncased", num_epochs=1, device="cpu" ) model = CategoricalModel(config) - - # Should work with participant_ids - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) + model.train(sample_items, sample_labels) def test_train_validates_participant_ids_length( self, sample_items: list[Item], sample_labels: list[str] @@ -363,11 +357,9 @@ def test_predict_proba_with_fixed_mode( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = CategoricalModel(config) + model.train(sample_items, sample_labels) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) - - proba = model.predict_proba(sample_items[:5], participant_ids[:5]) + proba = model.predict_proba(sample_items[:5]) assert proba.shape == (5, 3) # 3 classes # Each row should sum to 1 @@ -403,16 +395,19 @@ def test_predict_proba_validates_participant_ids_length( ) -> None: """Test that predict_proba validates participant_ids length.""" config = CategoricalModelConfig( - model_name="bert-base-uncased", num_epochs=1, device="cpu" + model_name="bert-base-uncased", + num_epochs=1, + device="cpu", + mixed_effects=MixedEffectsConfig(mode="random_intercepts"), ) model = CategoricalModel(config) - participant_ids = ["default"] * len(sample_items) + participant_ids = ["alice", "bob"] * (len(sample_items) // 2) model.train(sample_items, sample_labels, participant_ids) # Wrong length for predict_proba with pytest.raises(ValueError, match="Length mismatch"): - model.predict_proba(sample_items[:5], ["default"] * 3) + model.predict_proba(sample_items[:5], ["alice"] * 3) class TestSaveLoad: diff --git a/tests/active_learning/models/forced_choice/test_mixed_effects.py b/tests/active_learning/models/forced_choice/test_mixed_effects.py index 60f9f28..932ec0f 100644 --- a/tests/active_learning/models/forced_choice/test_mixed_effects.py +++ b/tests/active_learning/models/forced_choice/test_mixed_effects.py @@ -56,8 +56,7 @@ def test_train_with_fixed_mode( model = ForcedChoiceModel(config) # Fixed effects: use placeholder participant_ids - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_labels, participant_ids) + metrics = model.train(sample_items, sample_labels) assert "train_accuracy" in metrics assert "train_loss" in metrics @@ -76,30 +75,25 @@ def test_predict_with_fixed_mode( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = ForcedChoiceModel(config) - - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) + model.train(sample_items, sample_labels) # Predict with same participant_ids - predictions = model.predict(sample_items[:5], participant_ids[:5]) + predictions = model.predict(sample_items[:5]) assert len(predictions) == 5 for pred in predictions: assert pred.predicted_class in ["option_a", "option_b"] assert 0.0 <= pred.confidence <= 1.0 - def test_train_requires_participant_ids( + def test_train_accepts_default_fixed_mode( self, sample_items: list[Item], sample_labels: list[str] ) -> None: - """Test that train requires participant_ids parameter.""" + """Default mixed_effects.mode is 'fixed'; train runs without participant_ids.""" config = ForcedChoiceModelConfig( model_name="bert-base-uncased", num_epochs=1, device="cpu" ) model = ForcedChoiceModel(config) - - # Should work with participant_ids - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) + model.train(sample_items, sample_labels) def test_train_validates_participant_ids_length( self, sample_items: list[Item], sample_labels: list[str] @@ -338,11 +332,9 @@ def test_predict_proba_with_fixed_mode( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = ForcedChoiceModel(config) + model.train(sample_items, sample_labels) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) - - proba = model.predict_proba(sample_items[:5], participant_ids[:5]) + proba = model.predict_proba(sample_items[:5]) assert proba.shape == (5, 2) # Each row should sum to 1 @@ -378,11 +370,14 @@ def test_predict_proba_validates_participant_ids_length( ) -> None: """Test that predict_proba validates participant_ids length.""" config = ForcedChoiceModelConfig( - model_name="bert-base-uncased", num_epochs=1, device="cpu" + model_name="bert-base-uncased", + num_epochs=1, + device="cpu", + mixed_effects=MixedEffectsConfig(mode="random_intercepts"), ) model = ForcedChoiceModel(config) - participant_ids = ["default"] * len(sample_items) + participant_ids = ["alice", "bob"] * (len(sample_items) // 2) model.train(sample_items, sample_labels, participant_ids) # Wrong length for predict_proba diff --git a/tests/active_learning/models/magnitude/test_mixed_effects.py b/tests/active_learning/models/magnitude/test_mixed_effects.py index dc08626..e68aecd 100644 --- a/tests/active_learning/models/magnitude/test_mixed_effects.py +++ b/tests/active_learning/models/magnitude/test_mixed_effects.py @@ -34,8 +34,7 @@ def test_train_with_fixed_mode_unbounded( model = MagnitudeModel(config) # Fixed effects: use placeholder participant_ids - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_unbounded_labels, participant_ids) + metrics = model.train(sample_items, sample_unbounded_labels) assert "train_mse" in metrics assert "train_loss" in metrics @@ -58,9 +57,7 @@ def test_train_with_fixed_mode_bounded( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = MagnitudeModel(config) - - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_bounded_labels, participant_ids) + metrics = model.train(sample_items, sample_bounded_labels) assert "train_mse" in metrics assert "train_loss" in metrics @@ -78,12 +75,10 @@ def test_predict_with_fixed_mode( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = MagnitudeModel(config) - - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_unbounded_labels, participant_ids) + model.train(sample_items, sample_unbounded_labels) # Predict with same participant_ids - predictions = model.predict(sample_items[:5], participant_ids[:5]) + predictions = model.predict(sample_items[:5]) assert len(predictions) == 5 for pred in predictions: @@ -143,10 +138,9 @@ def test_train_validates_bounds(self, sample_items: list[Item]) -> None: # Label outside bounds labels = ["50.0"] * 19 + ["150.0"] # 150.0 > max_value=100.0 - participant_ids = ["default"] * len(sample_items) with pytest.raises(ValueError, match="outside bounds"): - model.train(sample_items, labels, participant_ids) + model.train(sample_items, labels) class TestRandomInterceptsMode: @@ -448,9 +442,11 @@ def test_save_and_load_preserves_random_effects( # Check intercepts preserved loaded_bias_p1 = model2.random_effects.intercepts["mu"]["p1"] assert loaded_bias_p1.shape == orig_bias_p1.shape - # Check values are close (may have small numerical differences) - assert float(loaded_bias_p1[0]) == pytest.approx( - float(orig_bias_p1[0]), abs=1e-5 + # Check values are close (may have small numerical differences). + # ``.item()`` detaches the tensor so the comparison does not + # emit a "converting tensor with requires_grad=True" warning. + assert loaded_bias_p1[0].item() == pytest.approx( + orig_bias_p1[0].item(), abs=1e-5 ) def test_save_and_load_preserves_config( @@ -471,8 +467,7 @@ def test_save_and_load_preserves_config( # Need to adjust labels to bounded range bounded_labels = [str(float(i * 5)) for i in range(20)] - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, bounded_labels, participant_ids) + model.train(sample_items, bounded_labels) with tempfile.TemporaryDirectory() as tmpdir: model.save(tmpdir) @@ -504,12 +499,11 @@ def test_unbounded_normal_distribution( ) model = MagnitudeModel(config) - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_unbounded_labels, participant_ids) + metrics = model.train(sample_items, sample_unbounded_labels) assert "train_mse" in metrics # Predictions can be any value (unbounded) - predictions = model.predict(sample_items[:5], participant_ids[:5]) + predictions = model.predict(sample_items[:5]) for pred in predictions: val = float(pred.predicted_class) assert isinstance(val, float) @@ -530,12 +524,11 @@ def test_bounded_truncated_normal_distribution( ) model = MagnitudeModel(config) - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_bounded_labels, participant_ids) + metrics = model.train(sample_items, sample_bounded_labels) assert "train_mse" in metrics # Predictions should be within bounds - predictions = model.predict(sample_items[:5], participant_ids[:5]) + predictions = model.predict(sample_items[:5]) for pred in predictions: val = float(pred.predicted_class) assert 0.0 <= val <= 100.0 @@ -556,14 +549,11 @@ def test_truncated_normal_handles_endpoints( ) model = MagnitudeModel(config) - participant_ids = ["default"] * len(sample_items) # Should handle endpoints without errors - metrics = model.train( - sample_items, sample_bounded_endpoint_labels, participant_ids - ) + metrics = model.train(sample_items, sample_bounded_endpoint_labels) assert "train_mse" in metrics - predictions = model.predict(sample_items[:5], participant_ids[:5]) + predictions = model.predict(sample_items[:5]) for pred in predictions: val = float(pred.predicted_class) assert 0.0 <= val <= 100.0 diff --git a/tests/active_learning/models/multi_select/test_multi_select_mixed_effects.py b/tests/active_learning/models/multi_select/test_multi_select_mixed_effects.py index ec463ef..90e96b9 100644 --- a/tests/active_learning/models/multi_select/test_multi_select_mixed_effects.py +++ b/tests/active_learning/models/multi_select/test_multi_select_mixed_effects.py @@ -34,8 +34,7 @@ def test_train_with_fixed_mode( model = MultiSelectModel(config) # Fixed effects: use placeholder participant_ids - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_labels, participant_ids) + metrics = model.train(sample_items, sample_labels) assert "train_accuracy" in metrics assert "train_loss" in metrics @@ -54,12 +53,10 @@ def test_predict_with_fixed_mode( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = MultiSelectModel(config) - - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) + model.train(sample_items, sample_labels) # Predict with same participant_ids - predictions = model.predict(sample_items[:5], participant_ids[:5]) + predictions = model.predict(sample_items[:5]) assert len(predictions) == 5 for pred in predictions: @@ -89,14 +86,17 @@ def test_train_requires_participant_ids( def test_length_mismatch_raises_error( self, sample_items: list[Item], sample_labels: list[str] ) -> None: - """Test that length mismatch between items and participant_ids raises error.""" + """Length mismatch between items and participant_ids raises.""" config = MultiSelectModelConfig( - model_name="bert-base-uncased", num_epochs=1, device="cpu" + model_name="bert-base-uncased", + num_epochs=1, + device="cpu", + mixed_effects=MixedEffectsConfig(mode="random_intercepts"), ) model = MultiSelectModel(config) # Mismatched lengths - participant_ids = ["default"] * (len(sample_items) - 1) + participant_ids = ["alice"] * (len(sample_items) - 1) with pytest.raises(ValueError, match="Length mismatch"): model.train(sample_items, sample_labels, participant_ids) @@ -104,14 +104,17 @@ def test_length_mismatch_raises_error( def test_empty_participant_id_raises_error( self, sample_items: list[Item], sample_labels: list[str] ) -> None: - """Test that empty participant_id raises error.""" + """Empty participant_id strings raise.""" config = MultiSelectModelConfig( - model_name="bert-base-uncased", num_epochs=1, device="cpu" + model_name="bert-base-uncased", + num_epochs=1, + device="cpu", + mixed_effects=MixedEffectsConfig(mode="random_intercepts"), ) model = MultiSelectModel(config) # Include empty participant_id - participant_ids = ["alice", "", "bob"] + ["default"] * (len(sample_items) - 3) + participant_ids = ["alice", "", "bob"] + ["x"] * (len(sample_items) - 3) with pytest.raises(ValueError, match="cannot contain empty strings"): model.train(sample_items, sample_labels, participant_ids) @@ -471,11 +474,9 @@ def test_sigmoid_output_independent_probabilities( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = MultiSelectModel(config) + model.train(sample_items, sample_labels) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) - - predictions = model.predict([sample_items[0]], [participant_ids[0]]) + predictions = model.predict([sample_items[0]]) # Probabilities are independent - sum may not be 1.0 prob_sum = sum(predictions[0].probabilities.values()) @@ -496,11 +497,9 @@ def test_empty_selection_possible(self, sample_items: list[Item]) -> None: mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = MultiSelectModel(config) + model.train(sample_items[:10], labels) - participant_ids = ["default"] * len(sample_items[:10]) - model.train(sample_items[:10], labels, participant_ids) - - predictions = model.predict([sample_items[0]], [participant_ids[0]]) + predictions = model.predict([sample_items[0]]) # Should return a prediction (possibly empty) selected = json.loads(predictions[0].predicted_class) @@ -519,11 +518,9 @@ def test_all_options_selected_possible(self, sample_items: list[Item]) -> None: mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = MultiSelectModel(config) + model.train(sample_items[:10], labels) - participant_ids = ["default"] * len(sample_items[:10]) - model.train(sample_items[:10], labels, participant_ids) - - predictions = model.predict([sample_items[0]], [participant_ids[0]]) + predictions = model.predict([sample_items[0]]) # Should return a prediction selected = json.loads(predictions[0].predicted_class) @@ -541,9 +538,7 @@ def test_hamming_accuracy_metric( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = MultiSelectModel(config) - - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_labels, participant_ids) + metrics = model.train(sample_items, sample_labels) # Hamming accuracy should be in [0, 1] assert 0.0 <= metrics["train_accuracy"] <= 1.0 @@ -560,10 +555,8 @@ def test_invalid_label_format_raises_error(self, sample_items: list[Item]) -> No ) model = MultiSelectModel(config) - participant_ids = ["default"] * len(sample_items) - with pytest.raises(ValueError, match="valid JSON"): - model.train(sample_items, labels_invalid, participant_ids) + model.train(sample_items, labels_invalid) def test_invalid_option_in_label_raises_error( self, sample_items: list[Item] @@ -577,10 +570,8 @@ def test_invalid_option_in_label_raises_error( ) model = MultiSelectModel(config) - participant_ids = ["default"] * len(sample_items) - with pytest.raises(ValueError, match="Invalid option"): - model.train(sample_items, labels_invalid, participant_ids) + model.train(sample_items, labels_invalid) class TestDualEncoderMode: @@ -599,15 +590,13 @@ def test_dual_encoder_mode_train_and_predict( mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = MultiSelectModel(config) - - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_labels, participant_ids) + metrics = model.train(sample_items, sample_labels) assert "train_accuracy" in metrics assert "train_loss" in metrics # Predict with dual encoder - predictions = model.predict(sample_items[:3], participant_ids[:3]) + predictions = model.predict(sample_items[:3]) assert len(predictions) == 3 for pred in predictions: @@ -623,10 +612,8 @@ def test_predict_before_training_raises_error( ) model = MultiSelectModel(config) - participant_ids = ["default"] * len(sample_items) - with pytest.raises(ValueError, match="not trained"): - model.predict(sample_items[:1], participant_ids[:1]) + model.predict(sample_items[:1]) def test_predict_proba_before_training_raises_error( self, sample_items: list[Item] @@ -637,7 +624,5 @@ def test_predict_proba_before_training_raises_error( ) model = MultiSelectModel(config) - participant_ids = ["default"] * len(sample_items) - with pytest.raises(ValueError, match="not trained"): - model.predict_proba(sample_items[:1], participant_ids[:1]) + model.predict_proba(sample_items[:1]) diff --git a/tests/active_learning/models/ordinal_scale/test_mixed_effects.py b/tests/active_learning/models/ordinal_scale/test_mixed_effects.py index 868a38b..cc5b728 100644 --- a/tests/active_learning/models/ordinal_scale/test_mixed_effects.py +++ b/tests/active_learning/models/ordinal_scale/test_mixed_effects.py @@ -9,6 +9,7 @@ from bead.active_learning.config import MixedEffectsConfig from bead.active_learning.models.ordinal_scale import OrdinalScaleModel from bead.config.active_learning import OrdinalScaleModelConfig +from bead.data.range import Range from bead.items.item import Item # mark all tests in this module as slow model training tests @@ -31,9 +32,8 @@ def test_train_with_fixed_mode( ) model = OrdinalScaleModel(config) - # Fixed effects: use placeholder participant_ids - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_labels, participant_ids) + # Fixed effects: participant_ids are not used. + metrics = model.train(sample_items, sample_labels) assert "train_mse" in metrics assert "train_loss" in metrics @@ -53,11 +53,8 @@ def test_predict_with_fixed_mode( ) model = OrdinalScaleModel(config) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) - - # Predict with same participant_ids - predictions = model.predict(sample_items[:5], participant_ids[:5]) + model.train(sample_items, sample_labels) + predictions = model.predict(sample_items[:5]) assert len(predictions) == 5 for pred in predictions: @@ -110,10 +107,9 @@ def test_train_validates_label_bounds(self, sample_items: list[Item]) -> None: # Label outside bounds labels = ["0.5"] * 19 + ["1.5"] # 1.5 > scale_max=1.0 - participant_ids = ["default"] * len(sample_items) with pytest.raises(ValueError, match="outside bounds"): - model.train(sample_items, labels, participant_ids) + model.train(sample_items, labels) class TestRandomInterceptsMode: @@ -341,11 +337,10 @@ def test_predict_proba_with_fixed_mode( ) model = OrdinalScaleModel(config) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) + model.train(sample_items, sample_labels) # predict_proba should return μ values - proba = model.predict_proba(sample_items[:5], participant_ids[:5]) + proba = model.predict_proba(sample_items[:5]) assert proba.shape == (5, 1) assert all(0.0 <= val[0] <= 1.0 for val in proba) @@ -447,8 +442,7 @@ def test_train_with_endpoint_values( ) model = OrdinalScaleModel(config) - participant_ids = ["default"] * len(sample_items) - metrics = model.train(sample_items, sample_endpoint_labels, participant_ids) + metrics = model.train(sample_items, sample_endpoint_labels) # Should train successfully with endpoints assert "train_mse" in metrics @@ -463,16 +457,13 @@ def test_predict_clamps_to_bounds( num_epochs=1, batch_size=4, device="cpu", - scale_min=0.0, - scale_max=1.0, + scale=Range[float](min=0.0, max=1.0), mixed_effects=MixedEffectsConfig(mode="fixed"), ) model = OrdinalScaleModel(config) - participant_ids = ["default"] * len(sample_items) - model.train(sample_items, sample_labels, participant_ids) - - predictions = model.predict(sample_items, participant_ids) + model.train(sample_items, sample_labels) + predictions = model.predict(sample_items) # All predictions should be in bounds for pred in predictions: diff --git a/tests/config/compose/__init__.py b/tests/config/compose/__init__.py new file mode 100644 index 0000000..9f06acd --- /dev/null +++ b/tests/config/compose/__init__.py @@ -0,0 +1 @@ +"""Tests for the bead.config.compose subpackage.""" diff --git a/tests/config/compose/conftest.py b/tests/config/compose/conftest.py new file mode 100644 index 0000000..7b1174f --- /dev/null +++ b/tests/config/compose/conftest.py @@ -0,0 +1,28 @@ +"""Shared fixtures for compose tests. + +Defines a small fake didactic schema so the tests exercise the +subpackage without depending on ``BeadConfig``. This is the same +discipline the subpackage uses internally: tests against a generic +``dx.Model`` so they survive the eventual extraction. +""" + +from __future__ import annotations + +import didactic.api as dx + + +class FakeNested(dx.Model): + """Two-field nested model used in tests.""" + + value: str = "" + count: int = 0 + + +class FakeSchema(dx.Model): + """Small two-level didactic schema used as the compose target.""" + + name: str = "" + paths: dict[str, str] = dx.field(default_factory=dict) + items: tuple[str, ...] = () + nested: dx.Embed[FakeNested] = dx.field(default_factory=FakeNested) + enabled: bool = False diff --git a/tests/config/compose/test_compose_basics.py b/tests/config/compose/test_compose_basics.py new file mode 100644 index 0000000..285f0fc --- /dev/null +++ b/tests/config/compose/test_compose_basics.py @@ -0,0 +1,157 @@ +"""End-to-end compose pipeline tests against the FakeSchema fixture.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from bead.config.compose import ConfigError, compose + +from .conftest import FakeNested, FakeSchema + + +def _write_yaml(path: Path, content: str) -> Path: + path.write_text(content, encoding="utf-8") + return path + + +def test_minimal_yaml_loads(tmp_path: Path) -> None: + cfg_path = _write_yaml(tmp_path / "cfg.yaml", "name: hello\n") + config = compose(cfg_path, schema=FakeSchema) + assert config.name == "hello" + + +def test_profile_dict_merged_first(tmp_path: Path) -> None: + cfg_path = _write_yaml(tmp_path / "cfg.yaml", "name: from_yaml\n") + config = compose( + cfg_path, + schema=FakeSchema, + profile_dict={"name": "from_profile", "enabled": True}, + ) + assert config.name == "from_yaml" # YAML beats profile + assert config.enabled is True + + +def test_overrides_take_precedence(tmp_path: Path) -> None: + cfg_path = _write_yaml(tmp_path / "cfg.yaml", "name: yaml\n") + config = compose( + cfg_path, + schema=FakeSchema, + overrides=["name=cli"], + ) + assert config.name == "cli" + + +def test_override_typed_value(tmp_path: Path) -> None: + cfg_path = _write_yaml(tmp_path / "cfg.yaml", "nested:\n count: 1\n") + config = compose( + cfg_path, + schema=FakeSchema, + overrides=["nested.count=42"], + ) + assert config.nested.count == 42 + + +def test_strict_unknown_key_raises(tmp_path: Path) -> None: + cfg_path = _write_yaml(tmp_path / "cfg.yaml", "not_a_real_key: 5\n") + with pytest.raises(ConfigError, match="Unknown config key"): + compose(cfg_path, schema=FakeSchema) + + +def test_strict_unknown_nested_key_raises(tmp_path: Path) -> None: + cfg_path = _write_yaml( + tmp_path / "cfg.yaml", + "nested:\n not_a_real_subkey: 5\n", + ) + with pytest.raises(ConfigError, match="nested.not_a_real_subkey"): + compose(cfg_path, schema=FakeSchema) + + +def test_defaults_list_merges_left_to_right(tmp_path: Path) -> None: + base = _write_yaml(tmp_path / "base.yaml", "name: base\nenabled: true\n") + middle = _write_yaml(tmp_path / "middle.yaml", "name: middle\n") + primary = _write_yaml( + tmp_path / "primary.yaml", + f"defaults:\n - {base.stem}\n - {middle.stem}\nname: final\n", + ) + config = compose(primary, schema=FakeSchema) + assert config.name == "final" + assert config.enabled is True + + +def test_defaults_list_without_extension(tmp_path: Path) -> None: + _write_yaml(tmp_path / "base.yaml", "name: base\n") + primary = _write_yaml(tmp_path / "primary.yaml", "defaults:\n - base\n") + config = compose(primary, schema=FakeSchema) + assert config.name == "base" + + +def test_extra_overlays(tmp_path: Path) -> None: + primary = _write_yaml(tmp_path / "primary.yaml", "name: A\n") + overlay = _write_yaml(tmp_path / "overlay.yaml", "name: B\n") + config = compose( + primary, + schema=FakeSchema, + extra=[overlay], + ) + assert config.name == "B" + + +def test_toml_supported(tmp_path: Path) -> None: + cfg_path = tmp_path / "cfg.toml" + cfg_path.write_text('name = "from_toml"\n', encoding="utf-8") + config = compose(cfg_path, schema=FakeSchema) + assert config.name == "from_toml" + + +def test_unsupported_suffix_raises(tmp_path: Path) -> None: + bad = tmp_path / "cfg.xml" + bad.write_text("", encoding="utf-8") + with pytest.raises(ConfigError, match="Unsupported"): + compose(bad, schema=FakeSchema) + + +def test_interpolation_resolves_against_composed_root(tmp_path: Path) -> None: + cfg_path = _write_yaml( + tmp_path / "cfg.yaml", + "paths:\n data_dir: /tmp/x\n out_dir: ${paths.data_dir}/out\n", + ) + config = compose(cfg_path, schema=FakeSchema) + assert config.paths["out_dir"] == "/tmp/x/out" + + +def test_env_interpolation_in_yaml( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setenv("BEAD_DATA", "/data/from/env") + cfg_path = _write_yaml( + tmp_path / "cfg.yaml", + "paths:\n data_dir: ${oc.env:BEAD_DATA}\n", + ) + config = compose(cfg_path, schema=FakeSchema) + assert config.paths["data_dir"] == "/data/from/env" + + +def test_no_config_path_uses_profile_and_overrides() -> None: + config = compose( + schema=FakeSchema, + profile_dict={"name": "P"}, + overrides=["nested.count=3"], + ) + assert config.name == "P" + assert config.nested.count == 3 + + +def test_bad_override_no_equals() -> None: + with pytest.raises(ConfigError, match="missing '='"): + compose(schema=FakeSchema, overrides=["no_equals_here"]) + + +def test_resolved_root_must_be_mapping(tmp_path: Path) -> None: + cfg_path = tmp_path / "cfg.toml" + cfg_path.write_text("name = 'ok'\n", encoding="utf-8") + # FakeSchema requires a top-level mapping; this is the happy case. + config = compose(cfg_path, schema=FakeSchema) + assert isinstance(config, FakeSchema) + _ = FakeNested # imported for test isolation typing diff --git a/tests/config/compose/test_interpolation_basics.py b/tests/config/compose/test_interpolation_basics.py new file mode 100644 index 0000000..f18b71b --- /dev/null +++ b/tests/config/compose/test_interpolation_basics.py @@ -0,0 +1,120 @@ +"""Core interpolation grammar: absolute and relative references, +list indexing, type preservation, escapes, cycles.""" + +from __future__ import annotations + +import pytest + +from bead.config.compose import InterpolationError, resolve + + +def test_absolute_reference_string() -> None: + root = {"paths": {"data_dir": "/tmp/bead"}} + assert resolve("${paths.data_dir}", root=root) == "/tmp/bead" + + +def test_absolute_reference_typed() -> None: + """Standalone ${...} preserves the referenced value's type.""" + root = {"counts": {"n": 7}} + assert resolve("${counts.n}", root=root) == 7 + assert isinstance(resolve("${counts.n}", root=root), int) + + +def test_substring_substitution_coerces_to_str() -> None: + root = {"counts": {"n": 7}} + assert resolve("you have ${counts.n} items", root=root) == "you have 7 items" + + +def test_relative_reference_one_up() -> None: + """${.x} resolves against the parent of the current node.""" + root = {"section": {"x": "value", "ref": "${.x}"}} + out = resolve(root, root=root) + assert isinstance(out, dict) + section = out["section"] + assert isinstance(section, dict) + assert section["ref"] == "value" + + +def test_relative_reference_two_up() -> None: + root = { + "a": { + "b": {"ref": "${..target}"}, + "target": "found", + } + } + out = resolve(root, root=root) + assert isinstance(out, dict) + a = out["a"] + assert isinstance(a, dict) + b = a["b"] + assert isinstance(b, dict) + assert b["ref"] == "found" + + +def test_list_indexing_bracketed() -> None: + root = {"items": ["zero", "one", "two"]} + assert resolve("${items[1]}", root=root) == "one" + + +def test_list_indexing_dotted() -> None: + root = {"items": ["zero", "one", "two"]} + assert resolve("${items.2}", root=root) == "two" + + +def test_nested_interpolation() -> None: + """The inner ${...} is resolved first, then spliced into the outer path.""" + root = {"which": "alpha", "alpha": "value-A", "beta": "value-B"} + assert resolve("${${which}}", root=root) == "value-A" + + +def test_escape_literal_dollar_brace() -> None: + """\\${literal} produces a literal ${literal}.""" + root: dict = {} + assert resolve("\\${not_resolved}", root=root) == "${not_resolved}" + + +def test_missing_reference_raises() -> None: + root = {"a": {}} + with pytest.raises(InterpolationError, match="unresolved"): + resolve("${a.b}", root=root) + + +def test_cycle_detection() -> None: + root = {"a": "${b}", "b": "${a}"} + with pytest.raises(InterpolationError, match="cycle"): + resolve("${a}", root=root) + + +def test_relative_above_root_raises() -> None: + root = {"x": "${..y}"} + with pytest.raises(InterpolationError, match="above the root"): + resolve(root, root=root) + + +def test_dict_value_resolves_recursively() -> None: + root = { + "paths": {"data_dir": "/tmp"}, + "out": {"items": "${paths.data_dir}/items"}, + } + out = resolve(root, root=root) + assert isinstance(out, dict) + out_section = out["out"] + assert isinstance(out_section, dict) + assert out_section["items"] == "/tmp/items" + + +def test_concatenation_with_multiple_interpolations() -> None: + root = {"a": "X", "b": "Y"} + assert resolve("[${a}_${b}]", root=root) == "[X_Y]" + + +def test_list_index_out_of_range() -> None: + root = {"items": ["a"]} + with pytest.raises(InterpolationError, match="out of range"): + resolve("${items[5]}", root=root) + + +def test_dict_indexed_with_integer_raises() -> None: + root = {"section": {"foo": "bar"}} + with pytest.raises(InterpolationError, match="dict"): + resolve("${section[0]}", root=root) diff --git a/tests/config/compose/test_interpolation_custom.py b/tests/config/compose/test_interpolation_custom.py new file mode 100644 index 0000000..ed91153 --- /dev/null +++ b/tests/config/compose/test_interpolation_custom.py @@ -0,0 +1,43 @@ +"""Custom resolver registration.""" + +from __future__ import annotations + +import pytest + +from bead.config.compose import ( + list_resolvers, + register_resolver, + resolve, + unregister_resolver, +) + + +def test_register_and_use() -> None: + register_resolver("test.upper", lambda s: s.upper(), replace=True) + try: + assert resolve("${test.upper:hello}", root={}) == "HELLO" + assert "test.upper" in list_resolvers() + finally: + unregister_resolver("test.upper") + + +def test_register_existing_without_replace_raises() -> None: + register_resolver("test.echo", lambda s: s, replace=True) + try: + with pytest.raises(ValueError, match="already registered"): + register_resolver("test.echo", lambda s: s + "!") + finally: + unregister_resolver("test.echo") + + +def test_register_replace_ok() -> None: + register_resolver("test.x", lambda s: s + "a", replace=True) + register_resolver("test.x", lambda s: s + "b", replace=True) + try: + assert resolve("${test.x:y}", root={}) == "yb" + finally: + unregister_resolver("test.x") + + +def test_unregister_unknown_is_noop() -> None: + unregister_resolver("test.never_registered") diff --git a/tests/config/compose/test_interpolation_resolvers.py b/tests/config/compose/test_interpolation_resolvers.py new file mode 100644 index 0000000..9806774 --- /dev/null +++ b/tests/config/compose/test_interpolation_resolvers.py @@ -0,0 +1,51 @@ +"""Built-in resolver behaviour: oc.env, oc.decode, etc.""" + +from __future__ import annotations + +import base64 +import os + +import pytest + +from bead.config.compose import InterpolationError, resolve + + +def test_oc_env_set(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("BEAD_TEST_VAR", "hello") + assert resolve("${oc.env:BEAD_TEST_VAR}", root={}) == "hello" + + +def test_oc_env_default(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("BEAD_TEST_VAR", raising=False) + assert resolve("${oc.env:BEAD_TEST_VAR,fallback}", root={}) == "fallback" + + +def test_oc_env_missing_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("BEAD_TEST_VAR", raising=False) + with pytest.raises(InterpolationError, match="not set"): + resolve("${oc.env:BEAD_TEST_VAR}", root={}) + + +def test_oc_decode_base64() -> None: + encoded = base64.b64encode(b"secret").decode("ascii") + assert resolve(f"${{oc.decode:{encoded}}}", root={}) == "secret" + + +def test_oc_decode_unknown_encoding() -> None: + with pytest.raises(InterpolationError, match="unknown encoding"): + resolve("${oc.decode:foo,rot13}", root={}) + + +def test_unknown_resolver() -> None: + with pytest.raises(InterpolationError, match="Unknown resolver"): + resolve("${nope:x}", root={}) + + +def test_resolver_arg_is_interpolated() -> None: + """Resolver arguments themselves resolve before the resolver runs.""" + os.environ["BEAD_T_NAME"] = "WORLD" + try: + root = {"key": "BEAD_T_NAME"} + assert resolve("${oc.env:${key}}", root=root) == "WORLD" + finally: + os.environ.pop("BEAD_T_NAME", None) diff --git a/tests/config/test_loader.py b/tests/config/test_loader.py deleted file mode 100644 index e8018c3..0000000 --- a/tests/config/test_loader.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Tests for configuration loading from YAML files.""" - -from pathlib import Path - -import pytest -import yaml -from didactic.api import ValidationError - -from bead.config.config import BeadConfig -from bead.config.loader import load_config, load_yaml_file, merge_configs - - -class TestMergeConfigs: - """Tests for merge_configs function.""" - - def test_merge_flat_dicts(self) -> None: - """Test merging flat dictionaries.""" - base = {"a": 1, "b": 2} - override = {"b": 3, "c": 4} - result = merge_configs(base, override) - assert result == {"a": 1, "b": 3, "c": 4} - - def test_merge_nested_dicts(self) -> None: - """Test deep merge of nested dictionaries.""" - base = {"a": 1, "b": {"c": 2, "d": 3}} - override = {"b": {"d": 4, "e": 5}} - result = merge_configs(base, override) - assert result == {"a": 1, "b": {"c": 2, "d": 4, "e": 5}} - - def test_merge_deeply_nested_dicts(self) -> None: - """Test deep merge with multiple nesting levels.""" - base = {"a": {"b": {"c": 1}}} - override = {"a": {"b": {"d": 2}}} - result = merge_configs(base, override) - assert result == {"a": {"b": {"c": 1, "d": 2}}} - - def test_merge_override_replaces_non_dict(self) -> None: - """Test that non-dict values are replaced entirely.""" - base = {"a": {"b": 1}} - override = {"a": 2} - result = merge_configs(base, override) - assert result == {"a": 2} - - def test_merge_with_empty_override(self) -> None: - """Test merge with empty override dictionary.""" - base = {"a": 1, "b": 2} - override: dict[str, int] = {} - result = merge_configs(base, override) - assert result == {"a": 1, "b": 2} - - def test_merge_with_empty_base(self) -> None: - """Test merge with empty base dictionary.""" - base: dict[str, int] = {} - override = {"a": 1, "b": 2} - result = merge_configs(base, override) - assert result == {"a": 1, "b": 2} - - def test_merge_preserves_types(self) -> None: - """Test that types are preserved during merge.""" - base = {"a": 1, "b": "string", "c": [1, 2, 3]} - override = {"d": 2.5} - result = merge_configs(base, override) - assert result["a"] == 1 - assert result["b"] == "string" - assert result["c"] == [1, 2, 3] - assert result["d"] == 2.5 - - def test_merge_with_none_values(self) -> None: - """Test merge with None values in override.""" - base = {"a": 1, "b": 2} - override = {"b": None, "c": 3} - result = merge_configs(base, override) - assert result == {"a": 1, "b": None, "c": 3} - - -class TestLoadYamlFile: - """Tests for load_yaml_file function.""" - - def test_load_valid_yaml(self, tmp_path: Path) -> None: - """Test loading a valid YAML file.""" - config_file = tmp_path / "config.yaml" - config_file.write_text("profile: test\nlogging:\n level: DEBUG\n") - result = load_yaml_file(config_file) - assert result == {"profile": "test", "logging": {"level": "DEBUG"}} - - def test_load_empty_yaml(self, tmp_path: Path) -> None: - """Test loading an empty YAML file.""" - config_file = tmp_path / "empty.yaml" - config_file.write_text("") - result = load_yaml_file(config_file) - assert result == {} - - def test_load_complex_nested_yaml(self, tmp_path: Path) -> None: - """Test loading complex nested YAML structure.""" - yaml_content = """ -profile: dev -paths: - data_dir: /data - output_dir: /output - cache_dir: /cache -logging: - level: INFO - console: true - file: false -""" - config_file = tmp_path / "complex.yaml" - config_file.write_text(yaml_content) - result = load_yaml_file(config_file) - assert result["profile"] == "dev" - assert result["paths"]["data_dir"] == "/data" - assert result["logging"]["level"] == "INFO" - assert result["logging"]["console"] is True - - def test_load_nonexistent_file(self, tmp_path: Path) -> None: - """Test loading a file that doesn't exist.""" - config_file = tmp_path / "nonexistent.yaml" - with pytest.raises(FileNotFoundError, match="Configuration file not found"): - load_yaml_file(config_file) - - def test_load_malformed_yaml(self, tmp_path: Path) -> None: - """Test loading malformed YAML file.""" - config_file = tmp_path / "malformed.yaml" - config_file.write_text("profile: test\n invalid:\n this is: [not valid") - with pytest.raises(yaml.YAMLError, match="Failed to parse YAML"): - load_yaml_file(config_file) - - def test_load_yaml_with_string_path(self, tmp_path: Path) -> None: - """Test loading YAML with string path instead of Path object.""" - config_file = tmp_path / "config.yaml" - config_file.write_text("profile: test\n") - result = load_yaml_file(str(config_file)) - assert result == {"profile": "test"} - - def test_load_yaml_with_lists(self, tmp_path: Path) -> None: - """Test loading YAML containing lists.""" - yaml_content = """ -items: - - name: item1 - value: 1 - - name: item2 - value: 2 -""" - config_file = tmp_path / "lists.yaml" - config_file.write_text(yaml_content) - result = load_yaml_file(config_file) - assert len(result["items"]) == 2 - assert result["items"][0]["name"] == "item1" - - -class TestLoadConfig: - """Tests for load_config function.""" - - def test_load_config_from_profile_default(self) -> None: - """Test loading config from default profile.""" - config = load_config(profile="default") - assert isinstance(config, BeadConfig) - assert config.profile == "default" - - def test_load_config_from_profile_dev(self) -> None: - """Test loading config from dev profile.""" - config = load_config(profile="dev") - assert isinstance(config, BeadConfig) - assert config.profile == "dev" - - def test_load_config_from_profile_prod(self) -> None: - """Test loading config from prod profile.""" - config = load_config(profile="prod") - assert isinstance(config, BeadConfig) - assert config.profile == "prod" - - def test_load_config_from_profile_test(self) -> None: - """Test loading config from test profile.""" - config = load_config(profile="test") - assert isinstance(config, BeadConfig) - assert config.profile == "test" - - def test_load_config_with_none_path(self) -> None: - """Test loading config with None path uses profile defaults.""" - config = load_config(config_path=None, profile="dev") - assert config.profile == "dev" - assert config.logging.level == "DEBUG" - - def test_load_config_from_yaml_file(self, tmp_path: Path) -> None: - """Test loading config from YAML file.""" - config_file = tmp_path / "config.yaml" - config_file.write_text( - """ -profile: test -logging: - level: WARNING -""" - ) - config = load_config(config_path=config_file) - assert config.logging.level == "WARNING" - - def test_load_config_with_yaml_and_profile(self, tmp_path: Path) -> None: - """Test loading config with both YAML file and profile.""" - config_file = tmp_path / "config.yaml" - config_file.write_text("logging:\n level: ERROR\n") - config = load_config(config_path=config_file, profile="dev") - # YAML should override profile - assert config.logging.level == "ERROR" - # But profile should still be set - assert config.profile == "dev" - - def test_load_config_with_single_level_override(self) -> None: - """Test loading config with single-level keyword override.""" - config = load_config(logging__console=False) - assert config.logging.console is False - - def test_load_config_with_nested_override(self) -> None: - """Test loading config with nested keyword override.""" - config = load_config(profile="default", logging__level="CRITICAL") - assert config.logging.level == "CRITICAL" - - def test_load_config_with_deeply_nested_override(self) -> None: - """Test loading config with deeply nested keyword override.""" - config = load_config(profile="default", paths__data_dir="/custom/data/path") - assert str(config.paths.data_dir) == "/custom/data/path" - - def test_load_config_precedence_profile_yaml_override(self, tmp_path: Path) -> None: - """Test configuration precedence: profile < yaml < override.""" - config_file = tmp_path / "config.yaml" - config_file.write_text("logging:\n level: WARNING\n") - config = load_config( - config_path=config_file, profile="dev", logging__level="CRITICAL" - ) - # Override should win - assert config.logging.level == "CRITICAL" - - def test_load_config_with_nonexistent_file(self, tmp_path: Path) -> None: - """Test loading config with nonexistent file raises error.""" - config_file = tmp_path / "nonexistent.yaml" - with pytest.raises(FileNotFoundError): - load_config(config_path=config_file) - - def test_load_config_with_malformed_yaml(self, tmp_path: Path) -> None: - """Test loading config with malformed YAML raises error.""" - config_file = tmp_path / "malformed.yaml" - config_file.write_text("this is: [not: valid yaml") - with pytest.raises(yaml.YAMLError): - load_config(config_path=config_file) - - def test_load_config_with_invalid_values(self, tmp_path: Path) -> None: - """Test loading config with invalid values raises ValidationError.""" - config_file = tmp_path / "invalid.yaml" - config_file.write_text("logging:\n level: INVALID_LEVEL\n") - with pytest.raises(ValidationError): - load_config(config_path=config_file) - - def test_load_config_multiple_overrides(self) -> None: - """Test loading config with multiple keyword overrides.""" - config = load_config( - profile="default", - logging__level="DEBUG", - logging__console=False, - paths__data_dir="/data", - ) - assert config.logging.level == "DEBUG" - assert config.logging.console is False - assert str(config.paths.data_dir) == "/data" - - def test_load_config_with_string_path(self, tmp_path: Path) -> None: - """Test loading config with string path instead of Path object.""" - config_file = tmp_path / "config.yaml" - config_file.write_text("profile: test\n") - config = load_config(config_path=str(config_file)) - assert config.profile == "test" diff --git a/tests/protocol/test_encoding.py b/tests/protocol/test_encoding.py index 17b7f56..5071d0f 100644 --- a/tests/protocol/test_encoding.py +++ b/tests/protocol/test_encoding.py @@ -19,6 +19,7 @@ def test_str_values(self) -> None: assert ScaleType.BINARY.value == "binary" assert ScaleType.ORDINAL.value == "ordinal" assert ScaleType.NOMINAL.value == "nominal" + assert ScaleType.FORCED_CHOICE.value == "forced_choice" class TestResponseEncoding: @@ -57,6 +58,41 @@ def test_index_out_of_range(self) -> None: with pytest.raises(IndexError): enc.index_to_label(5) + def test_forced_choice_encoding(self) -> None: + enc = ResponseEncoding( + name="acceptability", + n_levels=2, + scale_type=ScaleType.FORCED_CHOICE, + labels=("first", "second"), + ) + assert enc.is_forced_choice + assert not enc.is_binary + + def test_forced_choice_requires_two_or_more_levels(self) -> None: + with pytest.raises(Exception, match="at least 2 levels"): + ResponseEncoding( + name="x", + n_levels=1, + scale_type=ScaleType.FORCED_CHOICE, + labels=("only",), + ) + + def test_encode_response_space_with_explicit_forced_choice(self) -> None: + rs = ResponseSpace( + options=("first", "second"), + is_ordered=False, + scale_type=ScaleType.FORCED_CHOICE, + ) + enc = encode_response_space("acceptability", rs) + assert enc.is_forced_choice + + def test_encode_response_space_scale_type_arg_overrides(self) -> None: + rs = ResponseSpace(options=("a", "b"), is_ordered=False) + enc = encode_response_space( + "x", rs, scale_type=ScaleType.FORCED_CHOICE + ) + assert enc.is_forced_choice + def test_scale_predicates(self) -> None: enc = self._build() assert enc.is_ordinal is True diff --git a/tests/protocol/test_items_bridge.py b/tests/protocol/test_items_bridge.py index ae2ec94..37fde33 100644 --- a/tests/protocol/test_items_bridge.py +++ b/tests/protocol/test_items_bridge.py @@ -4,6 +4,7 @@ import pytest +from bead.active_learning.models import ForcedChoiceModel, model_class_for_encoding from bead.items.item_template import ItemTemplate, PresentationSpec from bead.protocol import ( AnnotationProtocol, @@ -13,6 +14,7 @@ ResponseSpace, ScaleType, SemanticAnchor, + encode_response_space, family_to_item_template, protocol_to_item_templates, realization_to_item, @@ -33,6 +35,49 @@ def test_ordinal_maps(self) -> None: def test_nominal_maps(self) -> None: assert scale_type_to_task_type(ScaleType.NOMINAL) == "categorical" + def test_forced_choice_maps(self) -> None: + assert ( + scale_type_to_task_type(ScaleType.FORCED_CHOICE) + == "forced_choice" + ) + + +def _build_forced_choice_anchor() -> SemanticAnchor: + return SemanticAnchor( + name="acceptability", + target_property="acceptability", + canonical_prompt="Which sentence sounds more natural?", + response_space=ResponseSpace( + options=("first", "second"), + is_ordered=False, + scale_type=ScaleType.FORCED_CHOICE, + ), + required_keywords=frozenset({"natural"}), + ) + + +class TestForcedChoiceFamilyTemplate: + """family_to_item_template handles forced_choice anchors.""" + + def test_forced_choice_template(self) -> None: + family = QuestionFamily(anchor=_build_forced_choice_anchor()) + template = family_to_item_template( + family, judgment_type="acceptability" + ) + assert template.task_type == "forced_choice" + # forced-choice templates carry no per-template options; + # the per-item alternatives live on each Item. + assert template.task_spec.options is None + assert template.task_spec.scale_bounds is None + assert template.task_spec.scale_labels == () + + def test_model_class_for_forced_choice_encoding(self) -> None: + anchor = _build_forced_choice_anchor() + encoding = encode_response_space( + anchor.name, anchor.response_space + ) + assert model_class_for_encoding(encoding) is ForcedChoiceModel + def _build_binary_anchor() -> SemanticAnchor: return SemanticAnchor( diff --git a/uv.lock b/uv.lock index 0098559..da0e0c8 100644 --- a/uv.lock +++ b/uv.lock @@ -157,7 +157,7 @@ wheels = [ [[package]] name = "bead" -version = "0.4.0" +version = "0.5.0" source = { editable = "." } dependencies = [ { name = "accelerate" }, From 325f7fac2e0eee4ed2332b9d9418c561748fdc95 Mon Sep 17 00:00:00 2001 From: Aaron Steven White Date: Tue, 12 May 2026 12:08:19 -0400 Subject: [PATCH 3/4] Fixes CI lint+format: drop unused json imports, wrap long lines --- bead/active_learning/models/forced_choice.py | 1 - bead/active_learning/models/free_text.py | 1 - bead/active_learning/models/magnitude.py | 1 - bead/active_learning/models/ordinal_scale.py | 1 - bead/cli/items_factories.py | 24 +++++++++++++++----- tests/protocol/test_encoding.py | 4 +--- tests/protocol/test_items_bridge.py | 13 +++-------- 7 files changed, 22 insertions(+), 23 deletions(-) diff --git a/bead/active_learning/models/forced_choice.py b/bead/active_learning/models/forced_choice.py index 376786e..1ca09da 100644 --- a/bead/active_learning/models/forced_choice.py +++ b/bead/active_learning/models/forced_choice.py @@ -2,7 +2,6 @@ from __future__ import annotations -import json import tempfile from pathlib import Path diff --git a/bead/active_learning/models/free_text.py b/bead/active_learning/models/free_text.py index 69feba9..6b77ea5 100644 --- a/bead/active_learning/models/free_text.py +++ b/bead/active_learning/models/free_text.py @@ -9,7 +9,6 @@ from __future__ import annotations -import json from pathlib import Path import numpy as np diff --git a/bead/active_learning/models/magnitude.py b/bead/active_learning/models/magnitude.py index 662bacf..0dd1ab0 100644 --- a/bead/active_learning/models/magnitude.py +++ b/bead/active_learning/models/magnitude.py @@ -8,7 +8,6 @@ from __future__ import annotations -import json import tempfile from pathlib import Path diff --git a/bead/active_learning/models/ordinal_scale.py b/bead/active_learning/models/ordinal_scale.py index f2ab1c7..224c7ec 100644 --- a/bead/active_learning/models/ordinal_scale.py +++ b/bead/active_learning/models/ordinal_scale.py @@ -6,7 +6,6 @@ from __future__ import annotations -import json import tempfile from pathlib import Path diff --git a/bead/cli/items_factories.py b/bead/cli/items_factories.py index 93bc61c..3289e3d 100644 --- a/bead/cli/items_factories.py +++ b/bead/cli/items_factories.py @@ -156,7 +156,9 @@ def create_forced_choice_from_texts( """ try: # Load texts - texts: list[str] = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] + texts: list[str] = [ + line.strip() for line in texts_file.read_text().splitlines() if line.strip() + ] print_info(f"Loaded {len(texts)} texts") # Create items by generating all combinations of n_alternatives from texts @@ -286,7 +288,9 @@ def create_ordinal_scale_from_texts( """ try: # Load texts - texts = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] + texts = [ + line.strip() for line in texts_file.read_text().splitlines() if line.strip() + ] print_info(f"Loaded {len(texts)} texts") # Create items @@ -459,7 +463,9 @@ def create_binary_from_texts( """ try: # Load texts - texts = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] + texts = [ + line.strip() for line in texts_file.read_text().splitlines() if line.strip() + ] print_info(f"Loaded {len(texts)} texts") # Create items @@ -535,7 +541,9 @@ def create_multi_select_from_texts( """ try: # Load texts - texts: list[str] = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] + texts: list[str] = [ + line.strip() for line in texts_file.read_text().splitlines() if line.strip() + ] print_info(f"Loaded {len(texts)} texts") # Parse options @@ -619,7 +627,9 @@ def create_magnitude_from_texts( """ try: # Load texts - texts = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] + texts = [ + line.strip() for line in texts_file.read_text().splitlines() if line.strip() + ] print_info(f"Loaded {len(texts)} texts") # Create items @@ -679,7 +689,9 @@ def create_free_text_from_texts( """ try: # Load texts - texts = [line.strip() for line in texts_file.read_text().splitlines() if line.strip()] + texts = [ + line.strip() for line in texts_file.read_text().splitlines() if line.strip() + ] print_info(f"Loaded {len(texts)} texts") # Create items diff --git a/tests/protocol/test_encoding.py b/tests/protocol/test_encoding.py index 5071d0f..76bce22 100644 --- a/tests/protocol/test_encoding.py +++ b/tests/protocol/test_encoding.py @@ -88,9 +88,7 @@ def test_encode_response_space_with_explicit_forced_choice(self) -> None: def test_encode_response_space_scale_type_arg_overrides(self) -> None: rs = ResponseSpace(options=("a", "b"), is_ordered=False) - enc = encode_response_space( - "x", rs, scale_type=ScaleType.FORCED_CHOICE - ) + enc = encode_response_space("x", rs, scale_type=ScaleType.FORCED_CHOICE) assert enc.is_forced_choice def test_scale_predicates(self) -> None: diff --git a/tests/protocol/test_items_bridge.py b/tests/protocol/test_items_bridge.py index 37fde33..d89a4b6 100644 --- a/tests/protocol/test_items_bridge.py +++ b/tests/protocol/test_items_bridge.py @@ -36,10 +36,7 @@ def test_nominal_maps(self) -> None: assert scale_type_to_task_type(ScaleType.NOMINAL) == "categorical" def test_forced_choice_maps(self) -> None: - assert ( - scale_type_to_task_type(ScaleType.FORCED_CHOICE) - == "forced_choice" - ) + assert scale_type_to_task_type(ScaleType.FORCED_CHOICE) == "forced_choice" def _build_forced_choice_anchor() -> SemanticAnchor: @@ -61,9 +58,7 @@ class TestForcedChoiceFamilyTemplate: def test_forced_choice_template(self) -> None: family = QuestionFamily(anchor=_build_forced_choice_anchor()) - template = family_to_item_template( - family, judgment_type="acceptability" - ) + template = family_to_item_template(family, judgment_type="acceptability") assert template.task_type == "forced_choice" # forced-choice templates carry no per-template options; # the per-item alternatives live on each Item. @@ -73,9 +68,7 @@ def test_forced_choice_template(self) -> None: def test_model_class_for_forced_choice_encoding(self) -> None: anchor = _build_forced_choice_anchor() - encoding = encode_response_space( - anchor.name, anchor.response_space - ) + encoding = encode_response_space(anchor.name, anchor.response_space) assert model_class_for_encoding(encoding) is ForcedChoiceModel From ce0a4ec91ad40e83a3ffb4eb7ed9aadf03927716 Mon Sep 17 00:00:00 2001 From: Aaron Steven White Date: Tue, 12 May 2026 12:29:41 -0400 Subject: [PATCH 4/4] Fixes pyright 1.1.409: Iterator -> Generator on @contextmanager --- bead/config/compose/interpolation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bead/config/compose/interpolation.py b/bead/config/compose/interpolation.py index 9acb86c..3d9b8a6 100644 --- a/bead/config/compose/interpolation.py +++ b/bead/config/compose/interpolation.py @@ -29,7 +29,7 @@ from __future__ import annotations -from collections.abc import Callable, Iterator +from collections.abc import Callable, Generator from contextlib import contextmanager from contextvars import ContextVar from dataclasses import dataclass, field @@ -84,7 +84,7 @@ def active_root() -> dict[str, ComposeValue] | None: @contextmanager def _activate_root( root: dict[str, ComposeValue], -) -> Iterator[None]: +) -> Generator[None]: token = _ACTIVE_ROOT.set(root) try: yield