diff --git a/CHANGELOG.md b/CHANGELOG.md index 44aa7e3..1b78669 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,143 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.0] - 2026-05-07 + +### Added + +#### Pipeline-wide integration of the protocol layer + +- `bead.labels` is the single canonical home for the + `[[label]]` / `[[label:text]]` / `[[label|transform]]` syntax. + `parse_label_refs`, `find_label_names`, and `replace_label_refs` + replace the three independent regex implementations that previously + lived in `bead.protocol.drift`, `bead.deployment.jspsych.trials`, + and `bead.items.span_labeling`. +- `bead.config.protocol.ProtocolConfig` plugs into `BeadConfig.protocol` + with declarative TOML/YAML configuration: anchor specs, drift + settings, realization strategies (template / contextual / lm), and + family composition. `ProtocolConfig.build(lm_client=..., cache=...)` + materializes a live `AnnotationProtocol`. +- `bead.protocol.items` provides the canonical + `QuestionRealization → Item` and protocol-wide + `family_to_item_template` / `protocol_to_item_templates` / + `realize_protocol_to_items` bridges, plus `scale_type_to_task_type` + as the single canonical mapping from `ScaleType` to `TaskType`. +- `bead.active_learning.models.registry` exposes + `MODEL_CLASSES` / `CONFIG_CLASSES` and + `model_class_for_task_type` / `config_class_for_task_type` / + `model_class_for_encoding` / `config_class_for_encoding` as the + single canonical task-type → model-class / config-class registry. + `bead.cli.models` and `bead.cli.training` consume the registry + directly, replacing two parallel string-keyed dicts and a dynamic + `_import_class` helper. +- `bead.deployment.protocol_trials.protocol_to_jspsych_trials` is the + canonical end-to-end bridge from an `AnnotationProtocol` and a + sequence of `ProtocolContext` records to a flat list of jsPsych + trial dicts. +- `bead.data_collection.jatos_results_to_annotation_records` converts + raw JATOS results into `AnnotationRecord` instances, the input + shape consumed by `annotator_reliability` and + `InterAnnotatorMetrics`. +- `bead protocol` CLI subcommand: `bead protocol validate`, + `bead protocol realize`, `bead protocol items` drive the + configured protocol from the shell. + +### Changed + +- `LMRealization` accepts a `ModelOutputCache` (the bead-wide + content-addressable cache) via its required `cache` keyword and a + required `model_name` keyword for cache-key isolation. The internal + FIFO dict and the `cache` / `max_cache_size` / `clear_cache` / + `cache_size` parameters and methods are removed; the + `ModelOutputCache` is the single canonical caching surface. +- `bead.cli.models` no longer maintains `TASK_TYPE_MODELS` / + `TASK_TYPE_CONFIGS` string-path dicts or the `_import_class` + helper; they are replaced by direct calls into + `bead.active_learning.models.registry`. `bead.cli.training` follows + the same pattern. +- `bead.deployment.jspsych.trials._parse_prompt_references`, + `_SpanReference`, `_SPAN_REF_PATTERN`, and the duplicated + `_SPAN_REF_PATTERN` in `bead.items.span_labeling` are removed in + favor of `bead.labels.parse_label_refs` / `LabelRef`. + +#### `bead.protocol`: annotation protocol primitives + +A new top-level package providing a type-theoretic stack for defining +annotation protocols: anchors as types, contexts as dependent +indices, realization strategies as computational content, and drift +guards as type-checkers. + +- `bead.protocol.anchor` defines `SemanticAnchor` (the type-level + spec of a question, with required span labels, required keywords, + optional embedding center and `max_drift`) and `ResponseSpace` / + `SemanticPoles`. +- `bead.protocol.context` defines a generic `ProtocolContext` and + `ContextItem` plus a module-level **predicate registry** + (`register_context_predicate`, `get_context_predicate`, + `list_context_predicates`) for callers to register named context + predicates at import time. +- `bead.protocol.realization` provides `RealizationStrategy` + (`typing.Protocol`), `TemplateRealization`, + `ContextualTemplateRealization` (rule-based selection from ranked + variants), and `LMRealization` (with caching and FIFO eviction) + plus an `LMClient` `Protocol` with explicit + `temperature` / `max_tokens` keyword parameters. +- `bead.protocol.drift` defines `DriftScore`, the `DriftValidator` + `Protocol`, and three concrete validators + (`StructuralDriftValidator`, `EmbeddingDriftValidator`, + `PerplexityDriftValidator`) plus a composite `DriftGuard`. The + embedding and perplexity validators consume narrow + `EmbeddingAdapter` / `PerplexityAdapter` `Protocol`s, so any object + exposing the right method (including bead's + `bead.items.adapters.ModelAdapter`) conforms. +- `bead.protocol.family` defines `QuestionFamily` (with explicit + `depends_on` for conditional dependencies) and `AnnotationProtocol` + (the iterated dependent product), with `realize_all` threading + responses through the context. `AnnotationProtocol` rejects + duplicate anchor names, self-dependencies, and forward / unknown + `depends_on` references at construction and on `append`. +- `bead.protocol.encoding` defines `ScaleType` + (`StrEnum: binary / ordinal / nominal`) and `ResponseEncoding` (with + invariant validators for `n_levels == len(labels)`, label + uniqueness, and `BINARY` having exactly 2 levels), plus + `encode_response_space` as the bridge from `ResponseSpace`. +- `bead.protocol.diagnostics` defines `DiagnosticLevel`, + `DiagnosticRecord`, `DatasetReport` (immutable, with `with_*` + mutators), `ConditionalObservationValidator` (which operates on + `AnnotationProtocol.depends_on`), and the `RecordLike` `Protocol` + for the structural record shape consumed by the validator. +- `LMRealization` raises `RuntimeError` on backend failures and on + empty / whitespace-only responses (instead of caching an empty + string). + +#### `bead.evaluation.reliability`: per-annotator reliability + +- `AnnotationRecord` is a `BeadBaseModel` with the canonical + `(annotator_id, item_id, question_name, response_label)` shape. +- `annotator_reliability(records, encodings=...)` returns + per-annotator response distributions and Shannon entropy in bits, + optionally filtering unrecognized labels. +- `low_entropy_annotators(profiles, threshold=...)` flags annotators + who collapse the response space. + +### Documentation + +- `docs/api/protocol.md` and `docs/api/evaluation.md` updates expose + the new modules through `mkdocstrings`. +- `docs/user-guide/protocols.md` walks through anchors, contexts + (including the predicate registry and per-dependent attributes), + the three realization strategies, drift validation (with the named + `EmbeddingAdapter` and `PerplexityAdapter` Protocols), protocol + composition, the structural construction-time invariants, the + `encode_response_space` bridge to modeling, conditional-observation + diagnostics (including the `RecordLike` Protocol), and reliability. +- The protocol layer is cross-linked from + `docs/user-guide/concepts.md`, `docs/user-guide/index.md`, + `docs/index.md`, the project `README.md`, and a new "Protocol layer" + paragraph in `docs/developer-guide/architecture.md` that places it + as a cross-cutting layer feeding into the existing 6-stage pipeline. + ## [0.3.0] - 2026-05-06 ### Changed diff --git a/README.md b/README.md index dce2d9a..e8c479c 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,7 @@ lists.save("lists/experiment.jsonl") - **Constraint satisfaction**: batch and list-level constraints for balanced designs - **Model integration**: HuggingFace, OpenAI, Anthropic with caching - **Active learning**: uncertainty sampling with convergence detection +- **Annotation protocols**: type-theoretic stack of `SemanticAnchor` (the question type), `ProtocolContext` (the dependent index), `RealizationStrategy` (template / contextual / LM phrasings), and `DriftGuard` (the type-checker over realized prompts), composed into conditional `AnnotationProtocol`s - **jsPsych 8.x**: Material Design UI with JATOS deployment ## CLI diff --git a/bead/__init__.py b/bead/__init__.py index 9e9d745..74233be 100644 --- a/bead/__init__.py +++ b/bead/__init__.py @@ -6,6 +6,6 @@ from __future__ import annotations -__version__ = "0.3.0" +__version__ = "0.4.0" __author__ = "Aaron Steven White" __email__ = "aaron.white@rochester.edu" diff --git a/bead/active_learning/config.py b/bead/active_learning/config.py index 49eb4ad..33221d5 100644 --- a/bead/active_learning/config.py +++ b/bead/active_learning/config.py @@ -7,9 +7,9 @@ import didactic.api as dx __all__ = [ - "VarianceComponents", - "RandomEffectsSpec", "MixedEffectsConfig", + "RandomEffectsSpec", + "VarianceComponents", ] diff --git a/bead/active_learning/models/__init__.py b/bead/active_learning/models/__init__.py index 85573b5..b90a8b0 100644 --- a/bead/active_learning/models/__init__.py +++ b/bead/active_learning/models/__init__.py @@ -9,8 +9,19 @@ from bead.active_learning.models.magnitude import MagnitudeModel from bead.active_learning.models.multi_select import MultiSelectModel from bead.active_learning.models.ordinal_scale import OrdinalScaleModel +from bead.active_learning.models.registry import ( + CONFIG_CLASSES, + MODEL_CLASSES, + ModelConfig, + config_class_for_encoding, + config_class_for_task_type, + model_class_for_encoding, + model_class_for_task_type, +) __all__ = [ + "CONFIG_CLASSES", + "MODEL_CLASSES", "ActiveLearningModel", "BinaryModel", "CategoricalModel", @@ -18,7 +29,12 @@ "ForcedChoiceModel", "FreeTextModel", "MagnitudeModel", + "ModelConfig", "ModelPrediction", "MultiSelectModel", "OrdinalScaleModel", + "config_class_for_encoding", + "config_class_for_task_type", + "model_class_for_encoding", + "model_class_for_task_type", ] diff --git a/bead/active_learning/models/registry.py b/bead/active_learning/models/registry.py new file mode 100644 index 0000000..4e71087 --- /dev/null +++ b/bead/active_learning/models/registry.py @@ -0,0 +1,182 @@ +"""Single canonical registry mapping task types to active-learning models. + +bead's eight task types each correspond to exactly one +:class:`~bead.active_learning.models.base.ActiveLearningModel` subclass +and one +:class:`~bead.config.active_learning.BaseEncoderModelConfig`-derived +config class. This module exposes those two mappings as the single +source of truth used by: + +- :mod:`bead.cli.models` (CLI training commands) +- :mod:`bead.protocol.items` (protocol-layer integration) +- :func:`model_for_encoding` (protocol-encoding-driven model selection) + +There is no other place in the codebase that maps task types to model +or config classes. Adding a new task type requires updating both +mappings here and registering the new model module in +:mod:`bead.active_learning.models`. +""" + +from __future__ import annotations + +from typing import Final + +from bead.active_learning.models.base import ActiveLearningModel +from bead.active_learning.models.binary import BinaryModel +from bead.active_learning.models.categorical import CategoricalModel +from bead.active_learning.models.cloze import ClozeModel +from bead.active_learning.models.forced_choice import ForcedChoiceModel +from bead.active_learning.models.free_text import FreeTextModel +from bead.active_learning.models.magnitude import MagnitudeModel +from bead.active_learning.models.multi_select import MultiSelectModel +from bead.active_learning.models.ordinal_scale import OrdinalScaleModel +from bead.config.active_learning import ( + BinaryModelConfig, + CategoricalModelConfig, + ClozeModelConfig, + ForcedChoiceModelConfig, + FreeTextModelConfig, + MagnitudeModelConfig, + MultiSelectModelConfig, + OrdinalScaleModelConfig, +) +from bead.items.item_template import TaskType +from bead.protocol.encoding import ResponseEncoding +from bead.protocol.items import scale_type_to_task_type + +type ModelConfig = ( + BinaryModelConfig + | CategoricalModelConfig + | ClozeModelConfig + | ForcedChoiceModelConfig + | FreeTextModelConfig + | MagnitudeModelConfig + | MultiSelectModelConfig + | OrdinalScaleModelConfig +) +"""Union of every active-learning model-config class.""" + + +MODEL_CLASSES: Final[dict[TaskType, type[ActiveLearningModel]]] = { + "binary": BinaryModel, + "categorical": CategoricalModel, + "cloze": ClozeModel, + "forced_choice": ForcedChoiceModel, + "free_text": FreeTextModel, + "magnitude": MagnitudeModel, + "multi_select": MultiSelectModel, + "ordinal_scale": OrdinalScaleModel, +} +"""The canonical task-type → model-class mapping. + +Add a new task type by appending an entry here and a matching entry +in :data:`CONFIG_CLASSES`. Every keyed task type must be a +``TaskType`` literal (the ``"span_labeling"`` task type has no +active-learning model and is intentionally absent). +""" + + +CONFIG_CLASSES: Final[dict[TaskType, type[ModelConfig]]] = { + "binary": BinaryModelConfig, + "categorical": CategoricalModelConfig, + "cloze": ClozeModelConfig, + "forced_choice": ForcedChoiceModelConfig, + "free_text": FreeTextModelConfig, + "magnitude": MagnitudeModelConfig, + "multi_select": MultiSelectModelConfig, + "ordinal_scale": OrdinalScaleModelConfig, +} +"""The canonical task-type → config-class mapping.""" + + +def model_class_for_task_type(task_type: TaskType) -> type[ActiveLearningModel]: + """Return the model class registered for ``task_type``. + + Parameters + ---------- + task_type : TaskType + Task-type literal. + + Returns + ------- + type[ActiveLearningModel] + The registered subclass. + + Raises + ------ + KeyError + If ``task_type`` has no registered model (for example, + ``"span_labeling"``). + """ + return MODEL_CLASSES[task_type] + + +def config_class_for_task_type(task_type: TaskType) -> type[ModelConfig]: + """Return the config class registered for ``task_type``. + + Parameters + ---------- + task_type : TaskType + Task-type literal. + + Returns + ------- + type[ModelConfig] + The registered config class. + + Raises + ------ + KeyError + If ``task_type`` has no registered config. + """ + return CONFIG_CLASSES[task_type] + + +def model_class_for_encoding( + encoding: ResponseEncoding, +) -> type[ActiveLearningModel]: + """Pick the active-learning model class for a protocol encoding. + + Composes :func:`~bead.protocol.items.scale_type_to_task_type` with + :func:`model_class_for_task_type`. This is the canonical bridge + from a :class:`~bead.protocol.ResponseEncoding` to the model + class that should be trained on responses recorded under that + encoding. + + Parameters + ---------- + encoding : ResponseEncoding + Protocol-side response encoding. + + Returns + ------- + type[ActiveLearningModel] + The matching model class. + + Examples + -------- + >>> from bead.protocol import ResponseSpace, encode_response_space + >>> rs = ResponseSpace(options=("no", "yes"), is_ordered=False) + >>> enc = encode_response_space("dynamicity", rs) + >>> model_class_for_encoding(enc).__name__ + 'BinaryModel' + """ + return model_class_for_task_type(scale_type_to_task_type(encoding.scale_type)) + + +def config_class_for_encoding( + encoding: ResponseEncoding, +) -> type[ModelConfig]: + """Pick the active-learning config class for a protocol encoding. + + Parameters + ---------- + encoding : ResponseEncoding + Protocol-side response encoding. + + Returns + ------- + type[ModelConfig] + The matching config class. + """ + return config_class_for_task_type(scale_type_to_task_type(encoding.scale_type)) diff --git a/bead/cli/main.py b/bead/cli/main.py index 4166716..a9d7c8d 100644 --- a/bead/cli/main.py +++ b/bead/cli/main.py @@ -418,6 +418,7 @@ def _lazy_load(self, cmd_name: str) -> click.Command: "items": ("bead.cli.items", "items"), "lists": ("bead.cli.lists", "lists"), "models": ("bead.cli.models", "models"), + "protocol": ("bead.cli.protocol", "protocol"), "resources": ("bead.cli.resources", "resources"), "shell": ("bead.cli.shell", "shell"), "simulate": ("bead.cli.simulate", "simulate"), diff --git a/bead/cli/models.py b/bead/cli/models.py index 35c9fbb..5526d43 100644 --- a/bead/cli/models.py +++ b/bead/cli/models.py @@ -16,6 +16,11 @@ from rich.table import Table from bead.active_learning.config import MixedEffectsConfig +from bead.active_learning.models import ( + MODEL_CLASSES, + config_class_for_task_type, + model_class_for_task_type, +) from bead.cli.display import ( print_error, print_info, @@ -23,51 +28,10 @@ ) from bead.data.serialization import read_jsonlines from bead.items.item import Item +from bead.items.item_template import TaskType console = Console() -# Task type to model class mapping -TASK_TYPE_MODELS = { - "forced_choice": "bead.active_learning.models.forced_choice.ForcedChoiceModel", - "categorical": "bead.active_learning.models.categorical.CategoricalModel", - "binary": "bead.active_learning.models.binary.BinaryModel", - "multi_select": "bead.active_learning.models.multi_select.MultiSelectModel", - "ordinal_scale": "bead.active_learning.models.ordinal_scale.OrdinalScaleModel", - "magnitude": "bead.active_learning.models.magnitude.MagnitudeModel", - "free_text": "bead.active_learning.models.free_text.FreeTextModel", - "cloze": "bead.active_learning.models.cloze.ClozeModel", -} - -# Config classes for each task type -TASK_TYPE_CONFIGS = { - "forced_choice": "bead.config.active_learning.ForcedChoiceModelConfig", - "categorical": "bead.config.active_learning.CategoricalModelConfig", - "binary": "bead.config.active_learning.BinaryModelConfig", - "multi_select": "bead.config.active_learning.MultiSelectModelConfig", - "ordinal_scale": "bead.config.active_learning.OrdinalScaleModelConfig", - "magnitude": "bead.config.active_learning.MagnitudeModelConfig", - "free_text": "bead.config.active_learning.FreeTextModelConfig", - "cloze": "bead.config.active_learning.ClozeModelConfig", -} - - -def _import_class(module_path: str) -> type: - """Dynamically import a class from module path. - - Parameters - ---------- - module_path : str - Fully qualified path to class (e.g., 'bead.models.forced_choice.Model'). - - Returns - ------- - type - Imported class. - """ - module_name, class_name = module_path.rsplit(".", 1) - module = __import__(module_name, fromlist=[class_name]) - return getattr(module, class_name) - @click.group() def models() -> None: @@ -123,7 +87,7 @@ def models() -> None: @click.option( "--task-type", required=True, - type=click.Choice(list(TASK_TYPE_MODELS.keys())), + type=click.Choice(list(MODEL_CLASSES.keys())), help="Task type for model", ) @click.option( @@ -395,8 +359,8 @@ def train_model( mixed_effects_config = MixedEffectsConfig(mode=mode) # Import model class and config dynamically - model_class = _import_class(TASK_TYPE_MODELS[task_type]) - config_class = _import_class(TASK_TYPE_CONFIGS[task_type]) + model_class = model_class_for_task_type(cast(TaskType, task_type)) + config_class = config_class_for_task_type(cast(TaskType, task_type)) # Build model config config_dict = { @@ -581,22 +545,22 @@ def predict( "Model config missing 'task_type' field. " "This model may have been trained with an older version of bead." ) - print_info("Valid task types: " + ", ".join(TASK_TYPE_MODELS.keys())) + print_info("Valid task types: " + ", ".join(MODEL_CLASSES.keys())) ctx.exit(1) task_type = config_dict["task_type"] - if task_type not in TASK_TYPE_MODELS: + if task_type not in MODEL_CLASSES: print_error( f"Unknown task type '{task_type}' in model config. " - f"Valid types: {', '.join(TASK_TYPE_MODELS.keys())}" + f"Valid types: {', '.join(MODEL_CLASSES.keys())}" ) ctx.exit(1) print_success(f"Detected task type: {task_type}") # Import model class - model_class = _import_class(TASK_TYPE_MODELS[task_type]) - config_class = _import_class(TASK_TYPE_CONFIGS[task_type]) + model_class = model_class_for_task_type(cast(TaskType, task_type)) + config_class = config_class_for_task_type(cast(TaskType, task_type)) model_config = config_class(**config_dict) # Initialize model and load weights @@ -766,22 +730,22 @@ def predict_proba( "Model config missing 'task_type' field. " "This model may have been trained with an older version of bead." ) - print_info("Valid task types: " + ", ".join(TASK_TYPE_MODELS.keys())) + print_info("Valid task types: " + ", ".join(MODEL_CLASSES.keys())) ctx.exit(1) task_type = config_dict["task_type"] - if task_type not in TASK_TYPE_MODELS: + if task_type not in MODEL_CLASSES: print_error( f"Unknown task type '{task_type}' in model config. " - f"Valid types: {', '.join(TASK_TYPE_MODELS.keys())}" + f"Valid types: {', '.join(MODEL_CLASSES.keys())}" ) ctx.exit(1) print_success(f"Detected task type: {task_type}") # Import model class - model_class = _import_class(TASK_TYPE_MODELS[task_type]) - config_class = _import_class(TASK_TYPE_CONFIGS[task_type]) + model_class = model_class_for_task_type(cast(TaskType, task_type)) + config_class = config_class_for_task_type(cast(TaskType, task_type)) model_config = config_class(**config_dict) # Initialize model and load weights diff --git a/bead/cli/protocol.py b/bead/cli/protocol.py new file mode 100644 index 0000000..7ca84e8 --- /dev/null +++ b/bead/cli/protocol.py @@ -0,0 +1,212 @@ +"""CLI commands for the annotation-protocol layer. + +Exposes the :class:`~bead.config.protocol.ProtocolConfig`-driven +workflow as ``bead protocol`` subcommands. Every command operates on +a :class:`~bead.config.config.BeadConfig` loaded via +:func:`bead.config.loader.load_config`, so the same TOML / YAML +configuration drives both Python and CLI invocations. + +Subcommands: + +- ``validate`` reports a per-family summary of the configured protocol + and verifies that it materializes without errors. +- ``realize`` reads protocol contexts from a JSONL file and writes + realized questions (or full :class:`~bead.items.item.Item` objects + when ``--emit-items`` is given) to a JSONL output file. +- ``items`` renders the per-family + :class:`~bead.items.item_template.ItemTemplate` collection to JSONL + for downstream stages. +""" + +from __future__ import annotations + +from pathlib import Path + +import click + +from bead.cli.display import print_error, print_info, print_success +from bead.config.loader import load_config +from bead.data.serialization import read_jsonlines, write_jsonlines +from bead.items.cache import ModelOutputCache +from bead.items.item import Item +from bead.items.item_template import ItemTemplate +from bead.protocol import ( + AnnotationProtocol, + ProtocolContext, + family_to_item_template, + realize_protocol_to_items, +) + + +def _load_protocol(config_path: Path | None, profile: str) -> AnnotationProtocol: + """Load a :class:`BeadConfig` and materialize its protocol.""" + config = load_config(config_path=config_path, profile=profile) + cache = ModelOutputCache( + cache_dir=config.paths.cache_dir / "models", + backend="filesystem", + ) + return config.protocol.build(cache=cache) + + +@click.group() +def protocol() -> None: + r"""Annotation-protocol commands. + + Drive the :class:`~bead.protocol.AnnotationProtocol` declared in + the project's BeadConfig (``protocol`` section): validate the + configuration, realize prompts for a batch of contexts, and emit + the per-family ItemTemplate collection. + + \b + Examples: + # Verify a protocol declaration in bead.toml + $ bead protocol validate + + # Realize prompts for all contexts in contexts.jsonl + $ bead protocol realize contexts.jsonl realizations.jsonl + + # Emit ItemTemplates for downstream item construction + $ bead protocol items --judgment-type acceptability templates.jsonl + """ + + +@protocol.command() +@click.option( + "--config-file", + type=click.Path(exists=True, path_type=Path), + help="Path to a project config file (defaults to bead.toml).", +) +@click.option("--profile", default="default", help="Configuration profile.") +def validate(config_file: Path | None, profile: str) -> None: + """Validate the protocol configuration and report its families. + + Loads the configured :class:`AnnotationProtocol`, prints a one-line + summary per family (anchor name, scale type, number of options, + declared dependencies), and exits non-zero on any construction + error. + """ + try: + proto = _load_protocol(config_file, profile) + except Exception as exc: # noqa: BLE001 + print_error(f"Protocol failed to materialize: {exc}") + raise SystemExit(1) from exc + + print_success(f"Protocol {proto.name!r}: {len(proto)} families") + for family in proto.families: + anchor = family.anchor + rs = anchor.response_space + scale = "ordinal" if rs.is_ordered else "binary" if len(rs) == 2 else "nominal" + deps = ", ".join(family.depends_on) if family.depends_on else "(none)" + print_info( + f" {family.name:20s} scale={scale:8s} " + f"n_options={len(rs):2d} depends_on={deps}" + ) + + +@protocol.command() +@click.argument("contexts_file", type=click.Path(exists=True, path_type=Path)) +@click.argument("output_file", type=click.Path(path_type=Path)) +@click.option( + "--config-file", + type=click.Path(exists=True, path_type=Path), + help="Path to a project config file (defaults to bead.toml).", +) +@click.option("--profile", default="default", help="Configuration profile.") +@click.option( + "--emit-items", + is_flag=True, + help=( + "Emit full Item records bound to per-family ItemTemplates " + "instead of bare QuestionRealizations." + ), +) +@click.option( + "--judgment-type", + default="acceptability", + help=("Judgment type for emitted ItemTemplates (used when --emit-items is set)."), +) +def realize( + contexts_file: Path, + output_file: Path, + config_file: Path | None, + profile: str, + emit_items: bool, + judgment_type: str, +) -> None: + """Realize protocol questions for every context in CONTEXTS_FILE. + + CONTEXTS_FILE is a JSONL file of :class:`ProtocolContext` records + (one per line). OUTPUT_FILE will be written as JSONL with one + record per realized question (skipping non-applicable families). + """ + proto = _load_protocol(config_file, profile) + if len(proto) == 0: + print_error( + "Configured protocol is empty; nothing to realize. Add " + "families to the [protocol.families] section." + ) + raise SystemExit(1) + + contexts = read_jsonlines(contexts_file, ProtocolContext) + + if emit_items: + items: list[Item] = [] + for ctx in contexts: + for _realization, item in realize_protocol_to_items( + proto, + ctx, + judgment_type=judgment_type, # type: ignore[arg-type] + ): + items.append(item) + write_jsonlines(items, output_file) + print_success( + f"Wrote {len(items)} Items from {len(contexts)} contexts to {output_file}" + ) + return + + realizations = [] + for ctx in contexts: + realizations.extend(proto.realize_all(ctx)) + write_jsonlines(realizations, output_file) + print_success( + f"Wrote {len(realizations)} realizations from {len(contexts)} " + f"contexts to {output_file}" + ) + + +@protocol.command() +@click.argument("output_file", type=click.Path(path_type=Path)) +@click.option( + "--config-file", + type=click.Path(exists=True, path_type=Path), + help="Path to a project config file (defaults to bead.toml).", +) +@click.option("--profile", default="default", help="Configuration profile.") +@click.option( + "--judgment-type", + default="acceptability", + help="Judgment type assigned to every ItemTemplate.", +) +def items( + output_file: Path, + config_file: Path | None, + profile: str, + judgment_type: str, +) -> None: + """Emit per-family ItemTemplates as JSONL. + + Builds one :class:`~bead.items.item_template.ItemTemplate` per + family in the configured protocol and writes them to OUTPUT_FILE + as JSONL. The resulting file feeds Stage 3 (item construction). + """ + proto = _load_protocol(config_file, profile) + if len(proto) == 0: + print_error("Configured protocol is empty; no templates to emit.") + raise SystemExit(1) + + templates: list[ItemTemplate] = [ + family_to_item_template(family, judgment_type=judgment_type) # type: ignore[arg-type] + for family in proto.families + ] + write_jsonlines(templates, output_file) + print_success(f"Wrote {len(templates)} ItemTemplates to {output_file}") diff --git a/bead/cli/training.py b/bead/cli/training.py index 7167004..377119f 100644 --- a/bead/cli/training.py +++ b/bead/cli/training.py @@ -18,13 +18,17 @@ from sklearn.metrics import accuracy_score, precision_recall_fscore_support from sklearn.model_selection import KFold -from bead.cli.models import _import_class # type: ignore[attr-defined] +from bead.active_learning.models import ( + config_class_for_task_type, + model_class_for_task_type, +) from bead.cli.utils import print_error, print_info, print_success from bead.data.base import JsonValue from bead.data.serialization import read_jsonlines from bead.data_collection.jatos import JATOSDataCollector from bead.evaluation.interannotator import InterAnnotatorMetrics from bead.items.item import Item +from bead.items.item_template import TaskType console = Console() @@ -330,9 +334,7 @@ def evaluate( ctx.exit(1) # Load model - model_class_name = f"{task_type.title().replace('_', '')}Model" - model_module = f"bead.active_learning.models.{task_type}" - model_class = _import_class(f"{model_module}.{model_class_name}") + model_class = model_class_for_task_type(cast(TaskType, task_type)) model_instance = model_class.load(model_dir) print_success(f"Loaded model from {model_dir}") @@ -553,13 +555,9 @@ def cross_validate( ctx.exit(1) # Import model and config classes - model_class_name = f"{task_type.title().replace('_', '')}Model" - config_class_name = f"{task_type.title().replace('_', '')}ModelConfig" - model_module = f"bead.active_learning.models.{task_type}" - config_module = "bead.config.active_learning" - model_class = _import_class(f"{model_module}.{model_class_name}") - config_class = _import_class(f"{config_module}.{config_class_name}") + model_class = model_class_for_task_type(cast(TaskType, task_type)) + config_class = config_class_for_task_type(cast(TaskType, task_type)) # Create cross-validator cv = KFold(n_splits=k_folds, shuffle=True, random_state=random_seed) @@ -781,13 +779,9 @@ def learning_curve( ctx.exit(1) # Import model and config classes - model_class_name = f"{task_type.title().replace('_', '')}Model" - config_class_name = f"{task_type.title().replace('_', '')}ModelConfig" - model_module = f"bead.active_learning.models.{task_type}" - config_module = "bead.config.active_learning" - model_class = _import_class(f"{model_module}.{model_class_name}") - config_class = _import_class(f"{config_module}.{config_class_name}") + model_class = model_class_for_task_type(cast(TaskType, task_type)) + config_class = config_class_for_task_type(cast(TaskType, task_type)) # Parse train sizes sizes = [float(s.strip()) for s in train_sizes.split(",")] diff --git a/bead/config/__init__.py b/bead/config/__init__.py index 65ce36b..20efdd1 100644 --- a/bead/config/__init__.py +++ b/bead/config/__init__.py @@ -25,6 +25,14 @@ get_profile, list_profiles, ) +from bead.config.protocol import ( + AnchorSpec, + DriftConfig, + FamilySpec, + ProtocolConfig, + RealizationKind, + TemplateVariantSpec, +) from bead.config.resources import ResourceConfig from bead.config.serialization import save_yaml, to_yaml from bead.config.template import SlotStrategyConfig, TemplateConfig @@ -44,6 +52,13 @@ "DeploymentConfig", "ActiveLearningConfig", "LoggingConfig", + "ProtocolConfig", + # protocol sub-specs + "AnchorSpec", + "TemplateVariantSpec", + "FamilySpec", + "DriftConfig", + "RealizationKind", # defaults "DEFAULT_CONFIG", "get_default_config", diff --git a/bead/config/config.py b/bead/config/config.py index f06f05d..cc5b93a 100644 --- a/bead/config/config.py +++ b/bead/config/config.py @@ -12,6 +12,7 @@ from bead.config.list import ListConfig from bead.config.logging import LoggingConfig from bead.config.paths import PathsConfig +from bead.config.protocol import ProtocolConfig from bead.config.resources import ResourceConfig from bead.config.template import TemplateConfig @@ -48,6 +49,10 @@ def _default_logging() -> LoggingConfig: return LoggingConfig() +def _default_protocol() -> ProtocolConfig: + return ProtocolConfig() + + class BeadConfig(dx.Model): """Main configuration for the bead package. @@ -71,6 +76,8 @@ class BeadConfig(dx.Model): Active learning configuration. logging : LoggingConfig Logging configuration. + protocol : ProtocolConfig + Annotation-protocol configuration. """ profile: str = "default" @@ -86,6 +93,7 @@ class BeadConfig(dx.Model): default_factory=_default_active_learning ) logging: dx.Embed[LoggingConfig] = dx.field(default_factory=_default_logging) + protocol: dx.Embed[ProtocolConfig] = dx.field(default_factory=_default_protocol) def to_dict(self) -> dict[str, Any]: """Render the configuration as a plain ``dict``.""" diff --git a/bead/config/protocol.py b/bead/config/protocol.py new file mode 100644 index 0000000..94157cb --- /dev/null +++ b/bead/config/protocol.py @@ -0,0 +1,512 @@ +"""Configuration for the annotation-protocol layer. + +Declares :class:`ProtocolConfig` (the top-level stage config that +plugs into :class:`~bead.config.config.BeadConfig`) along with the +declarative specs (:class:`AnchorSpec`, :class:`TemplateVariantSpec`, +:class:`FamilySpec`, :class:`DriftConfig`) that materialize into +runtime :class:`~bead.protocol.SemanticAnchor`, +:class:`~bead.protocol.QuestionFamily`, and +:class:`~bead.protocol.AnnotationProtocol` objects. + +Configuration is *declarative*: anchors, drift thresholds, realization +strategies, and protocol composition are written in YAML or TOML, and +:meth:`ProtocolConfig.build` produces the live objects. Runtime-only +parameters (LM clients, embedding adapters, output caches) are passed +to :meth:`build` rather than stored in the config. + +Predicates are referenced *by registered name*; callables cannot be +serialized. Register predicates in the +:mod:`~bead.protocol.context` registry at import time, then refer to +them by name from a :class:`FamilySpec` or :class:`TemplateVariantSpec`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import didactic.api as dx + +from bead.data.base import BeadBaseModel +from bead.protocol.anchor import ResponseSpace, SemanticAnchor, SemanticPoles +from bead.protocol.context import get_context_predicate +from bead.protocol.drift import ( + DriftGuard, + EmbeddingDriftValidator, + PerplexityDriftValidator, + StructuralDriftValidator, +) +from bead.protocol.family import AnnotationProtocol, QuestionFamily +from bead.protocol.realization import ( + ContextualTemplateRealization, + LMClient, + LMRealization, + RealizationStrategy, + TemplateRealization, + TemplateVariant, +) + +if TYPE_CHECKING: + from bead.items.cache import ModelOutputCache + from bead.protocol.drift import EmbeddingAdapter, PerplexityAdapter + + +RealizationKind = Literal["template", "contextual", "lm"] +"""Discriminator for which realization strategy a family uses.""" + + +class TemplateVariantSpec(BeadBaseModel): + """Declarative form of :class:`TemplateVariant` for config files. + + Attributes + ---------- + template : str + Question template, possibly containing ``[[label]]`` references. + condition_name : str + Name of a registered context predicate. Looked up via + :func:`bead.protocol.context.get_context_predicate` at build + time. Defaults to ``"always"``. + priority : int + Higher-priority variants are tried first. Defaults to ``0``. + description : str + Human-readable description. Defaults to empty. + """ + + template: str + condition_name: str = "always" + priority: int = 0 + description: str = "" + + def build(self) -> TemplateVariant: + """Build a :class:`TemplateVariant` from this spec. + + Returns + ------- + TemplateVariant + Live variant with the named predicate resolved. + + Raises + ------ + KeyError + If ``condition_name`` is not registered. + """ + return TemplateVariant( + template=self.template, + condition=get_context_predicate(self.condition_name), + priority=self.priority, + description=self.description, + ) + + +class AnchorSpec(BeadBaseModel): + """Declarative form of :class:`SemanticAnchor` for config files. + + Pole labels are flattened to two string fields rather than a + nested :class:`SemanticPoles`; ``build()`` constructs the embedded + model. + + Attributes + ---------- + name : str + Short identifier. + target_property : str + The property being measured. + canonical_prompt : str + Reference phrasing. + options : tuple[str, ...] + Ordered response options. + is_ordered : bool + Whether the response space is ordinal. Defaults to ``True``. + semantic_pole_low : str | None + Low-pole label, when ordered. Defaults to ``None``. + semantic_pole_high : str | None + High-pole label, when ordered. Defaults to ``None``. + required_span_labels : frozenset[str] + Span labels every realization must reference. Defaults to + the empty set. + required_keywords : frozenset[str] + Keywords every realization must contain. Defaults to the + empty set. + embedding_center : tuple[float, ...] | None + Pre-computed canonical-prompt embedding. Defaults to ``None``. + max_drift : float + Maximum cosine distance for embedding drift. Defaults to + ``0.3``. + description : str + Human-readable description. + """ + + name: str + target_property: str + canonical_prompt: str + options: tuple[str, ...] + is_ordered: bool = True + semantic_pole_low: str | None = None + semantic_pole_high: str | None = None + required_span_labels: frozenset[str] = dx.field(default_factory=frozenset) + required_keywords: frozenset[str] = dx.field(default_factory=frozenset) + embedding_center: tuple[float, ...] | None = None + max_drift: float = 0.3 + description: str = "" + + def build(self) -> SemanticAnchor: + """Build a :class:`SemanticAnchor` from this spec. + + Returns + ------- + SemanticAnchor + Live anchor. + + Raises + ------ + ValueError + If exactly one of ``semantic_pole_low`` and + ``semantic_pole_high`` is supplied. + """ + if (self.semantic_pole_low is None) != (self.semantic_pole_high is None): + raise ValueError( + f"AnchorSpec {self.name!r} sets only one pole; both " + f"semantic_pole_low and semantic_pole_high must be set " + f"or both must be None" + ) + poles: SemanticPoles | None = None + if self.semantic_pole_low is not None and self.semantic_pole_high is not None: + poles = SemanticPoles( + low=self.semantic_pole_low, + high=self.semantic_pole_high, + ) + space = ResponseSpace( + options=self.options, + is_ordered=self.is_ordered, + semantic_poles=poles, + ) + return SemanticAnchor( + name=self.name, + target_property=self.target_property, + canonical_prompt=self.canonical_prompt, + response_space=space, + required_span_labels=self.required_span_labels, + required_keywords=self.required_keywords, + embedding_center=self.embedding_center, + max_drift=self.max_drift, + description=self.description, + ) + + +class DriftConfig(BeadBaseModel): + """Configuration for the drift guard applied to a protocol. + + Every realized prompt runs through one shared + :class:`~bead.protocol.DriftGuard` configured by this section. + + Attributes + ---------- + min_length : int + Minimum non-whitespace length for the structural validator. + Defaults to ``15``. + require_question_mark : bool + Whether a trailing ``?`` is required. Defaults to ``True``. + keyword_case_sensitive : bool + Whether structural keyword checks are case-sensitive. Defaults + to ``False``. + embedding_max_distance : float | None + Cosine-distance ceiling for the embedding validator. ``None`` + defers to each anchor's ``max_drift``. Defaults to ``None``. + enable_embedding : bool + Whether to add an :class:`EmbeddingDriftValidator`. Requires an + embedding adapter at build time. Defaults to ``False``. + enable_perplexity : bool + Whether to add a :class:`PerplexityDriftValidator`. Requires a + perplexity adapter at build time. Defaults to ``False``. + max_perplexity : float + Perplexity ceiling for the perplexity validator. Defaults to + ``100.0``. + """ + + min_length: int = 15 + require_question_mark: bool = True + keyword_case_sensitive: bool = False + embedding_max_distance: float | None = None + enable_embedding: bool = False + enable_perplexity: bool = False + max_perplexity: float = 100.0 + + def build( + self, + *, + embedding_adapter: EmbeddingAdapter | None = None, + perplexity_adapter: PerplexityAdapter | None = None, + ) -> DriftGuard: + """Build a :class:`DriftGuard` with structural + optional checks. + + Parameters + ---------- + embedding_adapter : EmbeddingAdapter | None, optional + Required when :attr:`enable_embedding` is ``True``. Defaults + to ``None``. + perplexity_adapter : PerplexityAdapter | None, optional + Required when :attr:`enable_perplexity` is ``True``. Defaults + to ``None``. + + Returns + ------- + DriftGuard + Live composite drift validator. + + Raises + ------ + ValueError + If a validator is enabled but its adapter was not supplied. + """ + guard = DriftGuard() + guard.add( + StructuralDriftValidator( + min_length=self.min_length, + require_question_mark=self.require_question_mark, + keyword_case_sensitive=self.keyword_case_sensitive, + ) + ) + if self.enable_embedding: + if embedding_adapter is None: + raise ValueError( + "drift.enable_embedding=True but no " + "embedding_adapter was supplied to build()" + ) + guard.add( + EmbeddingDriftValidator( + embedding_adapter, + max_distance=self.embedding_max_distance, + ) + ) + if self.enable_perplexity: + if perplexity_adapter is None: + raise ValueError( + "drift.enable_perplexity=True but no " + "perplexity_adapter was supplied to build()" + ) + guard.add( + PerplexityDriftValidator( + perplexity_adapter, + max_perplexity=self.max_perplexity, + ) + ) + return guard + + +class FamilySpec(BeadBaseModel): + """Declarative form of :class:`QuestionFamily` for config files. + + Attributes + ---------- + anchor : AnchorSpec + The anchor declaration. Built into a + :class:`SemanticAnchor` at build time. + realization_kind : RealizationKind + Which realization strategy to use. + template : str | None + Used when ``realization_kind="template"``. ``None`` defers to + the anchor's canonical prompt. + variants : tuple[TemplateVariantSpec, ...] + Used when ``realization_kind="contextual"``. Empty tuple is + invalid for that kind. + fallback : str | None + Fallback template used when no variant matches. ``None`` + defers to the anchor's canonical prompt. + condition_name : str + Registered predicate name controlling family applicability. + Defaults to ``"always"``. + depends_on : tuple[str, ...] + Names of anchors whose responses must precede this family in + the protocol. Defaults to the empty tuple. + fallback_on_drift : bool + Whether to fall back to the canonical prompt on drift failure. + Defaults to ``True``. + """ + + anchor: dx.Embed[AnchorSpec] + realization_kind: RealizationKind = "template" + template: str | None = None + variants: tuple[dx.Embed[TemplateVariantSpec], ...] = () + fallback: str | None = None + condition_name: str = "always" + depends_on: tuple[str, ...] = () + fallback_on_drift: bool = True + + def _build_realization( + self, + *, + lm_client: LMClient | None, + lm_model_name: str, + cache: ModelOutputCache | None, + lm_temperature: float, + lm_max_tokens: int, + ) -> RealizationStrategy: + """Construct the realization strategy named by ``realization_kind``.""" + if self.realization_kind == "template": + return TemplateRealization(template=self.template) + if self.realization_kind == "contextual": + if not self.variants: + raise ValueError( + f"FamilySpec {self.anchor.name!r} has " + f"realization_kind='contextual' but variants is empty" + ) + return ContextualTemplateRealization( + variants=tuple(v.build() for v in self.variants), + fallback=self.fallback, + ) + if self.realization_kind == "lm": + if lm_client is None: + raise ValueError( + f"FamilySpec {self.anchor.name!r} has " + f"realization_kind='lm' but no lm_client was " + f"supplied to ProtocolConfig.build()" + ) + return LMRealization( + lm_client, + model_name=lm_model_name, + cache=cache, + temperature=lm_temperature, + max_tokens=lm_max_tokens, + ) + raise ValueError(f"Unknown realization_kind: {self.realization_kind!r}") + + def build( + self, + *, + drift_guard: DriftGuard, + lm_client: LMClient | None, + lm_model_name: str, + cache: ModelOutputCache | None, + lm_temperature: float, + lm_max_tokens: int, + ) -> QuestionFamily: + """Build a :class:`QuestionFamily` from this spec. + + Parameters + ---------- + drift_guard : DriftGuard + Shared drift guard for the protocol. + lm_client : LMClient | None + LM backend; required when ``realization_kind == "lm"``. + lm_model_name : str + Cache-key prefix for LM realizations. + cache : ModelOutputCache | None + Output cache for LM realizations. + lm_temperature : float + Sampling temperature for LM realizations. + lm_max_tokens : int + Maximum response length for LM realizations. + + Returns + ------- + QuestionFamily + Live family. + """ + return QuestionFamily( + anchor=self.anchor.build(), + realization=self._build_realization( + lm_client=lm_client, + lm_model_name=lm_model_name, + cache=cache, + lm_temperature=lm_temperature, + lm_max_tokens=lm_max_tokens, + ), + drift_guard=drift_guard, + condition=get_context_predicate(self.condition_name), + depends_on=self.depends_on, + fallback_on_drift=self.fallback_on_drift, + ) + + +def _default_drift() -> DriftConfig: + return DriftConfig() + + +class ProtocolConfig(BeadBaseModel): + """Top-level annotation-protocol stage configuration. + + Plugs into :class:`~bead.config.config.BeadConfig` as the + ``protocol`` field. Declares the families, drift settings, and + LM defaults for an annotation protocol that can be loaded from + YAML or TOML and materialized via :meth:`build`. + + Attributes + ---------- + name : str + Descriptive protocol name. Defaults to empty. + families : tuple[FamilySpec, ...] + Declarative family specs in protocol order. Defaults to the + empty tuple. + drift : DriftConfig + Drift-guard configuration shared by all families. Defaults to + a structural-only guard with the standard defaults. + lm_model_name : str + Cache-key prefix for LM realizations. Used when any family + has ``realization_kind="lm"``. Defaults to empty (forces the + caller to set it explicitly when LM realizations are used). + lm_temperature : float + Default sampling temperature for LM realizations. Defaults to + ``0.3``. + lm_max_tokens : int + Default maximum response length for LM realizations. Defaults + to ``200``. + """ + + name: str = "" + families: tuple[dx.Embed[FamilySpec], ...] = () + drift: dx.Embed[DriftConfig] = dx.field(default_factory=_default_drift) + lm_model_name: str = "" + lm_temperature: float = 0.3 + lm_max_tokens: int = 200 + + def build( + self, + *, + lm_client: LMClient | None = None, + cache: ModelOutputCache | None = None, + embedding_adapter: EmbeddingAdapter | None = None, + perplexity_adapter: PerplexityAdapter | None = None, + ) -> AnnotationProtocol: + """Materialize the configured protocol. + + Parameters + ---------- + lm_client : LMClient | None, optional + LM backend, required if any family declares + ``realization_kind="lm"``. Defaults to ``None``. + cache : ModelOutputCache | None, optional + Output cache for LM realizations. Defaults to ``None``. + embedding_adapter : EmbeddingAdapter | None, optional + Required when ``drift.enable_embedding=True``. Defaults to + ``None``. + perplexity_adapter : PerplexityAdapter | None, optional + Required when ``drift.enable_perplexity=True``. Defaults to + ``None``. + + Returns + ------- + AnnotationProtocol + Live protocol with every family materialized in declared + order. + + Raises + ------ + ValueError + If a required runtime dependency is not supplied or a + family declares an unknown realization kind. + """ + guard = self.drift.build( + embedding_adapter=embedding_adapter, + perplexity_adapter=perplexity_adapter, + ) + families = [ + family_spec.build( + drift_guard=guard, + lm_client=lm_client, + lm_model_name=self.lm_model_name, + cache=cache, + lm_temperature=self.lm_temperature, + lm_max_tokens=self.lm_max_tokens, + ) + for family_spec in self.families + ] + return AnnotationProtocol(families=families, name=self.name) diff --git a/bead/data_collection/__init__.py b/bead/data_collection/__init__.py index f4867bc..31cbf7b 100644 --- a/bead/data_collection/__init__.py +++ b/bead/data_collection/__init__.py @@ -3,9 +3,11 @@ from bead.data_collection.jatos import JATOSDataCollector from bead.data_collection.merger import DataMerger from bead.data_collection.prolific import ProlificDataCollector +from bead.data_collection.records import jatos_results_to_annotation_records __all__ = [ + "DataMerger", "JATOSDataCollector", "ProlificDataCollector", - "DataMerger", + "jatos_results_to_annotation_records", ] diff --git a/bead/data_collection/records.py b/bead/data_collection/records.py new file mode 100644 index 0000000..49c28ee --- /dev/null +++ b/bead/data_collection/records.py @@ -0,0 +1,137 @@ +"""Bridge from JATOS results to bead annotation records. + +JATOS returns experimental results as nested JSON: each study run +contains a ``data`` array of trial objects, each carrying the +metadata serialized by +:func:`bead.deployment.jspsych.trials._serialize_item_metadata` and a +jsPsych ``response`` field. + +This module is the single canonical conversion from that +representation into :class:`~bead.evaluation.AnnotationRecord` +instances, the input shape consumed by every reliability, +inter-annotator-agreement, and conditional-observation check in bead. +There is no other path from raw JATOS output to bead records. +""" + +from __future__ import annotations + +from collections.abc import Iterable, Mapping +from typing import Any + +from bead.data.base import JsonValue +from bead.evaluation.reliability import AnnotationRecord + + +def _coerce_response_label(response: JsonValue) -> str: + """Normalize a jsPsych ``response`` value to a string label. + + Parameters + ---------- + response : JsonValue + The raw value emitted by jsPsych. For binary, categorical, and + forced-choice tasks this is already a string label; for + ordinal, magnitude, and similar numeric tasks it is an int or + float; jsPsych may also wrap the response in an object with a + ``"response"`` key for some plugin variants. + + Returns + ------- + str + String form suitable for :class:`AnnotationRecord.response_label`. + """ + if isinstance(response, str): + return response + if isinstance(response, bool): + return "true" if response else "false" + if isinstance(response, int | float): + return str(response) + if isinstance(response, dict) and "response" in response: + inner = response["response"] + if isinstance(inner, str): + return inner + return str(inner) + return str(response) + + +def _annotator_id( + result: Mapping[str, Any], + *, + annotator_id_key: str, +) -> str | None: + """Extract the annotator id from a JATOS result envelope. + + Looks first in the URL query parameters (``urlQueryParameters``) + for the configured key (typically ``"PROLIFIC_PID"``), then falls + back to the JATOS-assigned ``worker_id``. + """ + url_params = result.get("urlQueryParameters") + if isinstance(url_params, Mapping) and annotator_id_key in url_params: + candidate = url_params[annotator_id_key] + if isinstance(candidate, str) and candidate: + return candidate + worker_id = result.get("worker_id") + if isinstance(worker_id, str) and worker_id: + return worker_id + if isinstance(worker_id, int): + return str(worker_id) + return None + + +def jatos_results_to_annotation_records( + results: Iterable[Mapping[str, Any]], + *, + annotator_id_key: str = "PROLIFIC_PID", +) -> tuple[AnnotationRecord, ...]: + """Convert a sequence of JATOS results to :class:`AnnotationRecord`s. + + Each JATOS result is expected to be the dict shape returned by + :class:`~bead.data_collection.JATOSDataCollector` (a study run + with a ``data`` field carrying jsPsych trial dicts). Trials that + lack ``item_id`` or ``template_name`` are silently skipped, since + they correspond to non-question trials such as instructions or + consent. + + Parameters + ---------- + results : Iterable[Mapping[str, Any]] + JATOS result envelopes. + annotator_id_key : str, optional + Query-parameter key carrying the annotator identifier. + Defaults to ``"PROLIFIC_PID"``. + + Returns + ------- + tuple[AnnotationRecord, ...] + One record per (item, template_name) trial. Records appear in + result-then-trial order. Trials missing required fields are + skipped. + """ + records: list[AnnotationRecord] = [] + for result in results: + annotator_id = _annotator_id(result, annotator_id_key=annotator_id_key) + if annotator_id is None: + continue + trials = result.get("data") + if not isinstance(trials, list): + continue + for trial in trials: + if not isinstance(trial, Mapping): + continue + item_id = trial.get("item_id") + question_name = trial.get("template_name") + response = trial.get("response") + if ( + not isinstance(item_id, str) + or not isinstance(question_name, str) + or response is None + ): + continue + records.append( + AnnotationRecord( + annotator_id=annotator_id, + item_id=item_id, + question_name=question_name, + response_label=_coerce_response_label(response), + ) + ) + return tuple(records) diff --git a/bead/deployment/jspsych/trials.py b/bead/deployment/jspsych/trials.py index 69c38f4..631bc82 100644 --- a/bead/deployment/jspsych/trials.py +++ b/bead/deployment/jspsych/trials.py @@ -8,7 +8,6 @@ from __future__ import annotations -import re from dataclasses import dataclass from bead.data.base import JsonValue @@ -24,6 +23,7 @@ from bead.items.item import Item from bead.items.item_template import ItemTemplate from bead.items.spans import Span +from bead.labels import parse_label_refs from bead.transforms.base import TransformContext, TransformRegistry @@ -1090,62 +1090,6 @@ def _generate_span_stimulus_html( # prompt span reference resolution -_SPAN_REF_PATTERN = re.compile(r"\[\[([^\]:|]+?)(?::([^\]|]+?))?(?:\|([^\]]+?))?\]\]") - - -@dataclass(frozen=True) -class _SpanReference: - """A parsed ``[[label]]``, ``[[label:text]]``, or ``[[label|transform]]``.""" - - label: str - display_text: str | None - transforms: list[str] - match_start: int - match_end: int - - -def _parse_prompt_references(prompt: str) -> list[_SpanReference]: - """Parse span references from a prompt string. - - Supports three syntax forms (and their combinations): - - - ``[[label]]`` — auto-fill display text from span tokens - - ``[[label:text]]`` — explicit display text - - ``[[label|transform1|transform2]]`` — auto-fill with transforms - - ``[[label:text|transform1]]`` — explicit text with transforms - - Parameters - ---------- - prompt : str - Prompt string potentially containing span references. - - Returns - ------- - list[_SpanReference] - Parsed references in order of appearance. - """ - refs: list[_SpanReference] = [] - - for m in _SPAN_REF_PATTERN.finditer(prompt): - raw_transforms = m.group(3) - transform_names = ( - [t.strip() for t in raw_transforms.split("|") if t.strip()] - if raw_transforms - else [] - ) - - refs.append( - _SpanReference( - label=m.group(1).strip(), - display_text=m.group(2).strip() if m.group(2) else None, - transforms=transform_names, - match_start=m.start(), - match_end=m.end(), - ) - ) - - return refs - def _auto_fill_span_text(label: str, item: Item) -> str: """Reconstruct display text from a span's tokens. @@ -1240,7 +1184,7 @@ def _resolve_prompt_references( KeyError If a transform name is not found in the registry. """ - refs = _parse_prompt_references(prompt) + refs = parse_label_refs(prompt) if not refs: return prompt @@ -1264,7 +1208,7 @@ def _resolve_prompt_references( # apply transforms if requested and a registry is available if ref.transforms and transform_registry is not None: context = _build_transform_context(ref.label, item) - pipeline = transform_registry.resolve_pipeline(ref.transforms) + pipeline = transform_registry.resolve_pipeline(list(ref.transforms)) display = pipeline(display, context) light = color_map.light_by_label.get(ref.label, "#BBDEFB") @@ -1275,7 +1219,7 @@ def _resolve_prompt_references( f'' f"{ref.label}" ) - result = result[: ref.match_start] + html + result[ref.match_end :] + result = result[: ref.start_offset] + html + result[ref.end_offset :] return result diff --git a/bead/deployment/protocol_trials.py b/bead/deployment/protocol_trials.py new file mode 100644 index 0000000..60724bd --- /dev/null +++ b/bead/deployment/protocol_trials.py @@ -0,0 +1,108 @@ +"""Bridge from the protocol layer to jsPsych deployment. + +End-to-end path from a configured :class:`AnnotationProtocol` and a +sequence of :class:`~bead.protocol.ProtocolContext` records to a list +of jsPsych trial objects ready for batch deployment. + +This is the canonical bridge to deployment. There is no other way to +materialize a protocol-defined experiment. +""" + +from __future__ import annotations + +from collections.abc import Iterable + +from bead.data.base import JsonValue +from bead.deployment.jspsych.config import ( + ChoiceConfig, + ExperimentConfig, + RatingScaleConfig, +) +from bead.deployment.jspsych.trials import create_trial +from bead.items.item_template import ItemTemplate, JudgmentType, PresentationSpec +from bead.protocol.context import ProtocolContext +from bead.protocol.family import AnnotationProtocol +from bead.protocol.items import ( + protocol_to_item_templates, + realize_protocol_to_items, +) + + +def protocol_to_jspsych_trials( + protocol: AnnotationProtocol, + contexts: Iterable[ProtocolContext], + *, + experiment_config: ExperimentConfig, + judgment_type: JudgmentType, + presentation_spec: PresentationSpec | None = None, + rating_config: RatingScaleConfig | None = None, + choice_config: ChoiceConfig | None = None, +) -> list[dict[str, JsonValue]]: + """Materialize an entire protocol as a flat list of jsPsych trials. + + Each :class:`ProtocolContext` is realized through every applicable + :class:`~bead.protocol.QuestionFamily`. Each resulting realization + is packaged as an :class:`~bead.items.item.Item` bound to the + family's :class:`ItemTemplate` and turned into a jsPsych trial + via :func:`bead.deployment.jspsych.trials.create_trial`. Trials + are returned in + ``(context_order, family_order)`` order: every realized question + for the first context comes first, then the second context, and + so on. + + Parameters + ---------- + protocol : AnnotationProtocol + Configured protocol whose families to realize. + contexts : Iterable[ProtocolContext] + Contexts to realize, one per annotation target. + experiment_config : ExperimentConfig + Shared experiment configuration applied to every trial. + judgment_type : JudgmentType + Common judgment type assigned to every per-family + :class:`ItemTemplate`. + presentation_spec : PresentationSpec | None, optional + Common presentation spec across families. Defaults to a fresh + :class:`PresentationSpec` per template. + rating_config : RatingScaleConfig | None, optional + Configuration for rating-scale trials (ordinal task type). + choice_config : ChoiceConfig | None, optional + Configuration for choice trials (binary, categorical, or + forced-choice task types). + + Returns + ------- + list[dict[str, JsonValue]] + Flat list of jsPsych trial dicts in ``trial_number`` order. + """ + templates: dict[str, ItemTemplate] = protocol_to_item_templates( + protocol, + judgment_type=judgment_type, + presentation_spec=presentation_spec, + ) + + template_by_id = {t.id: t for t in templates.values()} + + trials: list[dict[str, JsonValue]] = [] + trial_number = 0 + for ctx in contexts: + for _realization, item in realize_protocol_to_items( + protocol, + ctx, + judgment_type=judgment_type, + item_templates=templates, + presentation_spec=presentation_spec, + ): + trials.append( + create_trial( + item=item, + template=template_by_id[item.item_template_id], + experiment_config=experiment_config, + trial_number=trial_number, + rating_config=rating_config, + choice_config=choice_config, + ) + ) + trial_number += 1 + + return trials diff --git a/bead/evaluation/__init__.py b/bead/evaluation/__init__.py index 3ac6cf0..0a68007 100644 --- a/bead/evaluation/__init__.py +++ b/bead/evaluation/__init__.py @@ -1,13 +1,24 @@ """Evaluation module for model and human performance assessment. Provides cross-validation, inter-annotator agreement metrics, model -performance metrics, and convergence detection for active learning. +performance metrics, convergence detection for active learning, and +per-annotator reliability summaries. """ from bead.evaluation.convergence import ConvergenceDetector from bead.evaluation.interannotator import InterAnnotatorMetrics +from bead.evaluation.reliability import ( + AnnotationRecord, + AnnotatorReliability, + annotator_reliability, + low_entropy_annotators, +) __all__ = [ - "InterAnnotatorMetrics", + "AnnotationRecord", + "AnnotatorReliability", "ConvergenceDetector", + "InterAnnotatorMetrics", + "annotator_reliability", + "low_entropy_annotators", ] diff --git a/bead/evaluation/reliability.py b/bead/evaluation/reliability.py new file mode 100644 index 0000000..c692986 --- /dev/null +++ b/bead/evaluation/reliability.py @@ -0,0 +1,279 @@ +"""Per-annotator reliability summaries. + +Sits next to :class:`bead.evaluation.InterAnnotatorMetrics`. Where the +inter-annotator metrics quantify *agreement* across raters, this +module quantifies *response diversity* of each individual rater. Low +within-annotator entropy is a flag that the annotator is collapsing +the response space (always picking ``"yes"``, always picking the +midpoint, and so on), which biases agreement metrics in misleading +directions. + +The canonical input is a sequence of :class:`AnnotationRecord` +instances, each carrying an ``annotator_id``, ``item_id``, +``response_label``, and ``question_name``. The Shannon entropy of +each annotator's per-question response distribution is computed in +bits. +""" + +from __future__ import annotations + +import math +from collections.abc import Mapping, Sequence + +import didactic.api as dx + +from bead.data.base import BeadBaseModel +from bead.protocol.encoding import ResponseEncoding + + +class AnnotationRecord(BeadBaseModel): + """A single annotator response. + + Canonical record shape consumed by reliability and inter-annotator + metrics. Conforms structurally to + :class:`bead.protocol.diagnostics.RecordLike`. + + Attributes + ---------- + annotator_id : str + Identifier of the annotator who produced the response. + item_id : str + Identifier of the annotation item. + question_name : str + Anchor name of the question that was answered. + response_label : str + The annotator's response label (must be one of the labels of + the corresponding :class:`ResponseEncoding`). + """ + + annotator_id: str + item_id: str + question_name: str + response_label: str + + +class AnnotatorReliability(BeadBaseModel): + """Per-annotator reliability summary. + + Captures how diverse a single annotator's responses are within + each question. Low entropy means the annotator collapses the + response space. + + Attributes + ---------- + annotator_id : str + The annotator's identifier. + n_responses : int + Total responses from this annotator across all questions. + response_distribution : dict[str, dict[str, int]] + Per-question distribution of responses, keyed by anchor name + and then by response label, with counts as values. + entropy_per_question : dict[str, float] + Per-question Shannon entropy in bits. ``0.0`` when the + annotator only used one label for that question. + + Examples + -------- + >>> rel = AnnotatorReliability( + ... annotator_id="ann_1", + ... n_responses=4, + ... response_distribution={ + ... "completion": {"yes": 2, "no": 2}, + ... }, + ... entropy_per_question={"completion": 1.0}, + ... ) + >>> rel.entropy("completion") + 1.0 + >>> rel.entropy("missing") is None + True + """ + + annotator_id: str + n_responses: int = 0 + response_distribution: dict[str, dict[str, int]] = dx.field(default_factory=dict) + entropy_per_question: dict[str, float] = dx.field(default_factory=dict) + + def entropy(self, question_name: str) -> float | None: + """Return the Shannon entropy for one question, or ``None``. + + Parameters + ---------- + question_name : str + Anchor name to look up. + + Returns + ------- + float | None + Entropy in bits, or ``None`` if no responses were recorded + for this question. + """ + return self.entropy_per_question.get(question_name) + + +def _shannon_entropy(counts: Mapping[str, int]) -> float: + """Return the Shannon entropy in bits of a count distribution. + + Parameters + ---------- + counts : Mapping[str, int] + Per-label counts. Zero-count labels are treated as absent. + + Returns + ------- + float + Entropy in bits. ``0.0`` for an empty or singleton + distribution. + """ + total = sum(counts.values()) + if total == 0: + return 0.0 + entropy = 0.0 + for count in counts.values(): + if count <= 0: + continue + p = count / total + entropy -= p * math.log2(p) + return entropy + + +def annotator_reliability( + records: Sequence[AnnotationRecord], + encodings: Mapping[str, ResponseEncoding] | None = None, +) -> tuple[AnnotatorReliability, ...]: + """Compute per-annotator reliability summaries. + + Groups records by annotator, then by question, and computes + Shannon entropy in bits on each annotator-question label + distribution. When ``encodings`` is supplied, response labels not + present in the encoding for a question are silently skipped (a + common case after schema evolution). + + Parameters + ---------- + records : Sequence[AnnotationRecord] + All records across questions and annotators. + encodings : Mapping[str, ResponseEncoding] | None, optional + Per-question encodings used to filter unrecognized labels. + When ``None`` (the default), every label is counted. + + Returns + ------- + tuple[AnnotatorReliability, ...] + One summary per annotator, sorted by annotator id. + + Examples + -------- + >>> records = [ + ... AnnotationRecord(annotator_id="a1", item_id="i1", + ... question_name="q", response_label="yes"), + ... AnnotationRecord(annotator_id="a1", item_id="i2", + ... question_name="q", response_label="no"), + ... AnnotationRecord(annotator_id="a2", item_id="i1", + ... question_name="q", response_label="yes"), + ... AnnotationRecord(annotator_id="a2", item_id="i2", + ... question_name="q", response_label="yes"), + ... ] + >>> profiles = annotator_reliability(records) + >>> [(p.annotator_id, p.entropy("q")) for p in profiles] + [('a1', 1.0), ('a2', 0.0)] + """ + by_annotator: dict[str, list[AnnotationRecord]] = {} + for rec in records: + by_annotator.setdefault(rec.annotator_id, []).append(rec) + + summaries: list[AnnotatorReliability] = [] + for ann_id in sorted(by_annotator): + ann_records = by_annotator[ann_id] + distribution: dict[str, dict[str, int]] = {} + entropy_per_question: dict[str, float] = {} + + by_question: dict[str, list[str]] = {} + for rec in ann_records: + if encodings is not None: + encoding = encodings.get(rec.question_name) + if encoding is not None and rec.response_label not in encoding.labels: + continue + by_question.setdefault(rec.question_name, []).append(rec.response_label) + + for q_name, labels in by_question.items(): + counts: dict[str, int] = {} + for label in labels: + counts[label] = counts.get(label, 0) + 1 + distribution[q_name] = counts + entropy_per_question[q_name] = _shannon_entropy(counts) + + summaries.append( + AnnotatorReliability( + annotator_id=ann_id, + n_responses=sum(len(v) for v in by_question.values()), + response_distribution=distribution, + entropy_per_question=entropy_per_question, + ) + ) + + return tuple(summaries) + + +def low_entropy_annotators( + profiles: Sequence[AnnotatorReliability], + *, + threshold: float, + question_name: str | None = None, + require_min_responses: int = 1, +) -> tuple[str, ...]: + """Return annotator ids whose entropy falls at or below a threshold. + + Useful for flagging annotators who collapse the response space. + When ``question_name`` is supplied, the threshold is checked + against that one question's entropy; otherwise it is checked + against the *minimum* per-question entropy across every question + the annotator answered. + + Parameters + ---------- + profiles : Sequence[AnnotatorReliability] + Reliability summaries to scan. + threshold : float + Entropy ceiling in bits. Annotators with entropy at or below + this value are returned. + question_name : str | None, optional + Restrict the check to one question. Defaults to ``None`` (all + questions, returning the minimum). + require_min_responses : int, optional + Skip annotators whose response count is below this value. + Defaults to ``1``. + + Returns + ------- + tuple[str, ...] + Annotator ids meeting the criterion, sorted. + + Examples + -------- + >>> profiles = ( + ... AnnotatorReliability(annotator_id="a1", n_responses=10, + ... entropy_per_question={"q": 0.0}), + ... AnnotatorReliability(annotator_id="a2", n_responses=10, + ... entropy_per_question={"q": 0.95}), + ... ) + >>> low_entropy_annotators(profiles, threshold=0.5) + ('a1',) + """ + flagged: list[str] = [] + for profile in profiles: + if profile.n_responses < require_min_responses: + continue + if question_name is not None: + entropy = profile.entropy(question_name) + if entropy is None: + continue + if entropy <= threshold: + flagged.append(profile.annotator_id) + else: + entropies = tuple(profile.entropy_per_question.values()) + if not entropies: + continue + if min(entropies) <= threshold: + flagged.append(profile.annotator_id) + + return tuple(sorted(flagged)) diff --git a/bead/items/cache.py b/bead/items/cache.py index b8b3215..a3e40cb 100644 --- a/bead/items/cache.py +++ b/bead/items/cache.py @@ -480,7 +480,7 @@ def set( self, model_name: str, operation: str, - result: float | dict[str, float] | list[float] | np.ndarray, + result: str | float | dict[str, float] | list[float] | np.ndarray, model_version: str | None = None, **inputs: str | int | float | bool | None, ) -> None: @@ -491,9 +491,12 @@ def set( model_name Model identifier. operation - Operation type (e.g., "log_probability", "nli", "embedding"). + Operation type (e.g., "log_probability", "nli", "embedding", + "lm_completion"). result - Result to cache (log probability, NLI scores, embedding, etc.). + Result to cache. Strings (LM completions), floats (log + probabilities), float dicts (NLI scores), float lists, + and numpy arrays (embeddings) are supported. model_version Optional model version string for tracking. **inputs diff --git a/bead/items/span_labeling.py b/bead/items/span_labeling.py index 92aa2f6..a923cb7 100644 --- a/bead/items/span_labeling.py +++ b/bead/items/span_labeling.py @@ -14,7 +14,6 @@ from __future__ import annotations import json -import re import warnings from collections.abc import Callable from uuid import UUID, uuid4 @@ -25,11 +24,10 @@ Span, SpanSpec, ) +from bead.labels import parse_label_refs from bead.tokenization.config import TokenizerConfig from bead.tokenization.tokenizers import TokenizedText, create_tokenizer -_SPAN_REF_PATTERN = re.compile(r"\[\[([^\]:|]+?)(?::([^\]|]+?))?(?:\|([^\]]+?))?\]\]") - def tokenize_item( item: Item, @@ -324,12 +322,11 @@ def add_spans_to_item( if prompt_text: all_spans = list(item.spans) + spans span_labels = {s.label.label for s in all_spans if s.label is not None} - for match in _SPAN_REF_PATTERN.finditer(prompt_text): - ref_label = match.group(1) - if ref_label not in span_labels: + for ref in parse_label_refs(prompt_text): + if ref.label not in span_labels: warnings.warn( - f"Prompt contains [[{ref_label}]] but no span with " - f"label '{ref_label}' exists. Available labels: " + f"Prompt contains [[{ref.label}]] but no span with " + f"label '{ref.label}' exists. Available labels: " f"{sorted(span_labels)}", UserWarning, stacklevel=2, diff --git a/bead/labels.py b/bead/labels.py new file mode 100644 index 0000000..8f125db --- /dev/null +++ b/bead/labels.py @@ -0,0 +1,167 @@ +"""Prompt-template label-reference syntax. + +bead prompt templates use a single canonical syntax for embedding +references to named spans within a prompt: + +- ``[[label]]`` — the prompt is rendered with the span's reconstructed + text in this position +- ``[[label:text]]`` — the prompt is rendered with the explicit ``text`` + in this position (overrides the reconstructed text) +- ``[[label|transform1|transform2]]`` — the reconstructed text is + passed through the named transforms before rendering +- ``[[label:text|transform1]]`` — combines the explicit-text and + transform forms + +This module is the single canonical home for that syntax. Drift +validators, item-construction utilities, and the jsPsych deployment +layer all parse references through :func:`parse_label_refs` and never +through their own regular expressions. +""" + +from __future__ import annotations + +import re +from collections.abc import Callable + +from bead.data.base import BeadBaseModel + +LABEL_PATTERN: re.Pattern[str] = re.compile( + r"\[\[([^\]:|]+?)(?::([^\]|]+?))?(?:\|([^\]]+?))?\]\]" +) +"""Compiled regex for ``[[label]]`` / ``[[label:text]]`` / ``[[label|t]]``. + +Capture groups: ``(1)`` label name, ``(2)`` optional display text, +``(3)`` optional pipe-separated transform list. The pattern is +non-greedy and rejects ``]``, ``:``, and ``|`` characters inside the +label name. +""" + + +class LabelRef(BeadBaseModel): + """A parsed label reference. + + Attributes + ---------- + label : str + The label name. + display_text : str | None + Explicit display text supplied via ``[[label:text]]``, or + ``None`` when the reference is bare. + transforms : tuple[str, ...] + Transform names supplied via ``[[label|t1|t2]]``, in order. + Empty when no transforms were supplied. + start_offset : int + Inclusive character offset of the matched reference in the + original prompt. + end_offset : int + Exclusive character offset of the matched reference in the + original prompt. + """ + + label: str + display_text: str | None = None + transforms: tuple[str, ...] = () + start_offset: int = 0 + end_offset: int = 0 + + +def parse_label_refs(prompt: str) -> tuple[LabelRef, ...]: + """Parse every label reference in ``prompt``, in order. + + Parameters + ---------- + prompt : str + Prompt string potentially containing label references. + + Returns + ------- + tuple[LabelRef, ...] + Parsed references in order of appearance. Empty tuple when no + references match. + + Examples + -------- + >>> refs = parse_label_refs("Did [[situation|gerund]] happen?") + >>> refs[0].label + 'situation' + >>> refs[0].transforms + ('gerund',) + """ + refs: list[LabelRef] = [] + for match in LABEL_PATTERN.finditer(prompt): + raw_transforms = match.group(3) + if raw_transforms is None: + transforms: tuple[str, ...] = () + else: + transforms = tuple( + t.strip() for t in raw_transforms.split("|") if t.strip() + ) + display_text = match.group(2).strip() if match.group(2) else None + refs.append( + LabelRef( + label=match.group(1).strip(), + display_text=display_text, + transforms=transforms, + start_offset=match.start(), + end_offset=match.end(), + ) + ) + return tuple(refs) + + +def find_label_names(prompt: str) -> frozenset[str]: + """Return the set of label names referenced in ``prompt``. + + Convenience wrapper around :func:`parse_label_refs` that discards + everything except the label names. Used by structural drift + validation, where only the *which-labels-are-present* question + matters and display text and transforms are irrelevant. + + Parameters + ---------- + prompt : str + Prompt string potentially containing label references. + + Returns + ------- + frozenset[str] + Distinct label names referenced in the prompt. + + Examples + -------- + >>> find_label_names("Compare [[a]] and [[b:other]] and [[a|gerund]].") + frozenset({'a', 'b'}) + """ + return frozenset(ref.label for ref in parse_label_refs(prompt)) + + +def replace_label_refs( + prompt: str, + render: Callable[[LabelRef], str], +) -> str: + """Rewrite ``prompt`` by replacing each reference with rendered text. + + The ``render`` callable is invoked once per reference and must + return the string that should replace it. Replacements are applied + right-to-left so earlier matches' offsets remain valid. + + Parameters + ---------- + prompt : str + Prompt string potentially containing label references. + render : Callable[[LabelRef], str] + Function returning the replacement text for one reference. + + Returns + ------- + str + Prompt with every reference replaced. + """ + refs = parse_label_refs(prompt) + if not refs: + return prompt + result = prompt + for ref in reversed(refs): + replacement = render(ref) + result = result[: ref.start_offset] + replacement + result[ref.end_offset :] + return result diff --git a/bead/protocol/__init__.py b/bead/protocol/__init__.py new file mode 100644 index 0000000..2dfaf5d --- /dev/null +++ b/bead/protocol/__init__.py @@ -0,0 +1,125 @@ +"""Annotation protocol primitives. + +This package provides the type-theoretic stack for defining annotation +protocols independent of any specific linguistic domain. The design is +organized around four roles: + +- :class:`SemanticAnchor` is the *type* of a question: a declarative + specification of what is being measured, independent of how the + question is phrased. +- :class:`ProtocolContext` is the dependent *index*: everything known + about the current annotation target. Different contexts license + different questions and different phrasings. +- :class:`RealizationStrategy` is the computational *content* of the + dependent function ``Pi(ctx). Question(ctx)``: a strategy that maps + an anchor and a context to a concrete prompt string. +- :class:`DriftGuard` is the *type-checker*: it verifies that a realized + prompt still inhabits the type defined by its anchor. + +On top of these, :class:`QuestionFamily` packages an anchor with a +realization strategy and a drift guard, and :class:`AnnotationProtocol` +sequences families into the iterated dependent product +``Sigma(a_1 : Q_1(ctx)). Sigma(a_2 : Q_2(ctx, a_1)). ...``, threading +responses through the context so later questions can condition on +earlier answers. + +The :mod:`~bead.protocol.encoding` and :mod:`~bead.protocol.diagnostics` +submodules add a likelihood-agnostic response-encoding layer and an +immutable diagnostic-record system used by both the protocol layer and +downstream modeling code. +""" + +from __future__ import annotations + +from bead.protocol.anchor import ResponseSpace, SemanticAnchor +from bead.protocol.context import ( + ContextItem, + ContextPredicate, + ProtocolContext, + always, + get_context_predicate, + list_context_predicates, + register_context_predicate, +) +from bead.protocol.diagnostics import ( + ConditionalObservationValidator, + DatasetReport, + DiagnosticLevel, + DiagnosticRecord, + RecordLike, +) +from bead.protocol.drift import ( + DriftGuard, + DriftScore, + DriftValidator, + EmbeddingAdapter, + EmbeddingDriftValidator, + PerplexityAdapter, + PerplexityDriftValidator, + StructuralDriftValidator, +) +from bead.protocol.encoding import ResponseEncoding, ScaleType, encode_response_space +from bead.protocol.family import ( + AnnotationProtocol, + ApplicabilityPredicate, + QuestionFamily, + QuestionRealization, +) +from bead.protocol.items import ( + family_to_item_template, + protocol_to_item_templates, + realization_to_item, + realize_protocol_to_items, + scale_type_to_task_type, +) +from bead.protocol.realization import ( + ContextualTemplateRealization, + LMClient, + LMRealization, + RealizationStrategy, + TemplateRealization, + TemplateVariant, +) + +__all__ = [ + "AnnotationProtocol", + "ApplicabilityPredicate", + "ConditionalObservationValidator", + "ContextItem", + "ContextPredicate", + "ContextualTemplateRealization", + "DatasetReport", + "DiagnosticLevel", + "DiagnosticRecord", + "DriftGuard", + "DriftScore", + "DriftValidator", + "EmbeddingAdapter", + "EmbeddingDriftValidator", + "LMClient", + "LMRealization", + "PerplexityAdapter", + "PerplexityDriftValidator", + "ProtocolContext", + "QuestionFamily", + "QuestionRealization", + "RealizationStrategy", + "RecordLike", + "ResponseEncoding", + "ResponseSpace", + "ScaleType", + "SemanticAnchor", + "StructuralDriftValidator", + "TemplateRealization", + "TemplateVariant", + "always", + "encode_response_space", + "family_to_item_template", + "protocol_to_item_templates", + "realization_to_item", + "realize_protocol_to_items", + "scale_type_to_task_type", + "get_context_predicate", + "list_context_predicates", + "register_context_predicate", +] diff --git a/bead/protocol/anchor.py b/bead/protocol/anchor.py new file mode 100644 index 0000000..7b0ba95 --- /dev/null +++ b/bead/protocol/anchor.py @@ -0,0 +1,271 @@ +"""Semantic anchors: the type-level specification of a question. + +A :class:`SemanticAnchor` defines *what* a question measures, +independently of how it is phrased. It is the invariant that any +realization must preserve, the *type* in the dependent type +``Question(ctx)``. + +The anchor includes: + +- a *canonical prompt*: the reference phrasing +- a *response space*: the set of valid responses and their ordering +- *structural constraints*: keywords, span references, and embedding + bounds that any realization must satisfy +""" + +from __future__ import annotations + +from typing import Self + +import didactic.api as dx + +from bead.data.base import BeadBaseModel + + +class SemanticPoles(BeadBaseModel): + """Pole labels for an ordered response scale. + + Ordered scales are characterized by their two participant-facing + endpoint labels (for example ``low="definitely no"`` and + ``high="definitely yes"``). Unordered scales have no poles and use + ``None`` in place of an instance of this model. + + Attributes + ---------- + low : str + Label of the low end of the scale. + high : str + Label of the high end of the scale. + + Examples + -------- + >>> poles = SemanticPoles(low="definitely no", high="definitely yes") + >>> poles.as_tuple() + ('definitely no', 'definitely yes') + """ + + low: str + high: str + + def as_tuple(self) -> tuple[str, str]: + """Return ``(low, high)`` as a 2-tuple. + + Returns + ------- + tuple[str, str] + The pole labels as a Python tuple. + """ + return (self.low, self.high) + + +class ResponseSpace(BeadBaseModel): + """The space of valid responses for a question. + + Attributes + ---------- + options : tuple[str, ...] + Ordered response options. + is_ordered : bool + Whether the options form an ordinal scale. Defaults to + ``True``. + semantic_poles : SemanticPoles | None + Pole labels for ordered scales (for example ``low="never"``, + ``high="always"``). ``None`` for unordered (categorical) + response spaces. Defaults to ``None``. + + Examples + -------- + >>> rs = ResponseSpace( + ... options=("definitely no", "probably no", "unsure", + ... "probably yes", "definitely yes"), + ... is_ordered=True, + ... semantic_poles=SemanticPoles( + ... low="definitely no", high="definitely yes", + ... ), + ... ) + >>> len(rs) + 5 + >>> "probably yes" in rs + True + """ + + options: tuple[str, ...] + is_ordered: bool = True + semantic_poles: dx.Embed[SemanticPoles] | None = None + + def __len__(self) -> int: + """Return the number of response options.""" + return len(self.options) + + def __contains__(self, item: str) -> bool: + """Return whether ``item`` is one of the response options. + + Parameters + ---------- + item : str + Candidate response label. + + Returns + ------- + bool + ``True`` when ``item`` is a registered option. + """ + return item in self.options + + +class SemanticAnchor(BeadBaseModel): + """Type-level specification of what a question measures. + + Any realization of a question must preserve the anchor's semantic + content. The anchor is the *type*; a realized prompt string is the + *value*. + + Attributes + ---------- + name : str + Short identifier (for example ``"completion"``). + target_property : str + The property being measured (for example ``"telicity"``). + canonical_prompt : str + Reference phrasing of the question. Serves as both + documentation and the default template. + response_space : ResponseSpace + Valid responses. + required_span_labels : frozenset[str] + Span labels that must appear in any realization (for example + ``frozenset({"situation"})``). Defaults to the empty set. + required_keywords : frozenset[str] + Keywords that must appear in any realization to preserve + semantic content. Used by :class:`StructuralDriftValidator`. + Defaults to the empty set. + embedding_center : tuple[float, ...] | None + Pre-computed embedding of the canonical prompt for drift + validation via cosine distance. ``None`` means the embedding + is computed on demand by the validator. Defaults to ``None``. + max_drift : float + Maximum allowed cosine distance. + description : str + Human-readable description. + + Examples + -------- + >>> rs = ResponseSpace( + ... options=("no", "yes"), is_ordered=False + ... ) + >>> anchor = SemanticAnchor( + ... name="dynamicity", + ... target_property="dynamic", + ... canonical_prompt="Is anything changing during [[situation]]?", + ... response_space=rs, + ... required_span_labels=frozenset({"situation"}), + ... required_keywords=frozenset({"changing"}), + ... ) + >>> anchor.name + 'dynamicity' + + See Also + -------- + bead.protocol.drift.StructuralDriftValidator : Enforces + ``required_span_labels`` and ``required_keywords``. + bead.protocol.drift.EmbeddingDriftValidator : Enforces + ``embedding_center`` and ``max_drift``. + """ + + name: str + target_property: str + canonical_prompt: str + response_space: dx.Embed[ResponseSpace] + required_span_labels: frozenset[str] = dx.field(default_factory=frozenset) + required_keywords: frozenset[str] = dx.field(default_factory=frozenset) + embedding_center: tuple[float, ...] | None = None + max_drift: float = 0.3 + description: str = "" + + @classmethod + def from_response_options( + cls, + *, + name: str, + target_property: str, + canonical_prompt: str, + options: tuple[str, ...], + is_ordered: bool = True, + semantic_poles: SemanticPoles | None = None, + required_span_labels: frozenset[str] = frozenset(), + required_keywords: frozenset[str] = frozenset(), + embedding_center: tuple[float, ...] | None = None, + max_drift: float = 0.3, + description: str = "", + ) -> Self: + """Build an anchor from a flat list of response options. + + Convenience constructor for the common case in which a + :class:`ResponseSpace` is built inline from its options. + + Parameters + ---------- + name : str + Short identifier. + target_property : str + The property being measured. + canonical_prompt : str + Reference phrasing. + options : tuple[str, ...] + Ordered response options. + is_ordered : bool, optional + Whether the options form an ordinal scale. Defaults to + ``True``. + semantic_poles : SemanticPoles | None, optional + Pole labels for ordered scales. Defaults to ``None``. + required_span_labels : frozenset[str], optional + Span labels required in every realization. Defaults to the + empty set. + required_keywords : frozenset[str], optional + Keywords required in every realization. Defaults to the + empty set. + embedding_center : tuple[float, ...] | None, optional + Pre-computed canonical-prompt embedding. Defaults to + ``None``. + max_drift : float, optional + Maximum allowed cosine distance. Defaults to ``0.3``. + description : str, optional + Human-readable description. Defaults to the empty string. + + Returns + ------- + SemanticAnchor + A new anchor with an inline-constructed response space. + + Examples + -------- + >>> anchor = SemanticAnchor.from_response_options( + ... name="completion", + ... target_property="telicity", + ... canonical_prompt="Does [[situation]] reach an endpoint?", + ... options=("definitely no", "probably no", "unsure", + ... "probably yes", "definitely yes"), + ... is_ordered=True, + ... semantic_poles=SemanticPoles( + ... low="definitely no", high="definitely yes" + ... ), + ... required_span_labels=frozenset({"situation"}), + ... ) + >>> anchor.response_space.is_ordered + True + """ + space = ResponseSpace( + options=options, + is_ordered=is_ordered, + semantic_poles=semantic_poles, + ) + return cls( + name=name, + target_property=target_property, + canonical_prompt=canonical_prompt, + response_space=space, + required_span_labels=required_span_labels, + required_keywords=required_keywords, + embedding_center=embedding_center, + max_drift=max_drift, + description=description, + ) diff --git a/bead/protocol/context.py b/bead/protocol/context.py new file mode 100644 index 0000000..3cd1aa4 --- /dev/null +++ b/bead/protocol/context.py @@ -0,0 +1,336 @@ +"""Annotation contexts: dependent indices for question realization. + +A :class:`ProtocolContext` gathers everything known about the current +annotation target into a single immutable object. It is the *index* +in the dependent type ``Question(ctx)``: different contexts license +different questions and different phrasings. + +The context layer is deliberately domain-neutral. It carries +sentence-level, target-level, and dependent-level information common +to most token- or span-targeted annotation protocols, plus a +JSON-shaped ``metadata`` map (inherited from +:class:`~bead.data.base.BeadBaseModel`) for domain-specific data +that does not fit the standard fields. Domain-specific *predicates +over* the context live in the **predicate registry** documented at +the bottom of this module: callers register named predicates at +import time and refer to them by name from +:class:`~bead.protocol.realization.ContextualTemplateRealization`. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Self + +import didactic.api as dx + +from bead.data.base import BeadBaseModel + + +class ContextItem(BeadBaseModel): + """Generic per-token-or-span dependent context. + + Captures the structural properties of a single dependent (an + argument, an adjunct, a related span, ...) of the annotation + target. Domain-specific scalar attributes live in + :attr:`attributes`. + + Attributes + ---------- + node_id : str + Identifier from the upstream parse or annotation source. + Defaults to the empty string. + head_lemma : str + Lemma of the dependent head. Defaults to the empty string. + head_form : str + Surface form of the dependent head. Defaults to the empty + string. + head_upos : str + Universal POS tag of the dependent head. Defaults to the empty + string. + head_position : int + 1-based token position of the dependent head. Defaults to + ``0``. + span_text : str + Full surface span text of the dependent. Defaults to the empty + string. + span_positions : tuple[int, ...] + 1-based token positions in the dependent span. Defaults to + the empty tuple. + is_plural : bool + Whether the dependent head is morphologically plural. Defaults + to ``False``. + attributes : dict[str, float] + Domain-specific scalar attributes, keyed by attribute name + (for example ``{"definiteness": 0.7}``). Defaults to the + empty dict. + """ + + node_id: str = "" + head_lemma: str = "" + head_form: str = "" + head_upos: str = "" + head_position: int = 0 + span_text: str = "" + span_positions: tuple[int, ...] = () + is_plural: bool = False + attributes: dict[str, float] = dx.field(default_factory=dict) + + def attribute(self, name: str) -> float | None: + """Return the value of a named attribute, or ``None`` if absent. + + Parameters + ---------- + name : str + Attribute name to look up. + + Returns + ------- + float | None + The attribute value, or ``None`` when the attribute is not + present on this context item. + + Examples + -------- + >>> item = ContextItem( + ... attributes={"change_of_state": 4.2, + ... "instigation": 3.1}, + ... ) + >>> item.attribute("change_of_state") + 4.2 + >>> item.attribute("absent") is None + True + """ + return self.attributes.get(name) + + +class ProtocolContext(BeadBaseModel): + """Everything known about the current annotation target. + + This is the value that parameterizes the dependent question type. + Question families inspect the context to decide *which* question + variant to realize and *how* to phrase it. + + The :meth:`with_response` method threads an annotator response + into the context, supporting dependent products in which later + questions condition on earlier answers. + + Attributes + ---------- + sentence : str + Full sentence text. Defaults to the empty string. + tokens : tuple[str, ...] + Sentence tokens, in order. Defaults to the empty tuple. + tokens_lemma : tuple[str, ...] + Token lemmas, in order. Defaults to the empty tuple. + tokens_upos : tuple[str, ...] + Universal POS tags, in order. Defaults to the empty tuple. + target_lemma : str + Lemma of the annotation target's head. Defaults to the empty + string. + target_form : str + Surface form of the target head. Defaults to the empty string. + target_upos : str + UPOS tag of the target head. Defaults to the empty string. + target_position : int + 1-based token position of the target head. Defaults to ``0``. + target_span_text : str + Full surface span text of the target. Defaults to the empty + string. + target_span_positions : tuple[int, ...] + 1-based token positions of the target span. Defaults to the + empty tuple. + dependents : tuple[ContextItem, ...] + Structural dependents of the target. Defaults to the empty + tuple. + previous_responses : dict[str, str] + Annotator responses to earlier questions, keyed by anchor + name. Defaults to the empty dict. + target_id : str + Identifier for the annotation target, for traceability. + Defaults to the empty string. + source_id : str + Identifier for the source document or graph. Defaults to the + empty string. + + See Also + -------- + register_context_predicate : Register a named predicate over + :class:`ProtocolContext` instances. + """ + + sentence: str = "" + tokens: tuple[str, ...] = () + tokens_lemma: tuple[str, ...] = () + tokens_upos: tuple[str, ...] = () + + target_lemma: str = "" + target_form: str = "" + target_upos: str = "" + target_position: int = 0 + target_span_text: str = "" + target_span_positions: tuple[int, ...] = () + + dependents: tuple[dx.Embed[ContextItem], ...] = () + + previous_responses: dict[str, str] = dx.field(default_factory=dict) + + target_id: str = "" + source_id: str = "" + + def with_response(self, question_name: str, response: str) -> Self: + """Return a new context with one additional response recorded. + + Supports the dependent-product structure: the type of a later + question can depend on the value (response) of an earlier + question. + + Parameters + ---------- + question_name : str + Name of the anchor whose response is being recorded. + response : str + The annotator's response label. + + Returns + ------- + ProtocolContext + A new context whose :attr:`previous_responses` includes + ``{question_name: response}``. + + Examples + -------- + >>> ctx = ProtocolContext(sentence="Mary built a sandcastle.") + >>> ctx2 = ctx.with_response("dynamicity", "yes") + >>> ctx2.previous_responses + {'dynamicity': 'yes'} + >>> ctx.previous_responses + {} + """ + updated = {**self.previous_responses, question_name: response} + return self.with_(previous_responses=updated) + + def get_response(self, question_name: str) -> str | None: + """Return the recorded response for a question, or ``None``. + + Parameters + ---------- + question_name : str + The anchor name to look up. + + Returns + ------- + str | None + The recorded response label, or ``None`` if no response + has been threaded for this question. + """ + return self.previous_responses.get(question_name) + + +# --------------------------------------------------------------------------- +# Context-predicate registry +# --------------------------------------------------------------------------- +# +# A :data:`ContextPredicate` is a function ``ProtocolContext -> bool`` +# used by :class:`~bead.protocol.realization.ContextualTemplateRealization` +# to select among template variants based on context properties. +# +# The registry is intentionally module-level mutable state. It is +# populated at import time by user code and read at realization time. +# Do not mutate it from request-path code: the registry is not +# thread-safe and is not intended to carry per-request state. + +ContextPredicate = Callable[[ProtocolContext], bool] +"""Type alias for predicates over :class:`ProtocolContext`.""" + + +_PREDICATES: dict[str, ContextPredicate] = {} + + +def register_context_predicate(name: str, predicate: ContextPredicate) -> None: + """Register a named predicate over :class:`ProtocolContext`. + + Callers register their domain-specific predicates at import time. + The registered predicates are then available by name to + :class:`~bead.protocol.realization.ContextualTemplateRealization` + and other realization strategies that select among variants. + + Parameters + ---------- + name : str + Unique predicate name. Re-registering an existing name + overwrites the previous predicate. + predicate : ContextPredicate + Callable that returns ``True`` when the context matches. + + Examples + -------- + >>> def has_plural_dependent(ctx: ProtocolContext) -> bool: + ... return any(d.is_plural for d in ctx.dependents) + >>> register_context_predicate( + ... "has_plural_dependent", has_plural_dependent + ... ) + >>> get_context_predicate("has_plural_dependent") is has_plural_dependent + True + """ + _PREDICATES[name] = predicate + + +def get_context_predicate(name: str) -> ContextPredicate: + """Look up a registered predicate by name. + + Parameters + ---------- + name : str + The predicate name to look up. + + Returns + ------- + ContextPredicate + The registered predicate. + + Raises + ------ + KeyError + If no predicate with that name is registered. + """ + try: + return _PREDICATES[name] + except KeyError: + raise KeyError( + f"No context predicate registered under name {name!r}. " + f"Registered: {sorted(_PREDICATES)}" + ) from None + + +def list_context_predicates() -> tuple[str, ...]: + """Return the names of all registered context predicates, sorted. + + Returns + ------- + tuple[str, ...] + All registered predicate names in sorted order. + """ + return tuple(sorted(_PREDICATES)) + + +def always(_ctx: ProtocolContext) -> bool: + """Predicate that matches every context. + + Used as the catch-all condition for fallback template variants and + the default applicability predicate for question families. + + Parameters + ---------- + _ctx : ProtocolContext + Ignored. + + Returns + ------- + bool + Always ``True``. + """ + return True + + +register_context_predicate("always", always) diff --git a/bead/protocol/diagnostics.py b/bead/protocol/diagnostics.py new file mode 100644 index 0000000..e71ef77 --- /dev/null +++ b/bead/protocol/diagnostics.py @@ -0,0 +1,406 @@ +"""Dataset diagnostics and quality reporting for annotation protocols. + +Provides :class:`DatasetReport`, a structured immutable summary of +quality issues discovered during dataset preparation, and +:class:`ConditionalObservationValidator`, which checks that responses +to conditional questions respect the protocol's +:attr:`~bead.protocol.family.QuestionFamily.depends_on` graph. + +Diagnostic findings are immutable :class:`DiagnosticRecord` instances +collected in order of discovery. The :meth:`DatasetReport.summary` +method produces a human-readable overview suitable for logging. +""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from enum import StrEnum +from typing import Protocol, Self, runtime_checkable + +import didactic.api as dx + +from bead.data.base import BeadBaseModel +from bead.protocol.family import AnnotationProtocol + + +class DiagnosticLevel(StrEnum): + """Severity of a diagnostic finding. + + Attributes + ---------- + INFO : str + Informational message. Wire value: ``"info"``. + WARNING : str + Warning that does not prevent dataset use. Wire value: + ``"warning"``. + ERROR : str + Error that may invalidate downstream analysis. Wire value: + ``"error"``. + """ + + INFO = "info" + WARNING = "warning" + ERROR = "error" + + +class DiagnosticRecord(BeadBaseModel): + """A single diagnostic finding. + + Attributes + ---------- + level : DiagnosticLevel + Severity of the finding. + category : str + Short category tag (for example ``"missing_embedding"`` or + ``"unrecognized_label"``). + message : str + Human-readable description. + item_id : str | None + The item this finding pertains to, if applicable. Defaults to + ``None``. + question_name : str | None + The anchor name this finding pertains to, if applicable. + Defaults to ``None``. + """ + + level: DiagnosticLevel + category: str + message: str + item_id: str | None = None + question_name: str | None = None + + +class DatasetReport(BeadBaseModel): + """Immutable structured report of dataset-preparation quality. + + Mutating methods (:meth:`add`, :meth:`with_coverage`, + :meth:`with_missing_embedding`) follow the bead convention of + returning a new instance via ``.with_(...)``; the original is + unchanged. + + Attributes + ---------- + n_records_input : int + Total number of input records received. Defaults to ``0``. + n_items : int + Number of unique item ids. Defaults to ``0``. + n_records_encoded : int + Number of records successfully encoded. Defaults to ``0``. + n_records_dropped : int + Number of records dropped. Defaults to ``0``. + coverage : dict[str, float] + Per-question response-coverage rate (fraction of items with a + valid response). Defaults to the empty dict. + findings : tuple[DiagnosticRecord, ...] + All diagnostic findings, in order of discovery. Defaults to + the empty tuple. + items_missing_embeddings : tuple[str, ...] + Item ids that had no embedding provided. Defaults to the empty + tuple. + """ + + n_records_input: int = 0 + n_items: int = 0 + n_records_encoded: int = 0 + n_records_dropped: int = 0 + coverage: dict[str, float] = dx.field(default_factory=dict) + findings: tuple[dx.Embed[DiagnosticRecord], ...] = () + items_missing_embeddings: tuple[str, ...] = () + + def add( + self, + level: DiagnosticLevel, + category: str, + message: str, + *, + item_id: str | None = None, + question_name: str | None = None, + ) -> Self: + """Return a new report with one additional finding appended. + + Parameters + ---------- + level : DiagnosticLevel + Severity. + category : str + Category tag. + message : str + Description. + item_id : str | None, optional + Related item id. Defaults to ``None``. + question_name : str | None, optional + Related anchor name. Defaults to ``None``. + + Returns + ------- + DatasetReport + New report with the finding added. + """ + record = DiagnosticRecord( + level=level, + category=category, + message=message, + item_id=item_id, + question_name=question_name, + ) + return self.with_(findings=(*self.findings, record)) + + def extend(self, records: Sequence[DiagnosticRecord]) -> Self: + """Return a new report with multiple findings appended. + + Parameters + ---------- + records : Sequence[DiagnosticRecord] + Findings to append. + + Returns + ------- + DatasetReport + New report with the findings added. + """ + return self.with_(findings=(*self.findings, *records)) + + def with_coverage(self, question_name: str, rate: float) -> Self: + """Return a new report with one coverage entry set. + + Parameters + ---------- + question_name : str + Anchor name. + rate : float + Coverage rate in ``[0.0, 1.0]``. + + Returns + ------- + DatasetReport + New report with the entry set or replaced. + """ + new_coverage = dict(self.coverage) + new_coverage[question_name] = rate + return self.with_(coverage=new_coverage) + + def with_missing_embedding(self, item_id: str) -> Self: + """Return a new report flagging one item as missing an embedding. + + If ``item_id`` is already flagged the report is returned + unchanged (the missing-embedding list is a set semantically). + + Parameters + ---------- + item_id : str + The item id that lacked an embedding. + + Returns + ------- + DatasetReport + New report with the item recorded. + """ + if item_id in self.items_missing_embeddings: + return self + return self.with_( + items_missing_embeddings=(*self.items_missing_embeddings, item_id) + ) + + @property + def has_warnings(self) -> bool: + """Whether any warning-level findings exist.""" + return any(f.level == DiagnosticLevel.WARNING for f in self.findings) + + @property + def has_errors(self) -> bool: + """Whether any error-level findings exist.""" + return any(f.level == DiagnosticLevel.ERROR for f in self.findings) + + @property + def warnings(self) -> tuple[DiagnosticRecord, ...]: + """All warning-level findings, in discovery order.""" + return tuple(f for f in self.findings if f.level == DiagnosticLevel.WARNING) + + @property + def errors(self) -> tuple[DiagnosticRecord, ...]: + """All error-level findings, in discovery order.""" + return tuple(f for f in self.findings if f.level == DiagnosticLevel.ERROR) + + def by_category(self, category: str) -> tuple[DiagnosticRecord, ...]: + """Filter findings by category tag. + + Parameters + ---------- + category : str + Category tag to filter on. + + Returns + ------- + tuple[DiagnosticRecord, ...] + Matching findings, in discovery order. + """ + return tuple(f for f in self.findings if f.category == category) + + def summary(self) -> str: + """Produce a human-readable multi-line summary. + + Returns + ------- + str + A summary string suitable for logging. + """ + lines = [ + f"DatasetReport: {self.n_items} items, {self.n_records_input} records", + f" encoded: {self.n_records_encoded}, dropped: {self.n_records_dropped}", + ] + + if self.items_missing_embeddings: + lines.append( + f" items missing embeddings: {len(self.items_missing_embeddings)}" + ) + + if self.coverage: + lines.append(" coverage:") + for name, rate in sorted(self.coverage.items()): + lines.append(f" {name}: {rate:.1%}") + + n_warn = len(self.warnings) + n_err = len(self.errors) + if n_warn or n_err: + lines.append(f" warnings: {n_warn}, errors: {n_err}") + + return "\n".join(lines) + + +@runtime_checkable +class RecordLike(Protocol): + """Structural type for records consumed by the validator. + + Any object with the three attributes below conforms. The bead + :class:`~bead.evaluation.reliability.AnnotationRecord` is a + canonical example. + + Attributes + ---------- + item_id : str + Identifier of the annotation item. + response_label : str + Annotator's response label. + question_name : str + Anchor name of the question being answered. + """ + + item_id: str + response_label: str + question_name: str + + +@dataclass(frozen=True) +class ConditionalObservationValidator: + """Verify that conditional responses respect protocol dependencies. + + For every family in a protocol with non-empty + :attr:`~bead.protocol.family.QuestionFamily.depends_on`, the + validator checks two things: + + 1. *Dependency presence*: each item with a response on the + conditional question must also have a response on every + upstream question. + 2. *Dependency value* (optional): when ``conditioning_values`` is + supplied for the conditional anchor, the upstream response + must be one of the allowed labels. + + Findings are emitted as :class:`DiagnosticRecord` instances at the + :attr:`DiagnosticLevel.WARNING` level. + + Parameters + ---------- + conditioning_values : Mapping[str, set[str]] | None, optional + Per-conditional-anchor mapping from upstream label set to + validity. When omitted the validator only checks dependency + presence. Defaults to ``None``. + + Attributes + ---------- + conditioning_values : Mapping[str, set[str]] + Conditioning-value table (immutable view). + """ + + conditioning_values: Mapping[str, set[str]] = field(default_factory=dict) + + def validate( + self, + records_by_question: Mapping[str, Sequence[RecordLike]], + protocol: AnnotationProtocol, + ) -> tuple[DiagnosticRecord, ...]: + """Check conditional-observation consistency for a protocol. + + Parameters + ---------- + records_by_question : Mapping[str, Sequence[record-like]] + Records grouped by anchor name. Each record must expose + ``item_id``, ``response_label``, and ``question_name`` + attributes. + protocol : AnnotationProtocol + The protocol whose dependency edges drive the validation. + + Returns + ------- + tuple[DiagnosticRecord, ...] + Warning-level findings for any inconsistencies detected. + """ + findings: list[DiagnosticRecord] = [] + + response_lookup: dict[str, dict[str, str]] = {} + for q_name, records in records_by_question.items(): + lookup: dict[str, str] = {} + for rec in records: + lookup[rec.item_id] = rec.response_label + response_lookup[q_name] = lookup + + for family in protocol.families: + if not family.depends_on: + continue + + obs_responses = response_lookup.get(family.name, {}) + + for item_id in obs_responses: + for dep_name in family.depends_on: + dep_responses = response_lookup.get(dep_name, {}) + + if item_id not in dep_responses: + findings.append( + DiagnosticRecord( + level=DiagnosticLevel.WARNING, + category="conditional_missing_dependency", + message=( + f"conditional observation " + f"{family.name!r} has response for " + f"item {item_id!r} but conditioning " + f"observation {dep_name!r} has no " + f"response" + ), + item_id=item_id, + question_name=family.name, + ) + ) + continue + + if family.name in self.conditioning_values: + valid_vals = self.conditioning_values[family.name] + dep_label = dep_responses[item_id] + if dep_label not in valid_vals: + findings.append( + DiagnosticRecord( + level=DiagnosticLevel.WARNING, + category="conditional_inapplicable", + message=( + f"conditional observation " + f"{family.name!r} has response for " + f"item {item_id!r} but conditioning " + f"observation {dep_name!r} has value " + f"{dep_label!r} (expected one of " + f"{sorted(valid_vals)})" + ), + item_id=item_id, + question_name=family.name, + ) + ) + + return tuple(findings) diff --git a/bead/protocol/drift.py b/bead/protocol/drift.py new file mode 100644 index 0000000..66d9c5b --- /dev/null +++ b/bead/protocol/drift.py @@ -0,0 +1,550 @@ +"""Drift validation: the type-checker for realized prompts. + +A :class:`DriftGuard` verifies that a realized prompt still inhabits +the type defined by its :class:`~bead.protocol.anchor.SemanticAnchor`. +Without drift control an LM paraphraser, or even a rule-based +selector, may produce prompts that subtly change what is being +measured. + +Three validators are provided: + +- :class:`StructuralDriftValidator` checks that required span + references and keywords appear in the realization and that the + question is well-formed. +- :class:`EmbeddingDriftValidator` checks that the embedding of the + realized prompt is within a configured cosine distance of the + anchor's canonical-prompt embedding. +- :class:`PerplexityDriftValidator` flags realizations whose language- + model perplexity exceeds a configured ceiling. + +These compose under a :class:`DriftGuard`, which runs all configured +validators and aggregates their findings: a realization passes only +when every validator passes. +""" + +from __future__ import annotations + +import math +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Protocol, runtime_checkable + +from bead.data.base import BeadBaseModel +from bead.labels import find_label_names +from bead.protocol.anchor import SemanticAnchor +from bead.protocol.context import ProtocolContext + + +@runtime_checkable +class EmbeddingAdapter(Protocol): + """Structural type for objects that can embed text. + + Conforms to bead :class:`~bead.items.adapters.ModelAdapter` and to + any other object exposing a ``get_embedding`` method that returns + a sequence of floats. + + Examples + -------- + >>> class StubEmbedder: + ... def get_embedding(self, text: str) -> Sequence[float]: + ... return (1.0, 0.0) + >>> isinstance(StubEmbedder(), EmbeddingAdapter) + True + """ + + def get_embedding(self, text: str) -> Sequence[float]: + """Embed ``text`` to a fixed-length sequence of floats. + + Parameters + ---------- + text : str + Text to embed. + + Returns + ------- + Sequence[float] + Embedding vector, treated as a flat sequence of floats. + """ + ... + + +@runtime_checkable +class PerplexityAdapter(Protocol): + """Structural type for objects that can score text perplexity. + + Conforms to bead :class:`~bead.items.adapters.ModelAdapter` and to + any other object exposing a ``compute_perplexity`` method. + """ + + def compute_perplexity(self, text: str) -> float: + """Compute the perplexity of ``text`` under the backend. + + Parameters + ---------- + text : str + Text to score. + + Returns + ------- + float + Perplexity in the open interval ``(0, +inf)``. + """ + ... + + +class DriftScore(BeadBaseModel): + """Result of one or more drift validation checks. + + Attributes + ---------- + passed : bool + Whether the realization passes the validators that produced + this score. Defaults to ``True``. + structural_ok : bool + Whether structural constraints are satisfied. Defaults to + ``True``. + embedding_distance : float | None + Cosine distance from the canonical-prompt embedding, if an + embedding validator ran. Defaults to ``None``. + perplexity : float | None + Perplexity of the realized prompt under the validating + language model, if a perplexity validator ran. Defaults to + ``None``. + findings : tuple[str, ...] + Human-readable descriptions of any issues found. Defaults to + the empty tuple. + """ + + passed: bool = True + structural_ok: bool = True + embedding_distance: float | None = None + perplexity: float | None = None + findings: tuple[str, ...] = () + + +@runtime_checkable +class DriftValidator(Protocol): + """Protocol for a single drift-validation check. + + Examples + -------- + A minimal conforming validator: + + >>> class AlwaysPasses: + ... def validate(self, realization, anchor, context): + ... return DriftScore(passed=True) + >>> isinstance(AlwaysPasses(), DriftValidator) + True + """ + + def validate( + self, + realization: str, + anchor: SemanticAnchor, + context: ProtocolContext, + ) -> DriftScore: + """Check the realization against the anchor. + + Parameters + ---------- + realization : str + The realized prompt string. + anchor : SemanticAnchor + The semantic specification. + context : ProtocolContext + The annotation context. + + Returns + ------- + DriftScore + Validation result. + """ + ... + + +@dataclass(frozen=True) +class StructuralDriftValidator: + """Validate structural properties of a realized prompt. + + Checks that: + + 1. All required span labels appear as ``[[label]]`` references. + 2. Required keywords appear somewhere in the prompt. + 3. The prompt ends with appropriate punctuation. + 4. The prompt is not trivially short. + + Parameters + ---------- + min_length : int, optional + Minimum non-whitespace character length for a valid prompt. + Defaults to ``15``. + require_question_mark : bool, optional + Whether the realization must end with ``?``. Defaults to + ``True``. + keyword_case_sensitive : bool, optional + Whether keyword checks are case-sensitive. Defaults to + ``False``. + + Attributes + ---------- + min_length : int + Minimum prompt length. + require_question_mark : bool + Whether the trailing ``?`` is required. + keyword_case_sensitive : bool + Whether keyword matches are case-sensitive. + """ + + min_length: int = 15 + require_question_mark: bool = True + keyword_case_sensitive: bool = False + + def validate( + self, + realization: str, + anchor: SemanticAnchor, + context: ProtocolContext, # noqa: ARG002 + ) -> DriftScore: + """Run the structural checks against a realization. + + Parameters + ---------- + realization : str + The realized prompt string. + anchor : SemanticAnchor + The semantic specification supplying required labels and + keywords. + context : ProtocolContext + The annotation context (unused by this validator but + required by the :class:`DriftValidator` protocol). + + Returns + ------- + DriftScore + Score with ``structural_ok`` set and any failures listed + in ``findings``. + """ + findings: list[str] = [] + structural_ok = True + + stripped = realization.strip() + + if len(stripped) < self.min_length: + findings.append( + f"Realization too short ({len(stripped)} chars, " + f"minimum {self.min_length})" + ) + structural_ok = False + + found_labels = find_label_names(realization) + for required in anchor.required_span_labels: + if required not in found_labels: + findings.append(f"Missing required span reference [[{required}]]") + structural_ok = False + + check_text = realization if self.keyword_case_sensitive else realization.lower() + for keyword in anchor.required_keywords: + check_keyword = keyword if self.keyword_case_sensitive else keyword.lower() + if check_keyword not in check_text: + findings.append(f"Missing required keyword: {keyword!r}") + structural_ok = False + + if self.require_question_mark and not stripped.endswith("?"): + findings.append("Realization does not end with '?'") + structural_ok = False + + return DriftScore( + passed=structural_ok, + structural_ok=structural_ok, + findings=tuple(findings), + ) + + +def _cosine_distance(a: tuple[float, ...], b: tuple[float, ...]) -> float: + """Compute cosine distance ``1 - cos(a, b)`` between two vectors. + + Parameters + ---------- + a, b : tuple[float, ...] + Equal-length vectors. + + Returns + ------- + float + Cosine distance in ``[0.0, 2.0]``. Returns ``1.0`` when either + vector has zero norm (treated as orthogonal). + + Raises + ------ + ValueError + If ``a`` and ``b`` have different lengths. + """ + if len(a) != len(b): + raise ValueError(f"Vector dimension mismatch: {len(a)} vs {len(b)}") + dot = sum(ai * bi for ai, bi in zip(a, b, strict=True)) + norm_a = math.sqrt(sum(ai * ai for ai in a)) + norm_b = math.sqrt(sum(bi * bi for bi in b)) + if norm_a == 0.0 or norm_b == 0.0: + return 1.0 + return 1.0 - dot / (norm_a * norm_b) + + +class EmbeddingDriftValidator: + """Validate that a realization is semantically close to the anchor. + + Computes cosine distance between the realization embedding and the + anchor's canonical-prompt embedding (either pre-computed in the + anchor or computed on demand from the canonical prompt). The + realization passes when the distance is at most the configured + maximum (or the anchor's :attr:`~SemanticAnchor.max_drift`, if no + explicit maximum is set). + + Embeddings are obtained from any object conforming to the + :class:`EmbeddingAdapter` Protocol, which includes the bead + :class:`~bead.items.adapters.ModelAdapter` family. + + Parameters + ---------- + adapter : EmbeddingAdapter + Adapter exposing ``get_embedding(text)``. The returned + sequence is treated as a flat vector and converted to a + ``tuple[float, ...]``. + max_distance : float | None, optional + Override for the anchor's ``max_drift`` value. Defaults to + ``None`` (use the anchor's own value). + + Attributes + ---------- + max_distance : float | None + Configured override, or ``None`` to defer to the anchor. + """ + + def __init__( + self, + adapter: EmbeddingAdapter, + *, + max_distance: float | None = None, + ) -> None: + self._adapter = adapter + self.max_distance = max_distance + + def _embed(self, text: str) -> tuple[float, ...]: + """Embed ``text`` via the wrapped adapter and coerce to tuple.""" + emb = self._adapter.get_embedding(text) + return tuple(float(x) for x in emb) + + def validate( + self, + realization: str, + anchor: SemanticAnchor, + context: ProtocolContext, # noqa: ARG002 + ) -> DriftScore: + """Score the realization by cosine distance from the anchor. + + Parameters + ---------- + realization : str + The realized prompt string. + anchor : SemanticAnchor + The semantic specification supplying the canonical prompt + (and optionally a pre-computed embedding center and a + ``max_drift`` value). + context : ProtocolContext + The annotation context (unused by this validator but + required by the :class:`DriftValidator` protocol). + + Returns + ------- + DriftScore + Score with ``embedding_distance`` set; ``passed`` is + ``True`` iff the distance is within the configured + maximum. + """ + if anchor.embedding_center is not None: + canonical = anchor.embedding_center + else: + canonical = self._embed(anchor.canonical_prompt) + + realization_emb = self._embed(realization) + distance = _cosine_distance(canonical, realization_emb) + + max_dist = ( + self.max_distance if self.max_distance is not None else anchor.max_drift + ) + passed = distance <= max_dist + + findings: tuple[str, ...] = () + if not passed: + findings = ( + f"Embedding distance {distance:.3f} exceeds maximum {max_dist:.3f}", + ) + + return DriftScore( + passed=passed, + embedding_distance=distance, + findings=findings, + ) + + +class PerplexityDriftValidator: + """Validate that a realization has acceptable language-model perplexity. + + Wraps any object conforming to the :class:`PerplexityAdapter` + Protocol (which includes the bead + :class:`~bead.items.adapters.ModelAdapter` family). The realization + passes when its perplexity is at most the configured ceiling. + Useful for catching ungrammatical or otherwise unnatural + LM-generated paraphrases that might still pass structural and + embedding checks. + + Parameters + ---------- + adapter : PerplexityAdapter + Adapter exposing ``compute_perplexity(text) -> float``. + max_perplexity : float + Maximum allowed perplexity. Realizations with perplexity above + this value fail. + + Attributes + ---------- + max_perplexity : float + The configured perplexity ceiling. + """ + + def __init__( + self, + adapter: PerplexityAdapter, + *, + max_perplexity: float, + ) -> None: + if max_perplexity <= 0.0: + raise ValueError("max_perplexity must be positive") + self._adapter = adapter + self.max_perplexity = max_perplexity + + def validate( + self, + realization: str, + anchor: SemanticAnchor, # noqa: ARG002 + context: ProtocolContext, # noqa: ARG002 + ) -> DriftScore: + """Score the realization by language-model perplexity. + + Parameters + ---------- + realization : str + The realized prompt string. + anchor : SemanticAnchor + The semantic specification (unused by this validator). + context : ProtocolContext + The annotation context (unused by this validator). + + Returns + ------- + DriftScore + Score with ``perplexity`` set; ``passed`` is ``True`` iff + ``perplexity <= max_perplexity``. + """ + perplexity = float(self._adapter.compute_perplexity(realization)) + passed = perplexity <= self.max_perplexity + + findings: tuple[str, ...] = () + if not passed: + findings = ( + f"Perplexity {perplexity:.2f} exceeds maximum " + f"{self.max_perplexity:.2f}", + ) + + return DriftScore( + passed=passed, + perplexity=perplexity, + findings=findings, + ) + + +@dataclass +class DriftGuard: + """Composite drift validator. + + Runs every configured validator and aggregates their results: the + aggregate :class:`DriftScore` ``passed`` field is ``True`` only + when every validator passes. Findings from all validators are + collected in order. ``embedding_distance`` and ``perplexity`` are + populated from the last validator that set them. + + Attributes + ---------- + validators : list[DriftValidator] + Mutable list of configured validators. Defaults to the empty + list; calls to :meth:`check` on a guard with no validators + always pass. + """ + + validators: list[DriftValidator] = field(default_factory=list) + + def add(self, validator: DriftValidator) -> None: + """Append a validator to the guard. + + Parameters + ---------- + validator : DriftValidator + The validator to add. + """ + self.validators.append(validator) + + def check( + self, + realization: str, + anchor: SemanticAnchor, + context: ProtocolContext, + ) -> DriftScore: + """Run every validator and return an aggregated score. + + Parameters + ---------- + realization : str + The realized prompt string. + anchor : SemanticAnchor + The semantic specification. + context : ProtocolContext + The annotation context. + + Returns + ------- + DriftScore + Aggregate score. ``passed`` is ``True`` iff every + validator passes; ``findings`` concatenates all + validator-level findings; ``embedding_distance`` and + ``perplexity`` are taken from the validators that set + them. + """ + all_findings: list[str] = [] + all_passed = True + structural_ok = True + embedding_distance: float | None = None + perplexity: float | None = None + + for validator in self.validators: + score = validator.validate(realization, anchor, context) + all_findings.extend(score.findings) + + if not score.passed: + all_passed = False + if not score.structural_ok: + structural_ok = False + if score.embedding_distance is not None: + embedding_distance = score.embedding_distance + if score.perplexity is not None: + perplexity = score.perplexity + + return DriftScore( + passed=all_passed, + structural_ok=structural_ok, + embedding_distance=embedding_distance, + perplexity=perplexity, + findings=tuple(all_findings), + ) + + def __len__(self) -> int: + """Return the number of configured validators.""" + return len(self.validators) diff --git a/bead/protocol/encoding.py b/bead/protocol/encoding.py new file mode 100644 index 0000000..8318ebe --- /dev/null +++ b/bead/protocol/encoding.py @@ -0,0 +1,262 @@ +"""Response-space encodings for probabilistic modeling. + +Bridges the annotation-side :class:`~bead.protocol.anchor.ResponseSpace` +representation and a model-ready description of a response scale, +providing the index-to-label mapping and scale-type metadata that +downstream modeling code needs. + +This module supports three response scale types: + +- *Binary*: two unordered options (for example + ``("change", "no_change")``). Naturally modeled via Bernoulli + likelihoods. +- *Ordinal*: ordered options on a Likert-like scale (for example a + five-point Likert scale from ``"definitely no"`` to ``"definitely + yes"``). Naturally modeled via cumulative-link (ordered logistic) + likelihoods. +- *Nominal*: unordered multi-option (for example a categorical choice + among unordered alternatives). Naturally modeled via softmax + categorical likelihoods. + +The encoding itself is likelihood-agnostic. It does *not* select a +likelihood family; downstream modeling code (for example +:mod:`bead.active_learning.models`) chooses the appropriate model +class based on the scale type. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Self + +import didactic.api as dx + +from bead.data.base import BeadBaseModel +from bead.protocol.anchor import ResponseSpace, SemanticPoles + + +class ScaleType(StrEnum): + """Classification of response-scale structure. + + Attributes + ---------- + BINARY : str + Two unordered options. Wire value: ``"binary"``. + ORDINAL : str + Ordered options forming an ordinal scale. Wire value: + ``"ordinal"``. + NOMINAL : str + Unordered multi-option scale. Wire value: ``"nominal"``. + """ + + BINARY = "binary" + ORDINAL = "ordinal" + NOMINAL = "nominal" + + +class ResponseEncoding(BeadBaseModel): + """Encoding of a response space for probabilistic modeling. + + Bridges the annotation-side :class:`ResponseSpace` and a + modeling-side representation, providing the index-to-label + mapping and scale-type metadata needed by both systems. + + Attributes + ---------- + name : str + Identifier for this encoding (typically the anchor name, for + example ``"completion"``). + n_levels : int + Number of response categories. Must equal ``len(labels)``. + scale_type : ScaleType + Whether the scale is binary, ordinal, or nominal. + labels : tuple[str, ...] + Human-readable labels for each index, in order. + semantic_poles : SemanticPoles | None + The two participant-facing endpoints of the scale, if ordered + (for example ``SemanticPoles(low="definitely no", + high="definitely yes")``). Defaults to ``None``. + + Examples + -------- + >>> enc = ResponseEncoding( + ... name="completion", + ... n_levels=5, + ... scale_type=ScaleType.ORDINAL, + ... labels=("definitely no", "probably no", "unsure", + ... "probably yes", "definitely yes"), + ... semantic_poles=SemanticPoles( + ... low="definitely no", high="definitely yes" + ... ), + ... ) + >>> enc.label_to_index("probably yes") + 3 + >>> enc.index_to_label(0) + 'definitely no' + >>> enc.is_ordinal + True + + See Also + -------- + encode_response_space : Build an encoding from a + :class:`ResponseSpace`. + """ + + name: str + n_levels: int + scale_type: ScaleType + labels: tuple[str, ...] + semantic_poles: dx.Embed[SemanticPoles] | None = None + + @dx.model_validator(mode="after") + def _check_levels_match_labels(self) -> Self: + """Enforce ``n_levels == len(labels)`` and label uniqueness.""" + if self.n_levels != len(self.labels): + raise ValueError( + f"n_levels ({self.n_levels}) does not match " + f"len(labels) ({len(self.labels)}) for encoding " + f"{self.name!r}" + ) + if len(set(self.labels)) != len(self.labels): + raise ValueError( + f"Duplicate labels in encoding {self.name!r}: {self.labels}" + ) + if self.scale_type == ScaleType.BINARY and self.n_levels != 2: + raise ValueError( + f"BINARY scale must have exactly 2 levels, got " + f"{self.n_levels} in encoding {self.name!r}" + ) + return self + + @property + def is_ordinal(self) -> bool: + """Whether the response scale is ordered.""" + return self.scale_type == ScaleType.ORDINAL + + @property + def is_binary(self) -> bool: + """Whether the response scale is binary.""" + return self.scale_type == ScaleType.BINARY + + @property + def is_nominal(self) -> bool: + """Whether the response scale is unordered multi-option.""" + return self.scale_type == ScaleType.NOMINAL + + def label_to_index(self, label: str) -> int: + """Convert a response label to its integer index. + + Parameters + ---------- + label : str + The response label string. + + Returns + ------- + int + The 0-based index of the label. + + Raises + ------ + ValueError + If the label is not in the encoding. + """ + try: + return self.labels.index(label) + except ValueError: + raise ValueError( + f"Label {label!r} not found in encoding {self.name!r}. " + f"Valid labels: {self.labels}" + ) from None + + def index_to_label(self, index: int) -> str: + """Convert an integer index to its response label. + + Parameters + ---------- + index : int + The 0-based index. + + Returns + ------- + str + The response label at that index. + + Raises + ------ + IndexError + If the index is out of range for this encoding. + """ + if index < 0 or index >= len(self.labels): + raise IndexError( + f"Index {index} out of range for encoding {self.name!r} " + f"with {len(self.labels)} levels." + ) + return self.labels[index] + + +def _classify_scale(response_space: ResponseSpace) -> ScaleType: + """Determine the :class:`ScaleType` of a response space. + + A two-option, unordered space is classified as binary; otherwise an + ordered space is ordinal and an unordered space is nominal. + + Parameters + ---------- + response_space : ResponseSpace + The response space to classify. + + Returns + ------- + ScaleType + The classified scale type. + """ + if len(response_space.options) == 2 and not response_space.is_ordered: + return ScaleType.BINARY + if response_space.is_ordered: + return ScaleType.ORDINAL + return ScaleType.NOMINAL + + +def encode_response_space( + name: str, + response_space: ResponseSpace, +) -> ResponseEncoding: + """Build a :class:`ResponseEncoding` from a :class:`ResponseSpace`. + + This is the primary bridge from the protocol layer to the modeling + layer. The resulting encoding shares its labels with the response + space and inherits the space's ordering as a :class:`ScaleType`. + + Parameters + ---------- + name : str + Name for the encoding (typically the anchor name, for example + ``"completion"``). + response_space : ResponseSpace + The response space to encode. + + Returns + ------- + ResponseEncoding + The modeling-side encoding. + + Examples + -------- + >>> rs = ResponseSpace( + ... options=("no", "yes"), is_ordered=False + ... ) + >>> enc = encode_response_space("dynamicity", rs) + >>> enc.scale_type + + >>> enc.is_binary + True + """ + scale_type = _classify_scale(response_space) + return ResponseEncoding( + name=name, + n_levels=len(response_space.options), + scale_type=scale_type, + labels=response_space.options, + semantic_poles=response_space.semantic_poles, + ) diff --git a/bead/protocol/family.py b/bead/protocol/family.py new file mode 100644 index 0000000..7eda7fe --- /dev/null +++ b/bead/protocol/family.py @@ -0,0 +1,395 @@ +"""Question families and annotation protocols. + +A :class:`QuestionFamily` is a dependent function +``Pi(ctx : ProtocolContext). Question(ctx)``: for each context it +produces a valid, drift-checked +:class:`~bead.protocol.family.QuestionRealization`. + +An :class:`AnnotationProtocol` is the iterated dependent product + + Sigma(a_1 : Q_1(ctx)). Sigma(a_2 : Q_2(ctx, a_1)). ... Q_n(ctx, ...) + +a sequence of question families where later families may condition on +the responses to earlier ones. The dependency edges between families +are recorded explicitly in :attr:`QuestionFamily.depends_on`, which +:class:`~bead.protocol.diagnostics.ConditionalObservationValidator` +consults to check the integrity of conditional responses. +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field + +import didactic.api as dx + +from bead.data.base import BeadBaseModel +from bead.protocol.anchor import SemanticAnchor +from bead.protocol.context import ProtocolContext +from bead.protocol.drift import DriftGuard, DriftScore +from bead.protocol.realization import RealizationStrategy, TemplateRealization + +ApplicabilityPredicate = Callable[[ProtocolContext], bool] +"""Type alias for predicates determining when a family applies.""" + + +def _always_applicable(_ctx: ProtocolContext) -> bool: + """Default applicability: family applies to every context.""" + return True + + +class QuestionRealization(BeadBaseModel): + """A realized question paired with its provenance. + + This is the dependent pair ``Sigma(ctx). Question(ctx)``: a + concrete prompt together with the context that produced it and + evidence of its validity (a :class:`~bead.protocol.drift.DriftScore`). + + Attributes + ---------- + prompt : str + The realized prompt string. May contain ``[[label]]`` + references for downstream rendering. + anchor : SemanticAnchor + The semantic specification this question satisfies. + context : ProtocolContext + The context that parameterized the realization. + drift_score : DriftScore | None + Result of drift validation, if a guard was applied. Defaults + to ``None``. + strategy_name : str + Name of the realization strategy that produced this question. + Defaults to the empty string. + """ + + prompt: str + anchor: dx.Embed[SemanticAnchor] + context: dx.Embed[ProtocolContext] + drift_score: dx.Embed[DriftScore] | None = None + strategy_name: str = "" + + @property + def passed_drift_check(self) -> bool: + """Whether the realization passed drift validation. + + ``True`` when no drift score is attached (no validation was + run) or when the attached score's ``passed`` flag is ``True``. + """ + if self.drift_score is None: + return True + return self.drift_score.passed + + +@dataclass +class QuestionFamily: + """Dependent function from contexts to realized questions. + + For each :class:`ProtocolContext`, a family produces a + :class:`QuestionRealization` by: + + 1. Checking applicability (is this question relevant for this + context?). + 2. Invoking the realization strategy to produce a prompt. + 3. Running drift validation, if a guard is configured. + 4. Falling back to the canonical prompt if the realization drifts + (when ``fallback_on_drift`` is enabled). + + Parameters + ---------- + anchor : SemanticAnchor + The semantic type of questions this family produces. + realization : RealizationStrategy | None, optional + Strategy producing a prompt for a given context. Defaults to + an unparameterized :class:`TemplateRealization`, which echoes + the anchor's canonical prompt. + drift_guard : DriftGuard | None, optional + Optional drift validator. Defaults to ``None``. + condition : ApplicabilityPredicate | None, optional + When to ask this question. ``None`` (the default) marks the + family as always applicable; any non-``None`` value sets + :attr:`is_always_applicable` to ``False``. + depends_on : tuple[str, ...], optional + Anchor names whose responses must precede this family in a + protocol. Read by + :class:`~bead.protocol.diagnostics.ConditionalObservationValidator`. + Defaults to the empty tuple. + fallback_on_drift : bool, optional + If ``True`` (the default), fall back to the canonical prompt + when drift validation fails. If ``False``, raise + :class:`ValueError`. + + Attributes + ---------- + anchor : SemanticAnchor + Configured anchor. + realization : RealizationStrategy + Configured realization strategy. + drift_guard : DriftGuard | None + Configured drift guard. + condition : ApplicabilityPredicate + Configured applicability predicate. + depends_on : tuple[str, ...] + Names of anchors this family depends on. + fallback_on_drift : bool + Whether to fall back on drift failure. + """ + + anchor: SemanticAnchor + realization: RealizationStrategy = field(default_factory=TemplateRealization) + drift_guard: DriftGuard | None = None + condition: ApplicabilityPredicate = field(default=_always_applicable) + depends_on: tuple[str, ...] = () + fallback_on_drift: bool = True + is_always_applicable: bool = field(init=False) + + def __post_init__(self) -> None: + """Record whether ``condition`` is the default predicate.""" + self.is_always_applicable = self.condition is _always_applicable + + @property + def name(self) -> str: + """Short name from the anchor.""" + return self.anchor.name + + def is_applicable(self, context: ProtocolContext) -> bool: + """Whether this family should be asked for the given context. + + Parameters + ---------- + context : ProtocolContext + Current annotation context. + + Returns + ------- + bool + ``True`` when the family applies. + """ + return self.condition(context) + + def realize(self, context: ProtocolContext) -> QuestionRealization: + """Produce a question for the given context. + + Parameters + ---------- + context : ProtocolContext + Current annotation context. + + Returns + ------- + QuestionRealization + The realized question with its drift score and provenance. + + Raises + ------ + ValueError + If drift validation fails and :attr:`fallback_on_drift` is + ``False``. + """ + prompt = self.realization.realize(self.anchor, context) + strategy_name = type(self.realization).__name__ + + drift_score: DriftScore | None = None + guard = self.drift_guard + if guard is not None: + score = guard.check(prompt, self.anchor, context) + if not score.passed: + if self.fallback_on_drift: + prompt = self.anchor.canonical_prompt + strategy_name = f"{strategy_name}->fallback" + score = guard.check(prompt, self.anchor, context) + else: + raise ValueError( + f"Drift validation failed for " + f"{self.anchor.name!r}: {list(score.findings)}" + ) + drift_score = score + + return QuestionRealization( + prompt=prompt, + anchor=self.anchor, + context=context, + drift_score=drift_score, + strategy_name=strategy_name, + ) + + +@dataclass +class AnnotationProtocol: + """A sequence of question families forming a complete protocol. + + Represents the iterated dependent product + + Sigma(a_1 : Q_1(ctx)). Sigma(a_2 : Q_2(ctx, a_1)). ... + + When realized, the protocol threads annotator responses through + the context so later families can condition on earlier answers. + + Parameters + ---------- + families : list[QuestionFamily] + Families in protocol order. + name : str, optional + Descriptive name for the protocol. Defaults to the empty + string. + + Attributes + ---------- + families : list[QuestionFamily] + Families in protocol order. + name : str + Descriptive name. + + Raises + ------ + ValueError + If two families share the same anchor name (anchor names must + be unique within a protocol), or if any family's + :attr:`~QuestionFamily.depends_on` references a family that + does not appear earlier in the sequence. + """ + + families: list[QuestionFamily] + name: str = "" + + def __post_init__(self) -> None: + """Validate uniqueness and forward-only ``depends_on`` edges.""" + seen: set[str] = set() + for family in self.families: + if family.name in seen: + raise ValueError(f"Duplicate anchor name in protocol: {family.name!r}") + for dep in family.depends_on: + if dep == family.name: + raise ValueError(f"Family {family.name!r} depends on itself") + if dep not in seen: + raise ValueError( + f"Family {family.name!r} depends on {dep!r}, " + f"which is not earlier in the protocol " + f"(known so far: {sorted(seen)})" + ) + seen.add(family.name) + + def append(self, family: QuestionFamily) -> None: + """Add a family to the end of the protocol. + + Parameters + ---------- + family : QuestionFamily + The family to append. + + Raises + ------ + ValueError + If a family with the same anchor name is already present, + or if any of its :attr:`~QuestionFamily.depends_on` + references a family not already in the protocol. + """ + existing = {f.name for f in self.families} + if family.name in existing: + raise ValueError(f"Duplicate anchor name in protocol: {family.name!r}") + if family.name in family.depends_on: + raise ValueError(f"Family {family.name!r} depends on itself") + for dep in family.depends_on: + if dep not in existing: + raise ValueError( + f"Family {family.name!r} depends on {dep!r}, " + f"which is not in the protocol (known: " + f"{sorted(existing)})" + ) + self.families.append(family) + + def family_by_name(self, name: str) -> QuestionFamily: + """Look up a family by its anchor name. + + Parameters + ---------- + name : str + The anchor name to look up. + + Returns + ------- + QuestionFamily + The matching family. + + Raises + ------ + KeyError + If no family with that name exists in the protocol. + """ + for family in self.families: + if family.name == name: + return family + raise KeyError( + f"No family named {name!r} in protocol " + f"(have: {[f.name for f in self.families]})" + ) + + def realize_all( + self, + context: ProtocolContext, + *, + responses: dict[str, str] | None = None, + ) -> list[QuestionRealization]: + """Realize all applicable families for a context. + + Threads responses through the context as the protocol is + traversed. When ``responses`` is provided it is injected + before any family is realized; otherwise, after each family is + realized, the first option of its response space is used as a + placeholder so downstream families can be exercised in dry-run + mode. + + Parameters + ---------- + context : ProtocolContext + Base annotation context. + responses : dict[str, str] | None, optional + Pre-supplied responses keyed by anchor name. Defaults to + ``None``. + + Returns + ------- + list[QuestionRealization] + Realized questions in protocol order, skipping families + whose :meth:`QuestionFamily.is_applicable` returns + ``False`` for the running context. + + Raises + ------ + ValueError + If ``responses`` references an anchor not in the protocol. + """ + if responses: + unknown = set(responses) - {f.name for f in self.families} + if unknown: + raise ValueError( + f"Responses reference unknown anchors: {sorted(unknown)}" + ) + + running_ctx = context + if responses: + for family in self.families: + if family.name in responses: + running_ctx = running_ctx.with_response( + family.name, responses[family.name] + ) + + results: list[QuestionRealization] = [] + for family in self.families: + if not family.is_applicable(running_ctx): + continue + + realization = family.realize(running_ctx) + results.append(realization) + + if family.anchor.name not in running_ctx.previous_responses: + options = family.anchor.response_space.options + if options: + running_ctx = running_ctx.with_response( + family.anchor.name, options[0] + ) + + return results + + def __len__(self) -> int: + """Return the number of families in the protocol.""" + return len(self.families) diff --git a/bead/protocol/items.py b/bead/protocol/items.py new file mode 100644 index 0000000..786e1ae --- /dev/null +++ b/bead/protocol/items.py @@ -0,0 +1,384 @@ +"""Bridge from the protocol layer to bead's item-construction layer. + +A :class:`~bead.protocol.QuestionFamily` declares the type-level shape +of a question; a :class:`~bead.protocol.QuestionRealization` is one +realization of that question for a particular +:class:`~bead.protocol.ProtocolContext`. To deploy realizations through +bead's experimental pipeline they must be packaged as +:class:`~bead.items.item_template.ItemTemplate` and +:class:`~bead.items.item.Item` instances. + +This module is the canonical bridge. It defines two mappings: + +- :func:`scale_type_to_task_type` — the single canonical translation + from :class:`~bead.protocol.ScaleType` to the + :class:`~bead.items.item_template.TaskType` literal used by item + templates and active-learning model selection. +- :func:`family_to_item_template` — build the per-family + :class:`ItemTemplate` (one template per anchor; the same template is + reused for every realization of that family). +- :func:`realization_to_item` — package a single + :class:`QuestionRealization` as an :class:`Item` bound to the + family's template, with sentence text and span metadata derived + from the realization's :class:`ProtocolContext`. +- :func:`protocol_to_item_templates` — return a name-keyed dict of + templates for an entire protocol. + +The mapping is total: every supported :class:`ScaleType` corresponds +to exactly one :class:`TaskType`, and every protocol family produces +exactly one :class:`ItemTemplate`. There is no per-task-type factory +in the protocol layer; the family + realization pair is the single +canonical way to build items for a protocol. +""" + +from __future__ import annotations + +from typing import Final + +from bead.items.item import Item +from bead.items.item_template import ( + ItemElement, + ItemTemplate, + JudgmentType, + PresentationSpec, + ScaleBounds, + ScalePointLabel, + TaskSpec, + TaskType, +) +from bead.items.spans import Span, SpanLabel, SpanSegment +from bead.protocol.context import ContextItem, ProtocolContext +from bead.protocol.encoding import ScaleType, encode_response_space +from bead.protocol.family import ( + AnnotationProtocol, + QuestionFamily, + QuestionRealization, +) + +_SCALE_TO_TASK: Final[dict[ScaleType, TaskType]] = { + ScaleType.BINARY: "binary", + ScaleType.ORDINAL: "ordinal_scale", + ScaleType.NOMINAL: "categorical", +} +"""The canonical :class:`ScaleType` → :class:`TaskType` mapping.""" + + +def scale_type_to_task_type(scale_type: ScaleType) -> TaskType: + """Translate a :class:`ScaleType` to its :class:`TaskType`. + + This is the single canonical mapping used by every part of bead + that bridges between protocol-layer encodings and item-layer task + types (item construction, active-learning model selection, jsPsych + deployment). + + Parameters + ---------- + scale_type : ScaleType + Protocol-layer scale type. + + Returns + ------- + TaskType + The matching :class:`TaskType` literal. + + Examples + -------- + >>> from bead.protocol.encoding import ScaleType + >>> scale_type_to_task_type(ScaleType.ORDINAL) + 'ordinal_scale' + """ + return _SCALE_TO_TASK[scale_type] + + +def family_to_item_template( + family: QuestionFamily, + *, + judgment_type: JudgmentType, + presentation_spec: PresentationSpec | None = None, +) -> ItemTemplate: + """Build the :class:`ItemTemplate` for a :class:`QuestionFamily`. + + The template's ``task_type`` is derived from the anchor's response + space via :func:`scale_type_to_task_type`. Ordinal scales + populate :attr:`TaskSpec.scale_bounds` (``0`` to ``n_levels - 1``) + and :attr:`TaskSpec.scale_labels` (one + :class:`ScalePointLabel` per option). Binary and nominal scales + populate :attr:`TaskSpec.options`. + + The ``prompt`` field of the template's :class:`TaskSpec` is the + anchor's canonical prompt (with ``[[label]]`` references intact); + individual realizations override the prompt at item-construction + time via the ``prompt`` rendered-element on the resulting + :class:`Item`. + + Parameters + ---------- + family : QuestionFamily + The family to bridge. + judgment_type : JudgmentType + Semantic property being measured (caller-supplied because + bead's :class:`JudgmentType` taxonomy is broader than + :class:`~bead.protocol.encoding.ScaleType`). + presentation_spec : PresentationSpec | None, optional + Custom presentation spec. Defaults to a fresh + :class:`PresentationSpec` with mode ``"static"``. + + Returns + ------- + ItemTemplate + Template with ``name`` set to the anchor name, ``task_type`` + derived from the scale, and ``elements`` covering ``"text"`` + (the sentence) and ``"prompt"`` (the realized question). + """ + encoding = encode_response_space(family.anchor.name, family.anchor.response_space) + task_type = scale_type_to_task_type(encoding.scale_type) + + if encoding.is_ordinal: + scale_bounds: ScaleBounds | None = ScaleBounds(min=0, max=encoding.n_levels - 1) + scale_labels = tuple( + ScalePointLabel(point=i, label=label) + for i, label in enumerate(encoding.labels) + ) + options: tuple[str, ...] | None = None + else: + scale_bounds = None + scale_labels = () + options = encoding.labels + + task_spec = TaskSpec( + prompt=family.anchor.canonical_prompt, + scale_bounds=scale_bounds, + scale_labels=scale_labels, + options=options, + ) + elements = ( + ItemElement( + element_type="text", + element_name="text", + content="", + order=0, + ), + ItemElement( + element_type="text", + element_name="prompt", + content=family.anchor.canonical_prompt, + order=1, + ), + ) + return ItemTemplate( + name=family.anchor.name, + description=family.anchor.description or None, + judgment_type=judgment_type, + task_type=task_type, + task_spec=task_spec, + presentation_spec=presentation_spec or PresentationSpec(), + elements=elements, + ) + + +def _spans_from_context( + realization: QuestionRealization, +) -> tuple[Span, ...]: + """Extract :class:`Span` objects from a realization's context. + + Builds one span per ``required_span_label`` on the anchor, taking + its token positions from the matching field on + :class:`ProtocolContext` (target span for the anchor's primary + label, dependent span for any other required label that matches a + dependent's ``head_lemma``). + + Parameters + ---------- + realization : QuestionRealization + The realized question carrying the context. + + Returns + ------- + tuple[Span, ...] + Spans keyed to required label names. Empty when the anchor + has no required span labels. + """ + anchor = realization.anchor + context = realization.context + if not anchor.required_span_labels: + return () + + spans: list[Span] = [] + label_to_dependent: dict[str, ContextItem] = { + d.head_lemma: d for d in context.dependents if d.head_lemma + } + + for label in sorted(anchor.required_span_labels): + if label_to_dependent and label in label_to_dependent: + dep = label_to_dependent[label] + if dep.span_positions: + spans.append( + Span( + span_id=f"{anchor.name}-{label}", + segments=( + SpanSegment( + element_name="text", + indices=tuple(i - 1 for i in dep.span_positions), + ), + ), + label=SpanLabel(label=label), + head_index=( + dep.head_position - 1 if dep.head_position > 0 else None + ), + ) + ) + continue + + if context.target_span_positions: + spans.append( + Span( + span_id=f"{anchor.name}-{label}", + segments=( + SpanSegment( + element_name="text", + indices=tuple(i - 1 for i in context.target_span_positions), + ), + ), + label=SpanLabel(label=label), + head_index=( + context.target_position - 1 + if context.target_position > 0 + else None + ), + ) + ) + + return tuple(spans) + + +def realization_to_item( + realization: QuestionRealization, + *, + item_template: ItemTemplate, +) -> Item: + """Package a :class:`QuestionRealization` as an :class:`Item`. + + The resulting :class:`Item` references ``item_template`` by id, + rendering the realization's prompt as the ``"prompt"`` element and + the context's sentence as the ``"text"`` element. Tokenized + elements are populated from + :attr:`ProtocolContext.tokens` when present; spans for the + anchor's ``required_span_labels`` are derived via + :func:`_spans_from_context`. + + Parameters + ---------- + realization : QuestionRealization + A realized question produced by + :meth:`QuestionFamily.realize`. + item_template : ItemTemplate + The template returned by :func:`family_to_item_template` for + the originating family. The bridge does not validate that the + template was produced from the same family — the caller is + responsible for matching them. + + Returns + ------- + Item + Item bound to the template, with the realization's prompt and + context materialized. + """ + context = realization.context + rendered_elements = { + "text": context.sentence, + "prompt": realization.prompt, + } + tokenized_elements: dict[str, tuple[str, ...]] = {} + if context.tokens: + tokenized_elements["text"] = context.tokens + + spans = _spans_from_context(realization) + + return Item( + item_template_id=item_template.id, + rendered_elements=rendered_elements, + tokenized_elements=tokenized_elements, + spans=spans, + ) + + +def protocol_to_item_templates( + protocol: AnnotationProtocol, + *, + judgment_type: JudgmentType, + presentation_spec: PresentationSpec | None = None, +) -> dict[str, ItemTemplate]: + """Build one :class:`ItemTemplate` per family in the protocol. + + Parameters + ---------- + protocol : AnnotationProtocol + The protocol whose families to translate. + judgment_type : JudgmentType + Common judgment type to assign to every template. + presentation_spec : PresentationSpec | None, optional + Common presentation spec; defaults to a fresh one per call. + + Returns + ------- + dict[str, ItemTemplate] + Mapping from family / anchor name to its :class:`ItemTemplate`. + """ + return { + family.name: family_to_item_template( + family, + judgment_type=judgment_type, + presentation_spec=presentation_spec, + ) + for family in protocol.families + } + + +def realize_protocol_to_items( + protocol: AnnotationProtocol, + context: ProtocolContext, + *, + judgment_type: JudgmentType, + item_templates: dict[str, ItemTemplate] | None = None, + responses: dict[str, str] | None = None, + presentation_spec: PresentationSpec | None = None, +) -> tuple[tuple[QuestionRealization, Item], ...]: + """Realize a protocol against one context, packaging items. + + Each applicable family is realized in protocol order; each + :class:`QuestionRealization` is paired with the :class:`Item` + produced by :func:`realization_to_item`. + + Parameters + ---------- + protocol : AnnotationProtocol + Protocol to realize. + context : ProtocolContext + Base context for realization. + judgment_type : JudgmentType + Judgment type assigned to every template. + item_templates : dict[str, ItemTemplate] | None, optional + Pre-built templates; built via + :func:`protocol_to_item_templates` when ``None``. + responses : dict[str, str] | None, optional + Pre-supplied responses threaded into the context. Defaults to + ``None``. + presentation_spec : PresentationSpec | None, optional + Common presentation spec when templates are built fresh. + + Returns + ------- + tuple[tuple[QuestionRealization, Item], ...] + For each applicable family, the ``(realization, item)`` pair + in protocol order. + """ + templates = item_templates or protocol_to_item_templates( + protocol, + judgment_type=judgment_type, + presentation_spec=presentation_spec, + ) + realizations = protocol.realize_all(context, responses=responses) + return tuple( + (r, realization_to_item(r, item_template=templates[r.anchor.name])) + for r in realizations + ) diff --git a/bead/protocol/realization.py b/bead/protocol/realization.py new file mode 100644 index 0000000..88c0214 --- /dev/null +++ b/bead/protocol/realization.py @@ -0,0 +1,483 @@ +"""Realization strategies: how dependent functions are computed. + +A :class:`RealizationStrategy` maps a +:class:`~bead.protocol.anchor.SemanticAnchor` and a +:class:`~bead.protocol.context.ProtocolContext` to a concrete prompt +string. It is the computational content of the dependent function +``Pi(ctx). Question(ctx)``. + +Three strategies are provided: + +- :class:`TemplateRealization`: a fixed template (the simplest + strategy and a safe fallback). +- :class:`ContextualTemplateRealization`: rule-based selection from + ranked template variants. +- :class:`LMRealization`: prompts a language model to paraphrase the + canonical question for the specific context. Should always be paired + with a :class:`~bead.protocol.drift.DriftGuard` to validate that the + paraphrase preserves semantic content. + +These classes carry callable fields (predicates, LM clients) so they +are plain frozen Python classes rather than +:class:`~bead.data.base.BeadBaseModel` subclasses; didactic Models do +not accept :class:`~collections.abc.Callable` field types. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +from bead.protocol.anchor import SemanticAnchor +from bead.protocol.context import ContextPredicate, ProtocolContext, always + +if TYPE_CHECKING: + from bead.items.cache import ModelOutputCache + + +@runtime_checkable +class RealizationStrategy(Protocol): + """Protocol for question realization. + + A realization strategy is the computational content of the + dependent function ``Pi(ctx). Question(ctx)``: it produces a + prompt string for a given anchor-and-context pair. + + Examples + -------- + A minimal conforming implementation: + + >>> class EchoCanonical: + ... def realize( + ... self, anchor, context + ... ): + ... return anchor.canonical_prompt + >>> isinstance(EchoCanonical(), RealizationStrategy) + True + """ + + def realize( + self, + anchor: SemanticAnchor, + context: ProtocolContext, + ) -> str: + """Produce a prompt string for the given anchor and context. + + Parameters + ---------- + anchor : SemanticAnchor + The semantic invariant to preserve. + context : ProtocolContext + The context to condition on. + + Returns + ------- + str + A prompt string, possibly containing ``[[label]]`` or + ``[[label|transform]]`` references. + """ + ... + + +@dataclass(frozen=True) +class TemplateVariant: + """A context-conditioned question template. + + Parameters + ---------- + template : str + Question template, possibly containing ``[[label]]`` or + ``[[label|transform]]`` references. + condition : ContextPredicate, optional + Returns ``True`` when this variant is appropriate for the + context. Variants are evaluated in priority order; the first + match wins. Defaults to :func:`always`. + priority : int, optional + Higher-priority variants are tried first. Use this to order + more-specific variants before less-specific ones. Defaults to + ``0``. + description : str, optional + Human-readable description for experimenters. Defaults to the + empty string. + + Attributes + ---------- + template : str + The template string. + condition : ContextPredicate + Variant-applicability predicate. + priority : int + Selection priority. + description : str + Human-readable description. + """ + + template: str + condition: ContextPredicate = field(default=always) + priority: int = 0 + description: str = "" + + +@dataclass(frozen=True) +class TemplateRealization: + """Fixed-template realization. + + Always returns the same template string regardless of context. The + simplest strategy and a safe fallback when context-dependent + phrasing is not needed. + + Parameters + ---------- + template : str | None, optional + Template string. When ``None``, the anchor's canonical prompt + is used at realization time. Defaults to ``None``. + + Attributes + ---------- + template : str | None + The configured template, or ``None`` to defer to the anchor. + """ + + template: str | None = None + + def realize( + self, + anchor: SemanticAnchor, + context: ProtocolContext, # noqa: ARG002 + ) -> str: + """Return the configured template or the canonical prompt. + + Parameters + ---------- + anchor : SemanticAnchor + The semantic invariant. Its ``canonical_prompt`` is used + when this strategy was constructed without an explicit + template. + context : ProtocolContext + The annotation context (unused by this strategy but + required by the :class:`RealizationStrategy` protocol). + + Returns + ------- + str + The realized prompt string. + """ + return self.template if self.template is not None else anchor.canonical_prompt + + +@dataclass(frozen=True) +class ContextualTemplateRealization: + """Rule-based selection from ranked template variants. + + Evaluates variant conditions in descending priority order and + returns the template of the first matching variant. Falls back to + a configurable fallback template (or the anchor's canonical prompt + if none is configured) when no variant matches. + + This is the recommended strategy for production use: it gives + experimenters fine-grained control over how questions adapt to + context while guaranteeing the output is one of a pre-approved set + of templates. + + Parameters + ---------- + variants : tuple[TemplateVariant, ...] + Candidate templates. They are evaluated in descending priority + order; ties are broken by registration order. + fallback : str | None, optional + Template used when no variant matches. When ``None``, the + anchor's canonical prompt is used. Defaults to ``None``. + + Attributes + ---------- + variants : tuple[TemplateVariant, ...] + The configured variants, sorted by descending priority. + fallback : str | None + Fallback template, or ``None`` to defer to the anchor. + """ + + variants: tuple[TemplateVariant, ...] + fallback: str | None = None + + def __post_init__(self) -> None: + """Sort variants by descending priority, stable on ties.""" + sorted_variants = tuple( + sorted(self.variants, key=lambda v: v.priority, reverse=True) + ) + object.__setattr__(self, "variants", sorted_variants) + + def realize( + self, + anchor: SemanticAnchor, + context: ProtocolContext, + ) -> str: + """Return the first matching variant's template, or the fallback. + + Parameters + ---------- + anchor : SemanticAnchor + The semantic invariant. + context : ProtocolContext + The annotation context tested against each variant's + condition. + + Returns + ------- + str + The template of the highest-priority matching variant, the + configured fallback if none match, or + ``anchor.canonical_prompt`` when no fallback is configured. + """ + for variant in self.variants: + if variant.condition(context): + return variant.template + return self.fallback if self.fallback is not None else anchor.canonical_prompt + + +@runtime_checkable +class LMClient(Protocol): + """Protocol for language-model completion. + + Any object with a ``complete`` method matching this signature can + serve as an LM backend for :class:`LMRealization`. The keyword + parameters ``temperature`` and ``max_tokens`` are required, since + :class:`LMRealization` always supplies them. + + Examples + -------- + A minimal stub for testing: + + >>> class StubClient: + ... def complete( + ... self, prompt: str, *, + ... temperature: float, max_tokens: int, + ... ) -> str: + ... return "Did the event reach an endpoint?" + >>> isinstance(StubClient(), LMClient) + True + """ + + def complete( + self, + prompt: str, + *, + temperature: float, + max_tokens: int, + ) -> str: + """Generate a completion for the given prompt. + + Parameters + ---------- + prompt : str + Full prompt including any system context. + temperature : float + Sampling temperature. + max_tokens : int + Maximum response length in tokens. + + Returns + ------- + str + Generated text. + """ + ... + + +_DEFAULT_SYSTEM_PROMPT = ( + "You are helping design annotation questions for a linguistics " + "experiment. You will be given a sentence, information about a " + "highlighted target, and a canonical question about a specific " + "linguistic property.\n\n" + "Your task: rephrase the canonical question so it is natural, " + "clear, and easy for a non-linguist to answer, while preserving:\n" + "1. The same semantic target (the same property is being " + "measured)\n" + "2. The same response scale\n" + "3. References to the highlighted target using [[label]] or " + "[[label|transform]] syntax wherever they appear in the canonical " + "question.\n\n" + "Output ONLY the rephrased question, nothing else." +) +"""Default system prompt for :class:`LMRealization`. + +Tuned to preserve the response scale and ``[[label]]`` references +required by structural drift validation. +""" + + +class LMRealization: + """LM-based question paraphrasing. + + Prompts a language model to rephrase the canonical question for + the specific annotation context. The LM receives the sentence, + target information, and canonical question as context, and + produces a paraphrase that should be more natural for the specific + sentence. + + This strategy should always be paired with a + :class:`~bead.protocol.drift.DriftGuard` to validate that the + paraphrase preserves semantic content. + + When ``cache`` is supplied (a :class:`~bead.items.cache.ModelOutputCache`), + realized prompts are stored under the + ``(model_name, "lm_completion", prompt=full_prompt)`` key. Repeated + calls with the same anchor-and-context pair avoid redundant LM + calls. The cache is the single canonical caching surface across + bead; this class does not maintain its own. + + Parameters + ---------- + client : LMClient + Language-model backend. + model_name : str + Identifier for the model behind ``client``. Used as the cache + key prefix. + cache : ModelOutputCache | None, optional + Output cache shared with the rest of bead. Pass ``None`` to + disable caching. Defaults to ``None``. + system_prompt : str, optional + System prompt controlling paraphrase behavior. Defaults to + :data:`_DEFAULT_SYSTEM_PROMPT`. + temperature : float, optional + Sampling temperature. Lower values are more conservative. + Defaults to ``0.3``. + max_tokens : int, optional + Maximum response length in tokens. Defaults to ``200``. + """ + + def __init__( + self, + client: LMClient, + *, + model_name: str, + cache: ModelOutputCache | None = None, + system_prompt: str = _DEFAULT_SYSTEM_PROMPT, + temperature: float = 0.3, + max_tokens: int = 200, + ) -> None: + self._client = client + self._model_name = model_name + self._cache = cache + self._system_prompt = system_prompt + self._temperature = temperature + self._max_tokens = max_tokens + + def _build_user_prompt( + self, anchor: SemanticAnchor, context: ProtocolContext + ) -> str: + """Construct the user-facing portion of the LM prompt. + + Parameters + ---------- + anchor : SemanticAnchor + The semantic invariant. + context : ProtocolContext + The annotation context. + + Returns + ------- + str + A multi-line string summarizing the context and the + canonical question. + """ + parts: list[str] = [ + f'Sentence: "{context.sentence}"', + f'Highlighted target: "{context.target_span_text}"', + f'Target lemma: "{context.target_lemma}"', + ] + + if context.dependents: + dep_strs = [ + f' - {d.head_lemma} ({d.head_upos}): "{d.span_text}"' + for d in context.dependents + ] + parts.append("Dependents:\n" + "\n".join(dep_strs)) + + parts.extend( + [ + f"Semantic target: {anchor.target_property}", + f"Description: {anchor.description}", + f'Canonical question: "{anchor.canonical_prompt}"', + f"Response scale: {list(anchor.response_space.options)}", + ] + ) + + if anchor.required_span_labels: + parts.append( + f"Required span references: {sorted(anchor.required_span_labels)}" + ) + + return "\n".join(parts) + + def realize( + self, + anchor: SemanticAnchor, + context: ProtocolContext, + ) -> str: + """Generate a context-adapted question via the LM. + + When a cache was supplied at construction time and a cached + result exists for the same prompt, the cached value is + returned without calling the LM. + + Parameters + ---------- + anchor : SemanticAnchor + Semantic specification. + context : ProtocolContext + Current annotation context. + + Returns + ------- + str + LM-generated prompt string. Surrounding quotes and + whitespace are stripped, and a trailing ``?`` is appended + when missing. + + Raises + ------ + RuntimeError + If the LM backend raises, or if the LM returns an empty + response. + """ + full_prompt = ( + f"{self._system_prompt}\n\n" + f"{self._build_user_prompt(anchor, context)}\n\n" + f"Rephrased question:" + ) + + if self._cache is not None: + cached = self._cache.get( + self._model_name, "lm_completion", prompt=full_prompt + ) + if isinstance(cached, str): + return cached + + try: + raw = self._client.complete( + full_prompt, + temperature=self._temperature, + max_tokens=self._max_tokens, + ) + except Exception as exc: + raise RuntimeError( + f"LM realization failed for anchor {anchor.name!r}: {exc}" + ) from exc + + cleaned = raw.strip().strip("\"'").strip() + if not cleaned: + raise RuntimeError( + f"LM realization returned an empty response for anchor {anchor.name!r}" + ) + if not cleaned.endswith("?"): + cleaned = f"{cleaned}?" + + if self._cache is not None: + self._cache.set( + self._model_name, + "lm_completion", + cleaned, + prompt=full_prompt, + ) + + return cleaned diff --git a/docs/api/active_learning.md b/docs/api/active_learning.md index 6f086c9..ae57ba8 100644 --- a/docs/api/active_learning.md +++ b/docs/api/active_learning.md @@ -26,6 +26,17 @@ Stage 6 of the bead pipeline: active learning with GLMM support and convergence show_root_heading: true show_source: false +## Model Registry + +Single canonical task-type → model-class and task-type → config-class +mapping used by the CLI training commands and the protocol-encoding +factory. + +::: bead.active_learning.models.registry + options: + show_root_heading: true + show_source: false + ## Base Model Interface ::: bead.active_learning.models.base diff --git a/docs/api/config.md b/docs/api/config.md index dcc1268..86cef67 100644 --- a/docs/api/config.md +++ b/docs/api/config.md @@ -1,11 +1,23 @@ # bead.config -Configuration system with Pydantic models for YAML-based pipeline orchestration. +Configuration system: didactic models for TOML/YAML pipeline +orchestration. -All configuration modules are documented here. See the [Configuration Guide](../user-guide/configuration.md) for usage examples. +All configuration modules are documented here. See the [Configuration +Guide](../user-guide/configuration.md) for usage examples. ::: bead.config options: show_root_heading: true show_source: false members_order: alphabetical + +## Annotation Protocol Configuration + +Declarative TOML/YAML configuration for the annotation-protocol layer: +anchor specs, family specs, drift settings, and protocol composition. + +::: bead.config.protocol + options: + show_root_heading: true + show_source: false diff --git a/docs/api/data_collection.md b/docs/api/data_collection.md index 3890603..bee6f11 100644 --- a/docs/api/data_collection.md +++ b/docs/api/data_collection.md @@ -22,3 +22,14 @@ Data retrieval and integration from JATOS and Prolific platforms. options: show_root_heading: true show_source: false + +## Annotation-Record Bridge + +Single canonical conversion from raw JATOS results to bead +:class:`~bead.evaluation.AnnotationRecord` instances, the input shape +consumed by every reliability and inter-annotator-agreement check. + +::: bead.data_collection.records + options: + show_root_heading: true + show_source: false diff --git a/docs/api/deployment.md b/docs/api/deployment.md index 93c5737..5a00b28 100644 --- a/docs/api/deployment.md +++ b/docs/api/deployment.md @@ -42,3 +42,15 @@ Stage 5 of the bead pipeline: jsPsych 8.x batch experiment generation for JATOS. options: show_root_heading: true show_source: false + +## Protocol-Layer Bridge + +Single canonical bridge from a configured +:class:`~bead.protocol.AnnotationProtocol` and a sequence of +:class:`~bead.protocol.ProtocolContext` records to a flat list of +jsPsych trial dicts. + +::: bead.deployment.protocol_trials + options: + show_root_heading: true + show_source: false diff --git a/docs/api/evaluation.md b/docs/api/evaluation.md index d5aaaad..d9eb048 100644 --- a/docs/api/evaluation.md +++ b/docs/api/evaluation.md @@ -15,3 +15,10 @@ Metrics and evaluation utilities for convergence detection and inter-annotator a options: show_root_heading: true show_source: false + +## Per-Annotator Reliability + +::: bead.evaluation.reliability + options: + show_root_heading: true + show_source: false diff --git a/docs/api/labels.md b/docs/api/labels.md new file mode 100644 index 0000000..9e21039 --- /dev/null +++ b/docs/api/labels.md @@ -0,0 +1,10 @@ +# bead.labels + +Single canonical home for the prompt-template label-reference syntax +(`[[label]]`, `[[label:text]]`, `[[label|transform]]`) used by drift +validation, item-construction, and jsPsych deployment. + +::: bead.labels + options: + show_root_heading: true + show_source: false diff --git a/docs/api/protocol.md b/docs/api/protocol.md new file mode 100644 index 0000000..698c0ce --- /dev/null +++ b/docs/api/protocol.md @@ -0,0 +1,67 @@ +# bead.protocol + +Annotation-protocol primitives: anchors as types, contexts as +dependent indices, realization strategies as computational content, +and drift validators as type-checkers. On top of these, question +families and protocols compose into a sequenced, conditional +annotation pipeline. + +## Anchors and Response Spaces + +::: bead.protocol.anchor + options: + show_root_heading: true + show_source: false + +## Encoding + +::: bead.protocol.encoding + options: + show_root_heading: true + show_source: false + +## Contexts + +::: bead.protocol.context + options: + show_root_heading: true + show_source: false + +## Realization Strategies + +::: bead.protocol.realization + options: + show_root_heading: true + show_source: false + +## Drift Validation + +::: bead.protocol.drift + options: + show_root_heading: true + show_source: false + +## Question Families and Protocols + +::: bead.protocol.family + options: + show_root_heading: true + show_source: false + +## Diagnostics + +::: bead.protocol.diagnostics + options: + show_root_heading: true + show_source: false + +## Item-Layer Bridge + +Single canonical bridge from a realized question to a fully-populated +:class:`~bead.items.item.Item` and from a configured protocol to the +per-family :class:`~bead.items.item_template.ItemTemplate` collection. + +::: bead.protocol.items + options: + show_root_heading: true + show_source: false diff --git a/docs/developer-guide/architecture.md b/docs/developer-guide/architecture.md index 877e29e..2c47166 100644 --- a/docs/developer-guide/architecture.md +++ b/docs/developer-guide/architecture.md @@ -19,6 +19,15 @@ bead implements a 6-stage pipeline for constructing, deploying, and analyzing la Each stage reads data from the previous stage using UUID references, processes it, adds metadata, and writes new data with its own UUIDs. This creates an unbroken chain of provenance from lexical resources to trained models. +The `bead.protocol` package sits *across* the 6-stage pipeline, not +inside it. Anchors, contexts, realization strategies, and drift +guards together define *what question is being asked* of an +annotator and *how it is phrased*; the resulting prompt strings flow +into Stage 3 item construction, and annotator responses flow back to +Stage 6 training and evaluation. The protocol layer is intentionally +domain-neutral and pipeline-orthogonal so that any annotation domain +can reuse the same anchor / drift / realization machinery. + ### Data Flow Example A typical experiment follows this data flow: @@ -158,6 +167,35 @@ bead consists of 17 top-level modules organized by function: - `convergence.py`: ConvergenceDetector (Krippendorff's alpha) - `interannotator.py`: InterAnnotatorMetrics (Cohen, Fleiss, Krippendorff) +- `reliability.py`: AnnotationRecord, AnnotatorReliability, + per-annotator Shannon-entropy diagnostics, and + `low_entropy_annotators` flagger + +**bead/protocol/** - Annotation-protocol primitives (cross-cutting layer) + +- `anchor.py`: SemanticAnchor (the *type* of a question), ResponseSpace, + SemanticPoles +- `context.py`: ProtocolContext (the dependent *index*), ContextItem, + context-predicate registry (`register_context_predicate`, + `get_context_predicate`, `list_context_predicates`) +- `realization.py`: RealizationStrategy Protocol with three + implementations (TemplateRealization, + ContextualTemplateRealization, LMRealization), TemplateVariant, + LMClient Protocol +- `drift.py`: DriftScore, DriftValidator Protocol, three concrete + validators (StructuralDriftValidator, EmbeddingDriftValidator, + PerplexityDriftValidator), DriftGuard composite, plus + EmbeddingAdapter and PerplexityAdapter Protocols for backends +- `family.py`: QuestionFamily (Pi(ctx). Question(ctx)), + AnnotationProtocol (the iterated dependent product, with + `depends_on` graph validation), QuestionRealization +- `encoding.py`: ScaleType (binary / ordinal / nominal), + ResponseEncoding (likelihood-agnostic, with invariant validators + for `n_levels == len(labels)`, label uniqueness, and + BINARY-must-have-2-levels), `encode_response_space` bridge +- `diagnostics.py`: DiagnosticLevel, DiagnosticRecord, DatasetReport + (immutable; `with_*` mutators), ConditionalObservationValidator + (drives off `QuestionFamily.depends_on`), RecordLike Protocol **bead/simulation/** - Simulation framework diff --git a/docs/index.md b/docs/index.md index 1e655f2..04099e9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -17,11 +17,12 @@ bead implements a 6-stage pipeline for linguistic experiment design: - **Stand-off annotation** with UUID-based references for provenance tracking - **9 task types**: forced-choice, ordinal scale, binary, categorical, multi-select, magnitude, free text, cloze, span labeling +- **Annotation protocols**: type-theoretic stack of anchors, contexts, realization strategies, and drift guards, composed into conditional protocols ([overview](user-guide/protocols.md)) - **GLMM support**: Generalized Linear Mixed Models with random effects - **Batch deployment**: server-side list distribution via JATOS batch sessions - **Language-agnostic**: works with any language supported by UniMorph - **Configuration-first**: single YAML file orchestrates entire pipeline -- **Type-safe**: full Python 3.13 type hints with Pydantic v2 validation +- **Type-safe**: full Python 3.14 type hints with didactic validation ## Quick Links diff --git a/docs/user-guide/concepts.md b/docs/user-guide/concepts.md index d9033ef..367ab77 100644 --- a/docs/user-guide/concepts.md +++ b/docs/user-guide/concepts.md @@ -143,6 +143,34 @@ Bead distinguishes between **task types** (UI presentation) and **judgment types The same judgment type may use different task types depending on experimental goals. Acceptability can use ordinal scales (rate sentence naturalness) or forced choice (which sentence is more natural). +## Annotation Protocols + +Above the task / judgment distinction sits a separate type-theoretic +layer for *what* a question measures and *how* it is phrased. The +[`bead.protocol`](protocols.md) package factors annotation design +into four roles: + +- A `SemanticAnchor` is the *type* of a question: a declarative + specification of the property being measured, the response space, + and the structural constraints any phrasing must preserve. +- A `ProtocolContext` is the dependent *index*: everything known + about the current annotation target, including responses already + recorded for earlier questions. +- A `RealizationStrategy` is the *computational content* of the + dependent function `Pi(ctx). Question(ctx)`. Three strategies are + shipped: a fixed template, a context-conditional template selector, + and an LM paraphraser. +- A `DriftGuard` is the *type-checker* over realized prompts; it + composes structural, embedding, and perplexity validators. + +`QuestionFamily` packages an anchor with a realization strategy and a +drift guard; `AnnotationProtocol` sequences families into the +iterated dependent product +`Sigma(a_1 : Q_1(ctx)). Sigma(a_2 : Q_2(ctx, a_1)). ...`, threading +each response into the context so later questions can condition on +earlier answers. See the [protocols user guide](protocols.md) for +the full walkthrough. + ## Configuration-First Design Bead orchestrates the entire pipeline from a single YAML configuration file. The config specifies paths, strategies, constraints, and parameters for all six stages. diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index 6215243..7a81587 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -75,6 +75,9 @@ Before using either approach, familiarize yourself with these concepts: - [Pipeline Architecture](concepts.md): 6-stage experimental pipeline - [Configuration System](configuration.md): YAML-based project configuration - [Stand-off Annotation](concepts.md#stand-off-annotation): UUID-based data provenance +- [Annotation Protocols](protocols.md): anchors as types, contexts as + dependent indices, realization strategies as computational content, + and drift guards as type-checkers ## Quick Start diff --git a/docs/user-guide/protocols.md b/docs/user-guide/protocols.md new file mode 100644 index 0000000..fa4e17e --- /dev/null +++ b/docs/user-guide/protocols.md @@ -0,0 +1,635 @@ +# Annotation Protocols + +The `bead.protocol` package gives you a type-theoretic stack for +defining annotation protocols. Four roles work together: + +- A **semantic anchor** is the *type* of a question: a declarative + specification of what is being measured. +- A **protocol context** is the dependent *index*: everything known + about the current target. +- A **realization strategy** is the computational *content* of the + dependent function `Pi(ctx). Question(ctx)`: it produces the + prompt string a participant will see. +- A **drift guard** is the *type-checker*: it verifies that a realized + prompt still inhabits the type defined by its anchor. + +`QuestionFamily` packages these together; `AnnotationProtocol` +sequences families into the iterated dependent product +`Sigma(a_1 : Q_1(ctx)). Sigma(a_2 : Q_2(ctx, a_1)). ...`, threading +responses through the context so later questions can condition on +earlier answers. + +## Why a protocol layer? + +Without a separation between the question's type and its phrasing, +two questions that elicit different responses can look identical, and +two phrasings of the *same* question can look different. The protocol +layer makes the invariants explicit: + +- The anchor declares which property is measured, the response space, + required keywords, and required span references. +- The realization can vary by context (template variants, LM + paraphrase) but must preserve the anchor's invariants. +- The drift guard catches realizations that fail to preserve them. + +## Defining an anchor + +```python +from bead.protocol import ResponseSpace, SemanticAnchor +from bead.protocol.anchor import SemanticPoles + +response_space = ResponseSpace( + options=( + "definitely no", + "probably no", + "unsure", + "probably yes", + "definitely yes", + ), + is_ordered=True, + semantic_poles=SemanticPoles( + low="definitely no", high="definitely yes", + ), +) + +completion = SemanticAnchor( + name="completion", + target_property="telicity", + canonical_prompt="Does [[situation]] reach a definite endpoint?", + response_space=response_space, + required_span_labels=frozenset({"situation"}), + required_keywords=frozenset({"endpoint"}), + description="Whether the event reaches a culmination.", +) +``` + +Use `SemanticAnchor.from_response_options` for the common case of an +anchor whose response space is built inline: + +```python +completion = SemanticAnchor.from_response_options( + name="completion", + target_property="telicity", + canonical_prompt="Does [[situation]] reach an endpoint?", + options=("no", "yes"), + is_ordered=False, + required_span_labels=frozenset({"situation"}), +) +``` + +## Building a context + +`ProtocolContext` carries sentence-level, target-level, and +dependent-level information common to most annotation protocols. +Domain-specific data lives in the inherited `metadata` map (a JSON +dict from `BeadBaseModel`): + +```python +from bead.protocol import ContextItem, ProtocolContext + +ctx = ProtocolContext( + sentence="Mary built a sandcastle.", + target_lemma="build", + target_form="built", + target_upos="VERB", + target_position=2, + target_span_text="built a sandcastle", + target_span_positions=(2, 3, 4), + dependents=( + ContextItem( + head_lemma="Mary", head_upos="PROPN", + head_position=1, span_text="Mary", + ), + ContextItem( + head_lemma="sandcastle", head_upos="NOUN", + head_position=4, span_text="a sandcastle", + attributes={"definiteness": 0.0}, + ), + ), +) +``` + +### Domain-specific dependent attributes + +Each `ContextItem` carries an `attributes: dict[str, float]` map for +domain-specific scalar properties (semantic-role probabilities, +definiteness scores, frequency, ...). `ContextItem.attribute(name)` +returns `None` when the attribute is absent so callers do not have +to handle a separate `KeyError`: + +```python +mary, sandcastle = ctx.dependents +sandcastle.attribute("definiteness") # 0.0 +sandcastle.attribute("missing") is None # True +``` + +### The context-predicate registry + +`ContextualTemplateRealization` and other strategies look predicates +up by name from a module-level registry rather than passing functions +directly. Register at import time, look up at realization time: + +```python +from bead.protocol import ( + register_context_predicate, get_context_predicate, + list_context_predicates, ProtocolContext, +) + +def has_plural_dependent(ctx: ProtocolContext) -> bool: + return any(d.is_plural for d in ctx.dependents) + +register_context_predicate("has_plural_dependent", has_plural_dependent) + +assert get_context_predicate("has_plural_dependent") is has_plural_dependent +assert "has_plural_dependent" in list_context_predicates() +``` + +The protocol layer ships one predicate, `always`, which is also the +default condition for `TemplateVariant`. The registry is global +mutable state, populated at import time and read at realization +time; it is not designed for per-request mutation. + +## Threading dependent responses + +`ProtocolContext.with_response` returns a new context with one +additional response recorded; the original is unchanged. + +```python +ctx2 = ctx.with_response("change", "yes") +ctx3 = ctx2.with_response("completion", "probably yes") +ctx3.previous_responses +# {'change': 'yes', 'completion': 'probably yes'} +``` + +## Realization strategies + +`TemplateRealization` returns a fixed template (or the anchor's +canonical prompt when no template is configured): + +```python +from bead.protocol import TemplateRealization + +tr = TemplateRealization() # echoes anchor.canonical_prompt +``` + +`ContextualTemplateRealization` selects from ranked variants: + +```python +from bead.protocol import ContextualTemplateRealization, TemplateVariant + +contextual = ContextualTemplateRealization( + variants=( + TemplateVariant( + template="Does [[situation]] end at a specific point?", + condition=lambda ctx: ctx.target_upos == "VERB", + priority=10, + ), + TemplateVariant( + template="Does [[situation]] have a specific end?", + priority=0, + ), + ), +) +``` + +`LMRealization` paraphrases the canonical prompt via a language-model +client. Always pair it with a drift guard, and pass a +`bead.items.cache.ModelOutputCache` so realizations participate in +bead's single canonical caching surface: + +```python +from bead.items.cache import ModelOutputCache +from bead.protocol import LMClient, LMRealization + +class StubClient: + def complete( + self, prompt: str, *, temperature: float, max_tokens: int, + ) -> str: + return "Did the event reach an endpoint?" + +assert isinstance(StubClient(), LMClient) +cache = ModelOutputCache(backend="memory") +lm = LMRealization( + StubClient(), + model_name="stub-paraphraser", + cache=cache, + temperature=0.3, + max_tokens=200, +) +``` + +`LMClient` is a `typing.Protocol`: any object with a `complete(prompt, +*, temperature, max_tokens) -> str` method conforms. Cache entries +key off `(model_name, "lm_completion", prompt=full_prompt)`; passing +the same `ModelOutputCache` to multiple `LMRealization`s with +different `model_name` values keeps their entries isolated. Without a +cache (`cache=None`), every `realize()` call hits the backend. +`LMRealization.realize` raises `RuntimeError` on backend failures or +empty / whitespace-only responses, so a misbehaving LM cannot silently +pollute the cache. + +## Drift validation + +The drift guard composes structural, embedding, and perplexity +validators. + +```python +from bead.protocol import ( + DriftGuard, + EmbeddingDriftValidator, + PerplexityDriftValidator, + StructuralDriftValidator, +) + +guard = DriftGuard( + validators=[ + StructuralDriftValidator(min_length=15), + EmbeddingDriftValidator(adapter, max_distance=0.4), + PerplexityDriftValidator(adapter, max_perplexity=80.0), + ], +) +``` + +`StructuralDriftValidator` checks `[[label]]` references, required +keywords, length, and trailing `?`. `EmbeddingDriftValidator` runs on +the embedding adapter; if the anchor sets `embedding_center` and +`max_drift`, those are used as the cosine ceiling. +`PerplexityDriftValidator` flags realizations whose perplexity +exceeds a configured ceiling. + +The embedding and perplexity validators consume narrow +`typing.Protocol`s, so any object with the right shape can serve as +the backend: + +```python +from bead.protocol import EmbeddingAdapter, PerplexityAdapter + +# Conforms structurally: +class MyBackend: + def get_embedding(self, text: str) -> Sequence[float]: ... + def compute_perplexity(self, text: str) -> float: ... + +assert isinstance(MyBackend(), EmbeddingAdapter) +assert isinstance(MyBackend(), PerplexityAdapter) +``` + +Bead's `bead.items.adapters.ModelAdapter` family conforms out of the +box. + +## Composing a protocol + +```python +from bead.protocol import AnnotationProtocol, QuestionFamily + +change = QuestionFamily( + anchor=change_anchor, + realization=contextual, + drift_guard=guard, +) + +uniformity = QuestionFamily( + anchor=uniformity_anchor, + realization=TemplateRealization(), + drift_guard=guard, + condition=( + lambda ctx: ctx.previous_responses.get("change") == "yes" + ), + depends_on=("change",), +) + +protocol = AnnotationProtocol( + families=[change, uniformity], name="aspect-protocol", +) + +realizations = protocol.realize_all( + ctx, responses={"change": "yes"}, +) +``` + +`realize_all` threads each response into the context before evaluating +the next family; non-applicable families are skipped. When a response +is not pre-supplied, the first option of the family's response space +is used as a placeholder so downstream conditional families can still +be exercised in dry-run mode. + +## Constructor invariants + +The protocol layer enforces a small set of structural invariants at +construction time so configuration errors fail loudly instead of +manifesting as confusing behavior at realization time: + +- `ResponseEncoding` requires `n_levels == len(labels)`, rejects + duplicate labels, and rejects `BINARY` scales with anything other + than 2 levels. (`encode_response_space` derives all three from the + source `ResponseSpace` so it never produces an invalid encoding.) +- `AnnotationProtocol` rejects duplicate anchor names, families that + depend on themselves, and families whose `depends_on` references a + family that is not present *earlier* in the sequence (forward + references and unknown references are both refused). The same + validation runs on `AnnotationProtocol.append`. +- `LMRealization(client, max_cache_size=...)` requires + `max_cache_size > 0`. +- `PerplexityDriftValidator(..., max_perplexity=...)` requires + `max_perplexity > 0`. + +Together with the drift validators that fire at realization time, +these invariants make the construction of an +`AnnotationProtocol` a complete static check: if construction +succeeds, every realization is well-formed up to the LM's behavior, +and any LM misbehavior is caught by the drift guard. + +## Bridging to the modeling layer + +`encode_response_space` converts a `ResponseSpace` into a +likelihood-agnostic `ResponseEncoding`: + +```python +from bead.protocol import encode_response_space + +encoding = encode_response_space("change", change_anchor.response_space) +encoding.is_binary # True for two-option, unordered spaces +encoding.label_to_index("yes") +encoding.index_to_label(0) +``` + +`bead.active_learning.models` registers the canonical +`ScaleType` → model-class mapping. To pick the active-learning model +class for an encoding: + +```python +from bead.active_learning.models import ( + config_class_for_encoding, + model_class_for_encoding, +) + +ModelClass = model_class_for_encoding(encoding) # e.g. BinaryModel +ConfigClass = config_class_for_encoding(encoding) # e.g. BinaryModelConfig +model = ModelClass(ConfigClass(model_name="bert-base-uncased")) +``` + +The same registry (`MODEL_CLASSES` / `CONFIG_CLASSES`) drives the +`bead models train-model` and `bead training` CLI commands, so there +is exactly one mapping from task type to model class across the +codebase. + +## Bridging to item construction + +`bead.protocol.items` is the single canonical bridge from a realized +question to a fully-populated `bead.items.Item`: + +```python +from bead.protocol import ( + family_to_item_template, + realization_to_item, + realize_protocol_to_items, +) + +# Per-family templates (one per anchor): +template = family_to_item_template( + family_change, judgment_type="acceptability", +) + +# Per-context realization → Item: +realization = family_change.realize(ctx) +item = realization_to_item(realization, item_template=template) + +# Whole-protocol convenience: +pairs = realize_protocol_to_items( + protocol, ctx, judgment_type="acceptability", +) +for realization, item in pairs: + ... # downstream item processing +``` + +`scale_type_to_task_type` is the canonical translation used here and +in the active-learning registry. There is no other mapping: every +protocol family produces exactly one `ItemTemplate`. + +## Bridging to deployment + +`bead.deployment.protocol_trials.protocol_to_jspsych_trials` is the +single canonical bridge from a configured protocol and a sequence of +contexts to a flat list of jsPsych trial dicts ready for batch +deployment: + +```python +from bead.deployment.protocol_trials import protocol_to_jspsych_trials + +trials = protocol_to_jspsych_trials( + protocol, + contexts, + experiment_config=experiment_config, + judgment_type="acceptability", + rating_config=rating_config, # for ordinal scales + choice_config=choice_config, # for binary / categorical +) +``` + +Each context is realized through every applicable family; each +resulting realization is packaged as an `Item`, bound to its +family's `ItemTemplate`, and fed through +`bead.deployment.jspsych.trials.create_trial`. Trials are returned +in `(context_order, family_order)` with consecutive `trial_number` +fields. + +## Bridging back from JATOS + +After deployment, `bead.data_collection.jatos_results_to_annotation_records` +is the single canonical conversion from raw JATOS results to +`bead.evaluation.AnnotationRecord` instances: + +```python +from bead.data_collection import ( + JATOSDataCollector, + jatos_results_to_annotation_records, +) +from bead.evaluation import annotator_reliability + +results = JATOSDataCollector(...).download_results(Path("results.jsonl")) +records = jatos_results_to_annotation_records(results) +profiles = annotator_reliability(records) +``` + +The bridge looks up the annotator id in `urlQueryParameters` +(`"PROLIFIC_PID"` by default; configurable), then walks each result's +trial array picking the trials with `item_id` and `template_name` +fields populated by the jsPsych deployment layer. Trials missing +those fields (instructions, consent, demographics) are skipped. +Numeric responses are stringified so the resulting +`response_label` matches the encoding's labels for the corresponding +family. + +## Configuration-driven workflow + +`bead.config.protocol.ProtocolConfig` is the single canonical +declarative form of a protocol. It plugs into `BeadConfig` as the +`protocol` section, and a complete protocol is materialized via +`ProtocolConfig.build()`: + +```yaml +# bead.yaml +protocol: + name: aspect-protocol + drift: + min_length: 15 + require_question_mark: true + lm_model_name: gpt-4o-mini + lm_temperature: 0.3 + families: + - anchor: + name: change + target_property: dynamicity + canonical_prompt: "Is anything changing in [[situation]] over time?" + options: ["no", "yes"] + is_ordered: false + required_span_labels: [situation] + realization_kind: contextual + variants: + - template: "Is [[situation]] something that is changing?" + condition_name: always + priority: 0 + - anchor: + name: completion + target_property: telicity + canonical_prompt: "Does [[situation]] reach a definite endpoint?" + options: ["definitely no", "probably no", "unsure", + "probably yes", "definitely yes"] + is_ordered: true + semantic_pole_low: "definitely no" + semantic_pole_high: "definitely yes" + required_span_labels: [situation] + realization_kind: lm + condition_name: always + depends_on: [change] +``` + +```python +from bead.config import load_config + +config = load_config("bead.yaml") +protocol = config.protocol.build( + lm_client=my_lm_client, + cache=ModelOutputCache(backend="filesystem"), +) +``` + +Predicates (`condition_name`) are looked up by name from the registry +documented above. Every realization strategy (`template`, +`contextual`, `lm`) and drift validator +(`StructuralDriftValidator` always on; `EmbeddingDriftValidator` and +`PerplexityDriftValidator` opt-in via `drift.enable_embedding` and +`drift.enable_perplexity`) is reachable from configuration without +writing Python. + +## CLI + +The `bead protocol` subcommand drives the configuration-loaded +protocol from the shell: + +```bash +# Validate the protocol config and report each family's scale + deps +bead protocol validate + +# Realize prompts for every context in contexts.jsonl +bead protocol realize contexts.jsonl realizations.jsonl + +# Realize and emit fully-populated Items (skip the realization step) +bead protocol realize contexts.jsonl items.jsonl --emit-items + +# Emit per-family ItemTemplates +bead protocol items templates.jsonl --judgment-type acceptability +``` + +Every CLI command reads the same `BeadConfig` as the Python API, so +configuration is the single source of truth. + +## Diagnostics + +`DatasetReport` accumulates immutable diagnostic findings. Every +mutating method returns a new instance. + +```python +from bead.protocol import DatasetReport, DiagnosticLevel + +report = ( + DatasetReport(n_records_input=42, n_items=20) + .with_coverage("change", 0.95) + .add(DiagnosticLevel.WARNING, "missing_response", "item i12 has no response") +) +print(report.summary()) +``` + +`ConditionalObservationValidator` inspects records against the +protocol's `depends_on` graph: + +```python +from bead.protocol import ConditionalObservationValidator +from bead.evaluation import AnnotationRecord + +records = { + "change": [ + AnnotationRecord( + annotator_id="a1", item_id="i1", + question_name="change", response_label="yes", + ), + ], + "uniformity": [ + AnnotationRecord( + annotator_id="a1", item_id="i1", + question_name="uniformity", response_label="yes", + ), + ], +} + +validator = ConditionalObservationValidator( + conditioning_values={"uniformity": {"yes"}}, +) +findings = validator.validate(records, protocol) +``` + +`ConditionalObservationValidator` accepts any record type conforming +to the `RecordLike` Protocol (anything with `item_id`, +`response_label`, and `question_name` string attributes), so callers +are not bound to `bead.evaluation.AnnotationRecord` specifically. + +## Reliability + +`bead.evaluation.reliability` complements +`bead.evaluation.InterAnnotatorMetrics` with per-annotator entropy: + +```python +from bead.evaluation import ( + annotator_reliability, low_entropy_annotators, +) + +profiles = annotator_reliability(records_flat) +flagged = low_entropy_annotators(profiles, threshold=0.5) +``` + +Low entropy means the annotator is collapsing the response space +(always picking the same label, always picking the midpoint, ...). + +`annotator_reliability(records, encodings=...)` accepts an optional +`Mapping[str, ResponseEncoding]` keyed by anchor name. When supplied, +response labels not present in the encoding for a question are +silently skipped, which is useful after schema evolution invalidates +some legacy labels. + +`low_entropy_annotators` accepts two refinements: + +- `question_name="..."` restricts the threshold check to a single + question's entropy (otherwise the *minimum* per-question entropy + is checked). +- `require_min_responses=N` skips annotators with fewer than `N` + recorded responses, so an annotator who answered only one or two + items is not flagged purely on small-sample entropy. + +## Bridging to bead's item layer + +`QuestionRealization.prompt` is a string. It can be passed straight +into bead's existing item-construction pipeline (`ItemTemplate`, +`ItemConstructor`, ...) where the `[[label]]` markers in the +realization are resolved against the item's spans. The protocol layer +deliberately does *not* perform that resolution itself: the anchor +and the realization stay agnostic to bead's `{slot}` template syntax, +so the same realization can be reused across runtimes. diff --git a/mkdocs.yml b/mkdocs.yml index f44081e..d04486c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -60,6 +60,7 @@ nav: - Overview: user-guide/index.md - Core Concepts: user-guide/concepts.md - Configuration: user-guide/configuration.md + - Annotation Protocols: user-guide/protocols.md - CLI Guide: - Overview: user-guide/cli/index.md - Resources: user-guide/cli/resources.md @@ -89,6 +90,8 @@ nav: - bead.active_learning: api/active_learning.md - bead.simulation: api/simulation.md - bead.evaluation: api/evaluation.md + - bead.labels: api/labels.md + - bead.protocol: api/protocol.md - bead.dsl: api/dsl.md - bead.behavioral: api/behavioral.md - bead.participants: api/participants.md diff --git a/pyproject.toml b/pyproject.toml index a748997..474111e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "bead" -version = "0.3.0" +version = "0.4.0" description = "Lexicon and Template Collection Construction Pipeline for Acceptability and Inference Judgment Data" authors = [{name = "Aaron Steven White", email = "aaron.white@rochester.edu"}] readme = "README.md" @@ -70,6 +70,8 @@ dev = [ "ruff>=0.1.0", "pyright>=1.1.0", "pandas-stubs>=2.0.0", + "spacy>=3.7", + "stanza>=1.8", ] api = [ "openai>=1.0.0", diff --git a/tests/cli/test_models.py b/tests/cli/test_models.py index bb0dcfc..f732da6 100644 --- a/tests/cli/test_models.py +++ b/tests/cli/test_models.py @@ -132,10 +132,12 @@ def sample_participant_ids(tmp_path: Path) -> Path: class TestTrainModelCommand: """Tests for train-model command.""" - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_train_forced_choice_fixed_mode( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, runner: CliRunner, tmp_path: Path, sample_forced_choice_items: Path, @@ -158,7 +160,8 @@ def test_train_forced_choice_fixed_mode( } mock_config_class.return_value = mock_config_instance - mock_import.side_effect = [mock_model_class, mock_config_class] + mock_model_factory.return_value = mock_model_class + mock_config_factory.return_value = mock_config_class output_dir = tmp_path / "model" @@ -190,10 +193,12 @@ def test_train_forced_choice_fixed_mode( # Verify model was saved assert (output_dir / "model.pt").exists() or mock_model_instance.save.called - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_train_ordinal_scale_random_intercepts( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, runner: CliRunner, tmp_path: Path, sample_ordinal_scale_items: Path, @@ -215,7 +220,8 @@ def test_train_ordinal_scale_random_intercepts( } mock_config_class.return_value = mock_config_instance - mock_import.side_effect = [mock_model_class, mock_config_class] + mock_model_factory.return_value = mock_model_class + mock_config_factory.return_value = mock_config_class output_dir = tmp_path / "model" @@ -242,10 +248,12 @@ def test_train_ordinal_scale_random_intercepts( assert result.exit_code == 0, f"Command failed: {result.output}" assert "Training ordinal_scale model" in result.output - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_train_with_lora( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, runner: CliRunner, tmp_path: Path, sample_forced_choice_items: Path, @@ -267,7 +275,8 @@ def test_train_with_lora( } mock_config_class.return_value = mock_config_instance - mock_import.side_effect = [mock_model_class, mock_config_class] + mock_model_factory.return_value = mock_model_class + mock_config_factory.return_value = mock_config_class output_dir = tmp_path / "model" @@ -294,10 +303,12 @@ def test_train_with_lora( assert result.exit_code == 0, f"Command failed: {result.output}" - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_train_with_validation_data( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, runner: CliRunner, tmp_path: Path, sample_forced_choice_items: Path, @@ -317,7 +328,8 @@ def test_train_with_validation_data( } mock_config_class.return_value = mock_config_instance - mock_import.side_effect = [mock_model_class, mock_config_class] + mock_model_factory.return_value = mock_model_class + mock_config_factory.return_value = mock_config_class output_dir = tmp_path / "model" @@ -403,10 +415,12 @@ def test_train_invalid_task_type( class TestPredictCommand: """Tests for predict command.""" - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_predict_basic( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, runner: CliRunner, tmp_path: Path, sample_forced_choice_items: Path, @@ -445,7 +459,8 @@ def test_predict_basic( # Mock config class (not needed for predict, but imported) mock_config_class = MagicMock() - mock_import.side_effect = [mock_model_class, mock_config_class] + mock_model_factory.return_value = mock_model_class + mock_config_factory.return_value = mock_config_class output_file = tmp_path / "predictions.jsonl" @@ -470,10 +485,12 @@ def test_predict_basic( # Verify predictions file was created assert output_file.exists() - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_predict_missing_config( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, runner: CliRunner, tmp_path: Path, sample_forced_choice_items: Path, @@ -508,10 +525,12 @@ def test_predict_missing_config( class TestPredictProbaCommand: """Tests for predict-proba command.""" - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_predict_proba_basic( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, runner: CliRunner, tmp_path: Path, sample_forced_choice_items: Path, @@ -543,7 +562,8 @@ def test_predict_proba_basic( mock_model_class.return_value = mock_model_instance mock_config_class = MagicMock() - mock_import.side_effect = [mock_model_class, mock_config_class] + mock_model_factory.return_value = mock_model_class + mock_config_factory.return_value = mock_config_class output_file = tmp_path / "probabilities.json" @@ -614,10 +634,12 @@ class TestAllTaskTypes: "cloze", ], ) - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_train_all_task_types( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, task_type: str, runner: CliRunner, tmp_path: Path, @@ -638,7 +660,8 @@ def test_train_all_task_types( } mock_config_class.return_value = mock_config_instance - mock_import.side_effect = [mock_model_class, mock_config_class] + mock_model_factory.return_value = mock_model_class + mock_config_factory.return_value = mock_config_class output_dir = tmp_path / f"model_{task_type}" @@ -671,10 +694,12 @@ class TestAllMixedEffectsModes: "mode", ["fixed", "random_intercepts", "random_slopes"], ) - @patch("bead.cli.models._import_class") + @patch("bead.cli.models.config_class_for_task_type") + @patch("bead.cli.models.model_class_for_task_type") def test_train_all_modes( self, - mock_import: MagicMock, + mock_model_factory: MagicMock, + mock_config_factory: MagicMock, mode: str, runner: CliRunner, tmp_path: Path, @@ -696,7 +721,8 @@ def test_train_all_modes( } mock_config_class.return_value = mock_config_instance - mock_import.side_effect = [mock_model_class, mock_config_class] + mock_model_factory.return_value = mock_model_class + mock_config_factory.return_value = mock_config_class output_dir = tmp_path / f"model_{mode}" diff --git a/tests/cli/test_protocol.py b/tests/cli/test_protocol.py new file mode 100644 index 0000000..3d88d96 --- /dev/null +++ b/tests/cli/test_protocol.py @@ -0,0 +1,185 @@ +"""Integration tests for the ``bead protocol`` CLI.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +import yaml +from click.testing import CliRunner + +from bead.cli.protocol import protocol + + +def _project(tmp_path: Path, *, with_families: bool = True) -> Path: + """Write a minimal bead.toml-equivalent YAML config and return its path.""" + cfg: dict[str, object] = { + "profile": "default", + "paths": { + "data_dir": str(tmp_path), + "output_dir": str(tmp_path / "out"), + "cache_dir": str(tmp_path / ".cache"), + }, + "protocol": { + "name": "test-protocol", + "drift": { + "min_length": 5, + "require_question_mark": True, + }, + "families": ( + [ + { + "anchor": { + "name": "completion", + "target_property": "telicity", + "canonical_prompt": ( + "Does [[situation]] reach an endpoint?" + ), + "options": ["no", "yes"], + "is_ordered": False, + "required_span_labels": ["situation"], + }, + "realization_kind": "template", + } + ] + if with_families + else [] + ), + }, + } + config_path = tmp_path / "bead.yaml" + config_path.write_text(yaml.safe_dump(cfg)) + (tmp_path / "out").mkdir(exist_ok=True) + return config_path + + +@pytest.fixture +def runner() -> CliRunner: + return CliRunner() + + +def test_validate_reports_families(runner: CliRunner, tmp_path: Path) -> None: + config_path = _project(tmp_path) + result = runner.invoke(protocol, ["validate", "--config-file", str(config_path)]) + assert result.exit_code == 0, result.output + assert "test-protocol" in result.output + assert "completion" in result.output + + +def test_validate_empty_protocol_still_passes( + runner: CliRunner, tmp_path: Path +) -> None: + config_path = _project(tmp_path, with_families=False) + result = runner.invoke(protocol, ["validate", "--config-file", str(config_path)]) + assert result.exit_code == 0 + assert "0 families" in result.output + + +def test_realize_writes_realizations(runner: CliRunner, tmp_path: Path) -> None: + config_path = _project(tmp_path) + contexts_file = tmp_path / "contexts.jsonl" + contexts = [ + { + "sentence": f"Mary built sandcastle {i}.", + "target_lemma": "build", + "target_form": "built", + "target_upos": "VERB", + "target_position": 2, + "target_span_text": f"built sandcastle {i}", + "target_span_positions": [2, 3, 4], + "target_id": f"item-{i}", + } + for i in range(3) + ] + contexts_file.write_text("\n".join(json.dumps(c) for c in contexts) + "\n") + output_file = tmp_path / "realizations.jsonl" + result = runner.invoke( + protocol, + [ + "realize", + str(contexts_file), + str(output_file), + "--config-file", + str(config_path), + ], + ) + assert result.exit_code == 0, result.output + lines = output_file.read_text().strip().splitlines() + assert len(lines) == 3 + parsed = [json.loads(line) for line in lines] + for record in parsed: + assert "prompt" in record + assert "[[situation]]" in record["prompt"] + + +def test_realize_emit_items(runner: CliRunner, tmp_path: Path) -> None: + config_path = _project(tmp_path) + contexts_file = tmp_path / "contexts.jsonl" + contexts_file.write_text( + json.dumps( + { + "sentence": "Mary built a sandcastle.", + "target_lemma": "build", + "target_form": "built", + "target_upos": "VERB", + "target_position": 2, + "target_span_text": "built a sandcastle", + "target_span_positions": [2, 3, 4], + "target_id": "item-0", + } + ) + + "\n" + ) + output_file = tmp_path / "items.jsonl" + result = runner.invoke( + protocol, + [ + "realize", + str(contexts_file), + str(output_file), + "--config-file", + str(config_path), + "--emit-items", + ], + ) + assert result.exit_code == 0, result.output + lines = output_file.read_text().strip().splitlines() + assert len(lines) == 1 + record = json.loads(lines[0]) + assert "item_template_id" in record + assert "spans" in record + + +def test_items_writes_templates(runner: CliRunner, tmp_path: Path) -> None: + config_path = _project(tmp_path) + output_file = tmp_path / "templates.jsonl" + result = runner.invoke( + protocol, + ["items", str(output_file), "--config-file", str(config_path)], + ) + assert result.exit_code == 0, result.output + lines = output_file.read_text().strip().splitlines() + assert len(lines) == 1 + record = json.loads(lines[0]) + assert record["name"] == "completion" + assert record["task_type"] == "binary" + + +def test_realize_empty_protocol_errors(runner: CliRunner, tmp_path: Path) -> None: + config_path = _project(tmp_path, with_families=False) + contexts_file = tmp_path / "contexts.jsonl" + contexts_file.write_text("") + output_file = tmp_path / "out.jsonl" + result = runner.invoke( + protocol, + [ + "realize", + str(contexts_file), + str(output_file), + "--config-file", + str(config_path), + ], + ) + assert result.exit_code == 1 + assert "empty" in result.output.lower() diff --git a/tests/cli/test_training.py b/tests/cli/test_training.py index 2219a3c..a1306c6 100644 --- a/tests/cli/test_training.py +++ b/tests/cli/test_training.py @@ -135,7 +135,10 @@ def test_evaluate_basic(self, cli_runner: CliRunner, tmp_path: Path) -> None: f.write("0\n") # Mock model loading and prediction - with patch("bead.cli.training._import_class") as mock_import: + with ( + patch("bead.cli.training.model_class_for_task_type") as mock_import, + patch("bead.cli.training.config_class_for_task_type") as _, + ): mock_model_class = MagicMock() mock_model = MagicMock() mock_model.predict.return_value = [0] * 10 # Perfect predictions @@ -191,7 +194,10 @@ def test_evaluate_with_output(self, cli_runner: CliRunner, tmp_path: Path) -> No output_file = tmp_path / "results.json" - with patch("bead.cli.training._import_class") as mock_import: + with ( + patch("bead.cli.training.model_class_for_task_type") as mock_import, + patch("bead.cli.training.config_class_for_task_type") as _, + ): mock_model_class = MagicMock() mock_model = MagicMock() mock_model.predict.return_value = [1, 1, 1, 0, 1] @@ -297,7 +303,10 @@ def test_cross_validate_basic(self, cli_runner: CliRunner, tmp_path: Path) -> No ) # Mock model training and prediction - with patch("bead.cli.training._import_class") as mock_import: + with ( + patch("bead.cli.training.model_class_for_task_type") as mock_import, + patch("bead.cli.training.config_class_for_task_type") as _, + ): mock_model_class = MagicMock() mock_model = MagicMock() @@ -364,7 +373,10 @@ def test_cross_validate_with_stratification( output_file = tmp_path / "cv_results.json" - with patch("bead.cli.training._import_class") as mock_import: + with ( + patch("bead.cli.training.model_class_for_task_type") as mock_import, + patch("bead.cli.training.config_class_for_task_type") as _, + ): mock_model_class = MagicMock() mock_model = MagicMock() @@ -443,7 +455,10 @@ def test_learning_curve_basic(self, cli_runner: CliRunner, tmp_path: Path) -> No ) ) - with patch("bead.cli.training._import_class") as mock_import: + with ( + patch("bead.cli.training.model_class_for_task_type") as mock_import, + patch("bead.cli.training.config_class_for_task_type") as _, + ): mock_model_class = MagicMock() mock_model = MagicMock() mock_model.predict.side_effect = lambda items, **kwargs: [0] * len(items) @@ -500,7 +515,10 @@ def test_learning_curve_with_output( output_file = tmp_path / "learning_curve.json" - with patch("bead.cli.training._import_class") as mock_import: + with ( + patch("bead.cli.training.model_class_for_task_type") as mock_import, + patch("bead.cli.training.config_class_for_task_type") as _, + ): mock_model_class = MagicMock() mock_model = MagicMock() mock_model.predict.side_effect = lambda items, **kwargs: [0] * len(items) diff --git a/tests/config/test_protocol_config.py b/tests/config/test_protocol_config.py new file mode 100644 index 0000000..1a5ea71 --- /dev/null +++ b/tests/config/test_protocol_config.py @@ -0,0 +1,319 @@ +"""Tests for :mod:`bead.config.protocol`.""" + +from __future__ import annotations + +from collections.abc import Sequence +from pathlib import Path + +import pytest + +from bead.config.protocol import ( + AnchorSpec, + DriftConfig, + FamilySpec, + ProtocolConfig, + TemplateVariantSpec, +) +from bead.config.serialization import to_yaml +from bead.protocol import ( + AnnotationProtocol, + ContextualTemplateRealization, + LMRealization, + SemanticAnchor, + StructuralDriftValidator, + TemplateRealization, + register_context_predicate, +) +from bead.protocol.context import ProtocolContext + + +def _is_verb(ctx: ProtocolContext) -> bool: + return ctx.target_upos == "VERB" + + +register_context_predicate("test_is_verb", _is_verb) + + +class TestAnchorSpec: + """Tests for :class:`AnchorSpec`.""" + + def test_build_minimal(self) -> None: + spec = AnchorSpec( + name="completion", + target_property="telicity", + canonical_prompt="Does [[situation]] end?", + options=("no", "yes"), + is_ordered=False, + required_span_labels=frozenset({"situation"}), + ) + anchor = spec.build() + assert isinstance(anchor, SemanticAnchor) + assert anchor.name == "completion" + assert anchor.response_space.options == ("no", "yes") + assert anchor.response_space.semantic_poles is None + + def test_build_with_poles(self) -> None: + spec = AnchorSpec( + name="freq", + target_property="frequency", + canonical_prompt="How often does [[s]] happen?", + options=("never", "sometimes", "always"), + is_ordered=True, + semantic_pole_low="never", + semantic_pole_high="always", + required_span_labels=frozenset({"s"}), + ) + anchor = spec.build() + poles = anchor.response_space.semantic_poles + assert poles is not None + assert poles.as_tuple() == ("never", "always") + + def test_partial_poles_rejected(self) -> None: + spec = AnchorSpec( + name="x", + target_property="x", + canonical_prompt="Q?", + options=("a", "b"), + semantic_pole_low="a", + ) + with pytest.raises(ValueError, match="only one pole"): + spec.build() + + +class TestDriftConfig: + """Tests for :class:`DriftConfig`.""" + + def test_default_builds_structural_only(self) -> None: + guard = DriftConfig().build() + assert len(guard) == 1 + assert isinstance(guard.validators[0], StructuralDriftValidator) + + def test_embedding_requires_adapter(self) -> None: + cfg = DriftConfig(enable_embedding=True) + with pytest.raises(ValueError, match="embedding_adapter"): + cfg.build() + + def test_perplexity_requires_adapter(self) -> None: + cfg = DriftConfig(enable_perplexity=True) + with pytest.raises(ValueError, match="perplexity_adapter"): + cfg.build() + + def test_with_adapters(self) -> None: + class Adapter: + def get_embedding(self, text: str) -> Sequence[float]: + del text + return (1.0, 0.0) + + def compute_perplexity(self, text: str) -> float: + del text + return 25.0 + + adapter = Adapter() + cfg = DriftConfig( + enable_embedding=True, + enable_perplexity=True, + max_perplexity=50.0, + ) + guard = cfg.build( + embedding_adapter=adapter, + perplexity_adapter=adapter, + ) + assert len(guard) == 3 + + +class TestFamilySpec: + """Tests for :class:`FamilySpec`.""" + + def _anchor(self, name: str = "x") -> AnchorSpec: + return AnchorSpec( + name=name, + target_property=name, + canonical_prompt="Question for [[s]]?", + options=("no", "yes"), + is_ordered=False, + required_span_labels=frozenset({"s"}), + ) + + def test_template_realization(self) -> None: + spec = FamilySpec( + anchor=self._anchor(), + realization_kind="template", + template="Did [[s]] happen?", + ) + family = spec.build( + drift_guard=DriftConfig().build(), + lm_client=None, + lm_model_name="", + cache=None, + lm_temperature=0.3, + lm_max_tokens=200, + ) + assert isinstance(family.realization, TemplateRealization) + + def test_contextual_requires_variants(self) -> None: + spec = FamilySpec( + anchor=self._anchor(), + realization_kind="contextual", + variants=(), + ) + with pytest.raises(ValueError, match="variants is empty"): + spec.build( + drift_guard=DriftConfig().build(), + lm_client=None, + lm_model_name="", + cache=None, + lm_temperature=0.3, + lm_max_tokens=200, + ) + + def test_contextual_realization(self) -> None: + spec = FamilySpec( + anchor=self._anchor(), + realization_kind="contextual", + variants=( + TemplateVariantSpec( + template="V: [[s]]?", + condition_name="test_is_verb", + priority=10, + ), + TemplateVariantSpec(template="Generic [[s]]?", priority=0), + ), + ) + family = spec.build( + drift_guard=DriftConfig().build(), + lm_client=None, + lm_model_name="", + cache=None, + lm_temperature=0.3, + lm_max_tokens=200, + ) + assert isinstance(family.realization, ContextualTemplateRealization) + + def test_lm_requires_client(self) -> None: + spec = FamilySpec( + anchor=self._anchor(), + realization_kind="lm", + ) + with pytest.raises(ValueError, match="no lm_client"): + spec.build( + drift_guard=DriftConfig().build(), + lm_client=None, + lm_model_name="x", + cache=None, + lm_temperature=0.3, + lm_max_tokens=200, + ) + + def test_lm_realization(self) -> None: + class Client: + def complete( + self, prompt: str, *, temperature: float, max_tokens: int + ) -> str: + del prompt, temperature, max_tokens + return "Did [[s]] happen?" + + spec = FamilySpec( + anchor=self._anchor(), + realization_kind="lm", + ) + family = spec.build( + drift_guard=DriftConfig().build(), + lm_client=Client(), + lm_model_name="stub", + cache=None, + lm_temperature=0.5, + lm_max_tokens=128, + ) + assert isinstance(family.realization, LMRealization) + + def test_condition_name_resolved(self) -> None: + spec = FamilySpec( + anchor=self._anchor("y"), + condition_name="test_is_verb", + depends_on=(), + ) + family = spec.build( + drift_guard=DriftConfig().build(), + lm_client=None, + lm_model_name="", + cache=None, + lm_temperature=0.3, + lm_max_tokens=200, + ) + assert family.is_always_applicable is False + assert family.is_applicable(ProtocolContext(target_upos="VERB")) + assert not family.is_applicable(ProtocolContext(target_upos="NOUN")) + + +class TestProtocolConfig: + """Tests for :class:`ProtocolConfig`.""" + + def test_empty_default(self) -> None: + cfg = ProtocolConfig() + proto = cfg.build() + assert isinstance(proto, AnnotationProtocol) + assert len(proto) == 0 + + def test_two_family_protocol(self) -> None: + cfg = ProtocolConfig( + name="aspect", + families=( + FamilySpec( + anchor=AnchorSpec( + name="change", + target_property="dynamicity", + canonical_prompt="Changing [[s]]?", + options=("no", "yes"), + is_ordered=False, + required_span_labels=frozenset({"s"}), + ), + ), + FamilySpec( + anchor=AnchorSpec( + name="completion", + target_property="telicity", + canonical_prompt="Endpoint [[s]]?", + options=("no", "yes"), + is_ordered=False, + required_span_labels=frozenset({"s"}), + ), + depends_on=("change",), + ), + ), + ) + proto = cfg.build() + assert proto.name == "aspect" + assert [f.name for f in proto.families] == ["change", "completion"] + assert proto.family_by_name("completion").depends_on == ("change",) + + def test_yaml_round_trip(self, tmp_path: Path) -> None: + """Round-trip through YAML preserves the protocol structure.""" + del tmp_path # only used for type assertion below + cfg = ProtocolConfig( + name="rt", + families=( + FamilySpec( + anchor=AnchorSpec( + name="q1", + target_property="q1", + canonical_prompt="[[s]]?", + options=("no", "yes"), + is_ordered=False, + required_span_labels=frozenset({"s"}), + ), + ), + ), + ) + yaml_text = to_yaml(cfg, include_defaults=False) + assert "rt" in yaml_text + assert "q1" in yaml_text + + +def test_protocol_config_in_bead_config() -> None: + """ProtocolConfig is wired into BeadConfig.protocol.""" + from bead.config import BeadConfig # noqa: PLC0415 + + config = BeadConfig() + assert isinstance(config.protocol, ProtocolConfig) + assert config.protocol.name == "" + assert len(config.protocol.families) == 0 diff --git a/tests/data_collection/test_records.py b/tests/data_collection/test_records.py new file mode 100644 index 0000000..9305323 --- /dev/null +++ b/tests/data_collection/test_records.py @@ -0,0 +1,183 @@ +"""Tests for :mod:`bead.data_collection.records`.""" + +from __future__ import annotations + +from bead.data_collection.records import jatos_results_to_annotation_records +from bead.evaluation import annotator_reliability + + +def _result( + *, + annotator: str = "P001", + trials: list[dict[str, object]], + annotator_key: str = "PROLIFIC_PID", +) -> dict[str, object]: + return { + "urlQueryParameters": {annotator_key: annotator}, + "worker_id": "worker-9", + "data": trials, + } + + +def test_basic_conversion() -> None: + results = [ + _result( + annotator="P001", + trials=[ + { + "item_id": "i1", + "template_name": "completion", + "response": "yes", + }, + { + "item_id": "i2", + "template_name": "completion", + "response": "no", + }, + ], + ), + _result( + annotator="P002", + trials=[ + { + "item_id": "i1", + "template_name": "completion", + "response": "yes", + }, + ], + ), + ] + records = jatos_results_to_annotation_records(results) + assert len(records) == 3 + assert records[0].annotator_id == "P001" + assert records[0].item_id == "i1" + assert records[0].question_name == "completion" + assert records[0].response_label == "yes" + assert records[2].annotator_id == "P002" + + +def test_falls_back_to_worker_id() -> None: + results = [ + { + "urlQueryParameters": {}, + "worker_id": "worker-42", + "data": [ + { + "item_id": "i1", + "template_name": "q", + "response": "yes", + }, + ], + }, + ] + records = jatos_results_to_annotation_records(results) + assert len(records) == 1 + assert records[0].annotator_id == "worker-42" + + +def test_skips_non_question_trials() -> None: + results = [ + _result( + trials=[ + {"trial_type": "instructions", "response": None}, + { + "item_id": "i1", + "template_name": "q", + "response": "yes", + }, + ], + ), + ] + records = jatos_results_to_annotation_records(results) + assert len(records) == 1 + + +def test_numeric_response_coerced_to_str() -> None: + results = [ + _result( + trials=[ + { + "item_id": "i1", + "template_name": "rating", + "response": 5, + }, + { + "item_id": "i2", + "template_name": "rating", + "response": 4.5, + }, + ], + ), + ] + records = jatos_results_to_annotation_records(results) + assert records[0].response_label == "5" + assert records[1].response_label == "4.5" + + +def test_response_object_with_response_key() -> None: + results = [ + _result( + trials=[ + { + "item_id": "i1", + "template_name": "q", + "response": {"response": "yes", "rt": 100}, + }, + ], + ), + ] + records = jatos_results_to_annotation_records(results) + assert records[0].response_label == "yes" + + +def test_missing_annotator_skipped() -> None: + results = [ + { + "urlQueryParameters": {}, + "data": [ + {"item_id": "i1", "template_name": "q", "response": "yes"}, + ], + }, + ] + records = jatos_results_to_annotation_records(results) + assert records == () + + +def test_custom_annotator_key() -> None: + results = [ + _result( + annotator="C42", + annotator_key="custom_id", + trials=[ + {"item_id": "i1", "template_name": "q", "response": "yes"}, + ], + ), + ] + records = jatos_results_to_annotation_records(results, annotator_id_key="custom_id") + assert records[0].annotator_id == "C42" + + +def test_pipes_into_annotator_reliability() -> None: + """The bridge output composes with annotator_reliability end-to-end.""" + results = [ + _result( + annotator="A", + trials=[ + {"item_id": "i1", "template_name": "q", "response": "yes"}, + {"item_id": "i2", "template_name": "q", "response": "no"}, + ], + ), + _result( + annotator="B", + trials=[ + {"item_id": "i1", "template_name": "q", "response": "yes"}, + {"item_id": "i2", "template_name": "q", "response": "yes"}, + ], + ), + ] + records = jatos_results_to_annotation_records(results) + profiles = annotator_reliability(records) + by_id = {p.annotator_id: p for p in profiles} + # A used both labels → entropy = 1.0; B used one label → entropy = 0.0 + assert by_id["A"].entropy("q") == 1.0 + assert by_id["B"].entropy("q") == 0.0 diff --git a/tests/deployment/jspsych/test_trials.py b/tests/deployment/jspsych/test_trials.py index f260d95..2e5cf8c 100644 --- a/tests/deployment/jspsych/test_trials.py +++ b/tests/deployment/jspsych/test_trials.py @@ -26,7 +26,6 @@ SpanColorMap, _assign_span_colors, _generate_stimulus_html, - _parse_prompt_references, _resolve_prompt_references, create_completion_trial, create_consent_trial, @@ -42,6 +41,7 @@ TaskSpec, ) from bead.items.spans import Span, SpanLabel, SpanSegment +from bead.labels import parse_label_refs class TestCreateTrial: @@ -499,17 +499,17 @@ def test_completion_trial_custom_message(self) -> None: class TestParsePromptReferences: - """Tests for _parse_prompt_references().""" + """Tests for parse_label_refs().""" def test_no_references(self) -> None: - """Plain text without references returns an empty list.""" - refs = _parse_prompt_references("How natural is this sentence?") + """Plain text without references returns an empty tuple.""" + refs = parse_label_refs("How natural is this sentence?") - assert refs == [] + assert refs == () def test_auto_fill_reference(self) -> None: """Single auto-fill reference is parsed with label and no display_text.""" - refs = _parse_prompt_references("How natural is [[agent]]?") + refs = parse_label_refs("How natural is [[agent]]?") assert len(refs) == 1 assert refs[0].label == "agent" @@ -517,7 +517,7 @@ def test_auto_fill_reference(self) -> None: def test_explicit_text_reference(self) -> None: """Explicit text reference is parsed with both label and display_text.""" - refs = _parse_prompt_references("Did [[event:the breaking]] happen?") + refs = parse_label_refs("Did [[event:the breaking]] happen?") assert len(refs) == 1 assert refs[0].label == "event" @@ -525,7 +525,7 @@ def test_explicit_text_reference(self) -> None: def test_multiple_references(self) -> None: """Multiple references are parsed in order of appearance.""" - refs = _parse_prompt_references("Did [[agent]] cause [[event:the breaking]]?") + refs = parse_label_refs("Did [[agent]] cause [[event:the breaking]]?") assert len(refs) == 2 assert refs[0].label == "agent" diff --git a/tests/evaluation/test_reliability.py b/tests/evaluation/test_reliability.py new file mode 100644 index 0000000..0f5e2a3 --- /dev/null +++ b/tests/evaluation/test_reliability.py @@ -0,0 +1,166 @@ +"""Tests for :mod:`bead.evaluation.reliability`.""" + +from __future__ import annotations + +import math + +import pytest + +from bead.evaluation.reliability import ( + AnnotationRecord, + AnnotatorReliability, + annotator_reliability, + low_entropy_annotators, +) +from bead.protocol.encoding import ResponseEncoding, ScaleType + + +def _record( + annotator: str, item: str, label: str, question: str = "q" +) -> AnnotationRecord: + return AnnotationRecord( + annotator_id=annotator, + item_id=item, + question_name=question, + response_label=label, + ) + + +class TestAnnotatorReliability: + """Tests for :class:`AnnotatorReliability`.""" + + def test_entropy_lookup(self) -> None: + rel = AnnotatorReliability( + annotator_id="a1", + n_responses=4, + response_distribution={"q": {"yes": 2, "no": 2}}, + entropy_per_question={"q": 1.0}, + ) + assert rel.entropy("q") == pytest.approx(1.0) + assert rel.entropy("missing") is None + + +class TestAnnotatorReliabilityFunction: + """Tests for :func:`annotator_reliability`.""" + + def test_uniform_distribution_max_entropy(self) -> None: + records = [ + _record("a1", "i1", "yes"), + _record("a1", "i2", "no"), + ] + profiles = annotator_reliability(records) + assert len(profiles) == 1 + assert profiles[0].entropy("q") == pytest.approx(1.0) + + def test_constant_response_zero_entropy(self) -> None: + records = [ + _record("a1", "i1", "yes"), + _record("a1", "i2", "yes"), + _record("a1", "i3", "yes"), + ] + profiles = annotator_reliability(records) + assert profiles[0].entropy("q") == pytest.approx(0.0) + assert profiles[0].n_responses == 3 + + def test_three_way_uniform(self) -> None: + records = [ + _record("a", "i1", "a"), + _record("a", "i2", "b"), + _record("a", "i3", "c"), + ] + profiles = annotator_reliability(records) + # Shannon entropy of uniform 3-way = log2(3) + assert profiles[0].entropy("q") == pytest.approx(math.log2(3)) + + def test_grouped_by_question(self) -> None: + records = [ + _record("a1", "i1", "yes", question="q1"), + _record("a1", "i2", "no", question="q1"), + _record("a1", "i1", "always", question="q2"), + _record("a1", "i2", "always", question="q2"), + ] + profiles = annotator_reliability(records) + assert profiles[0].entropy("q1") == pytest.approx(1.0) + assert profiles[0].entropy("q2") == pytest.approx(0.0) + + def test_filters_unknown_labels_with_encoding(self) -> None: + encoding = ResponseEncoding( + name="q", + n_levels=2, + scale_type=ScaleType.BINARY, + labels=("no", "yes"), + ) + records = [ + _record("a1", "i1", "yes"), + _record("a1", "i2", "maybe"), # unknown label + _record("a1", "i3", "no"), + ] + profiles = annotator_reliability(records, {"q": encoding}) + assert profiles[0].n_responses == 2 + assert profiles[0].response_distribution["q"] == {"yes": 1, "no": 1} + + def test_sorted_by_annotator_id(self) -> None: + records = [ + _record("c", "i1", "yes"), + _record("a", "i1", "yes"), + _record("b", "i1", "yes"), + ] + profiles = annotator_reliability(records) + assert [p.annotator_id for p in profiles] == ["a", "b", "c"] + + +class TestLowEntropyAnnotators: + """Tests for :func:`low_entropy_annotators`.""" + + def _profiles(self) -> tuple[AnnotatorReliability, ...]: + return ( + AnnotatorReliability( + annotator_id="lazy", + n_responses=10, + entropy_per_question={"q1": 0.0, "q2": 0.5}, + ), + AnnotatorReliability( + annotator_id="diligent", + n_responses=10, + entropy_per_question={"q1": 0.95, "q2": 1.5}, + ), + AnnotatorReliability( + annotator_id="newcomer", + n_responses=1, + entropy_per_question={"q1": 0.0}, + ), + ) + + def test_global_min_threshold(self) -> None: + flagged = low_entropy_annotators(self._profiles(), threshold=0.5) + # 'lazy' has min 0.0 (<= 0.5); 'newcomer' min 0.0 but min responses + # is 1 by default so it's still flagged + assert flagged == ("lazy", "newcomer") + + def test_per_question_threshold(self) -> None: + flagged = low_entropy_annotators( + self._profiles(), + threshold=0.5, + question_name="q1", + ) + assert flagged == ("lazy", "newcomer") + + def test_min_responses_filter(self) -> None: + flagged = low_entropy_annotators( + self._profiles(), + threshold=0.5, + require_min_responses=5, + ) + assert flagged == ("lazy",) # newcomer dropped + + def test_no_matches(self) -> None: + flagged = low_entropy_annotators(self._profiles(), threshold=-0.1) + assert flagged == () + + def test_unknown_question_returns_empty(self) -> None: + flagged = low_entropy_annotators( + self._profiles(), + threshold=0.5, + question_name="missing", + ) + assert flagged == () diff --git a/tests/protocol/__init__.py b/tests/protocol/__init__.py new file mode 100644 index 0000000..a05ceeb --- /dev/null +++ b/tests/protocol/__init__.py @@ -0,0 +1 @@ +"""Tests for the bead.protocol package.""" diff --git a/tests/protocol/test_anchor.py b/tests/protocol/test_anchor.py new file mode 100644 index 0000000..b8a9f10 --- /dev/null +++ b/tests/protocol/test_anchor.py @@ -0,0 +1,91 @@ +"""Tests for :mod:`bead.protocol.anchor`.""" + +from __future__ import annotations + +import pytest + +from bead.protocol.anchor import ResponseSpace, SemanticAnchor, SemanticPoles + + +class TestResponseSpace: + """Tests for :class:`ResponseSpace`.""" + + def test_construction_defaults(self) -> None: + rs = ResponseSpace(options=("no", "yes")) + assert rs.options == ("no", "yes") + assert rs.is_ordered is True + assert rs.semantic_poles is None + + def test_membership_and_length(self) -> None: + rs = ResponseSpace( + options=("definitely no", "unsure", "definitely yes"), + is_ordered=True, + semantic_poles=SemanticPoles(low="definitely no", high="definitely yes"), + ) + assert len(rs) == 3 + assert "unsure" in rs + assert "absent" not in rs + + def test_frozen_with_round_trip(self) -> None: + rs = ResponseSpace(options=("a", "b")) + rs2 = rs.with_(options=("a", "b", "c")) + assert rs.options == ("a", "b") + assert rs2.options == ("a", "b", "c") + assert rs.id == rs2.id # with_ preserves identity + + +class TestSemanticPoles: + """Tests for :class:`SemanticPoles`.""" + + def test_as_tuple(self) -> None: + poles = SemanticPoles(low="never", high="always") + assert poles.as_tuple() == ("never", "always") + + +class TestSemanticAnchor: + """Tests for :class:`SemanticAnchor`.""" + + def _build(self) -> SemanticAnchor: + rs = ResponseSpace( + options=("no", "yes"), + is_ordered=False, + ) + return SemanticAnchor( + name="completion", + target_property="telicity", + canonical_prompt="Does [[situation]] reach an endpoint?", + response_space=rs, + required_span_labels=frozenset({"situation"}), + required_keywords=frozenset({"endpoint"}), + description="Whether the event culminates.", + ) + + def test_construction(self) -> None: + anchor = self._build() + assert anchor.name == "completion" + assert anchor.target_property == "telicity" + assert anchor.required_span_labels == frozenset({"situation"}) + assert anchor.required_keywords == frozenset({"endpoint"}) + assert anchor.max_drift == pytest.approx(0.3) + + def test_from_response_options(self) -> None: + anchor = SemanticAnchor.from_response_options( + name="freq", + target_property="frequency", + canonical_prompt="How often does [[situation]] happen?", + options=("never", "sometimes", "always"), + is_ordered=True, + semantic_poles=SemanticPoles(low="never", high="always"), + required_span_labels=frozenset({"situation"}), + ) + assert anchor.response_space.is_ordered is True + poles = anchor.response_space.semantic_poles + assert poles is not None + assert poles.as_tuple() == ("never", "always") + assert anchor.required_span_labels == frozenset({"situation"}) + + def test_with_round_trip(self) -> None: + anchor = self._build() + anchor2 = anchor.with_(max_drift=0.5) + assert anchor.max_drift == pytest.approx(0.3) + assert anchor2.max_drift == pytest.approx(0.5) diff --git a/tests/protocol/test_context.py b/tests/protocol/test_context.py new file mode 100644 index 0000000..b96019f --- /dev/null +++ b/tests/protocol/test_context.py @@ -0,0 +1,105 @@ +"""Tests for :mod:`bead.protocol.context`.""" + +from __future__ import annotations + +import pytest + +from bead.protocol import context as context_module +from bead.protocol.context import ( + ContextItem, + ProtocolContext, + get_context_predicate, + list_context_predicates, + register_context_predicate, +) + + +class TestContextItem: + """Tests for :class:`ContextItem`.""" + + def test_attribute_lookup(self) -> None: + item = ContextItem( + head_lemma="ball", + attributes={"animacy": 0.0, "definiteness": 0.7}, + ) + assert item.attribute("animacy") == pytest.approx(0.0) + assert item.attribute("definiteness") == pytest.approx(0.7) + assert item.attribute("absent") is None + + def test_defaults(self) -> None: + item = ContextItem() + assert item.node_id == "" + assert item.span_positions == () + assert item.attributes == {} + + +class TestProtocolContext: + """Tests for :class:`ProtocolContext`.""" + + def test_with_response_threads(self) -> None: + ctx = ProtocolContext(sentence="Mary ran fast.") + ctx2 = ctx.with_response("dynamicity", "yes") + ctx3 = ctx2.with_response("completion", "no") + + assert ctx.previous_responses == {} + assert ctx2.previous_responses == {"dynamicity": "yes"} + assert ctx3.previous_responses == { + "dynamicity": "yes", + "completion": "no", + } + + def test_get_response(self) -> None: + ctx = ProtocolContext().with_response("q", "yes") + assert ctx.get_response("q") == "yes" + assert ctx.get_response("absent") is None + + def test_with_dependents(self) -> None: + dep = ContextItem(head_lemma="ball", head_upos="NOUN") + ctx = ProtocolContext( + sentence="Mary kicked the ball.", + dependents=(dep,), + ) + assert len(ctx.dependents) == 1 + assert ctx.dependents[0].head_lemma == "ball" + + +class TestPredicateRegistry: + """Tests for the context-predicate registry.""" + + def test_always_is_pre_registered(self) -> None: + always = get_context_predicate("always") + assert always(ProtocolContext()) is True + + def test_register_and_lookup(self) -> None: + def has_dependents(ctx: ProtocolContext) -> bool: + return len(ctx.dependents) > 0 + + register_context_predicate("has_dependents_test", has_dependents) + try: + fetched = get_context_predicate("has_dependents_test") + assert fetched is has_dependents + assert "has_dependents_test" in list_context_predicates() + assert fetched(ProtocolContext()) is False + ctx = ProtocolContext(dependents=(ContextItem(),)) + assert fetched(ctx) is True + finally: + # cleanup so other tests don't see this predicate + context_module._PREDICATES.pop("has_dependents_test", None) + + def test_unknown_predicate_raises(self) -> None: + with pytest.raises(KeyError, match="No context predicate"): + get_context_predicate("nonexistent_predicate_xyz") + + def test_re_registration_overwrites(self) -> None: + def first(_ctx: ProtocolContext) -> bool: + return True + + def second(_ctx: ProtocolContext) -> bool: + return False + + register_context_predicate("override_test", first) + register_context_predicate("override_test", second) + try: + assert get_context_predicate("override_test") is second + finally: + context_module._PREDICATES.pop("override_test", None) diff --git a/tests/protocol/test_deployment_bridge.py b/tests/protocol/test_deployment_bridge.py new file mode 100644 index 0000000..e58f864 --- /dev/null +++ b/tests/protocol/test_deployment_bridge.py @@ -0,0 +1,101 @@ +"""Tests for :mod:`bead.protocol.deployment`.""" + +from __future__ import annotations + +from bead.deployment.distribution import ListDistributionStrategy +from bead.deployment.jspsych.config import ( + ChoiceConfig, + ExperimentConfig, + InstructionsConfig, +) +from bead.deployment.protocol_trials import protocol_to_jspsych_trials +from bead.protocol import ( + AnnotationProtocol, + ProtocolContext, + QuestionFamily, + ResponseSpace, + SemanticAnchor, +) + + +def _binary_anchor() -> SemanticAnchor: + return SemanticAnchor( + name="completion", + target_property="telicity", + canonical_prompt="Does [[situation]] reach an endpoint?", + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + required_span_labels=frozenset({"situation"}), + ) + + +def _experiment_config() -> ExperimentConfig: + return ExperimentConfig( + experiment_type="binary_choice", + title="test", + description="test", + instructions=InstructionsConfig.from_text("Click yes or no."), + distribution_strategy=ListDistributionStrategy( + strategy_type="random", + ), + ) + + +def test_protocol_to_jspsych_trials_emits_one_per_realization() -> None: + proto = AnnotationProtocol(families=[QuestionFamily(anchor=_binary_anchor())]) + contexts = [ + ProtocolContext( + sentence=f"Mary built sandcastle {i}.", + target_position=2, + target_span_text=f"built sandcastle {i}", + target_span_positions=(2, 3, 4), + ) + for i in range(3) + ] + trials = protocol_to_jspsych_trials( + proto, + contexts, + experiment_config=_experiment_config(), + judgment_type="acceptability", + choice_config=ChoiceConfig(), + ) + assert len(trials) == 3 + for trial in trials: + assert "type" in trial or "stimulus" in trial + + +def test_protocol_to_jspsych_trials_skips_non_applicable_families() -> None: + second = SemanticAnchor( + name="follow_up", + target_property="follow_up", + canonical_prompt="Did [[situation]] have a follow-up?", + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + required_span_labels=frozenset({"situation"}), + ) + proto = AnnotationProtocol( + families=[ + QuestionFamily(anchor=_binary_anchor()), + QuestionFamily( + anchor=second, + condition=( + lambda ctx: ctx.previous_responses.get("completion") == "yes" + ), + depends_on=("completion",), + ), + ] + ) + ctx = ProtocolContext( + sentence="Mary built a sandcastle.", + target_position=2, + target_span_text="built a sandcastle", + target_span_positions=(2, 3, 4), + ) + # The placeholder threading injects "no" for completion (first option), + # so follow_up's condition fails and only "completion" fires. + trials = protocol_to_jspsych_trials( + proto, + [ctx], + experiment_config=_experiment_config(), + judgment_type="acceptability", + choice_config=ChoiceConfig(), + ) + assert len(trials) == 1 diff --git a/tests/protocol/test_diagnostics.py b/tests/protocol/test_diagnostics.py new file mode 100644 index 0000000..c593d59 --- /dev/null +++ b/tests/protocol/test_diagnostics.py @@ -0,0 +1,170 @@ +"""Tests for :mod:`bead.protocol.diagnostics`.""" + +from __future__ import annotations + +from bead.protocol.anchor import ResponseSpace, SemanticAnchor +from bead.protocol.diagnostics import ( + ConditionalObservationValidator, + DatasetReport, + DiagnosticLevel, + DiagnosticRecord, + RecordLike, +) +from bead.protocol.family import AnnotationProtocol, QuestionFamily + + +def _anchor(name: str) -> SemanticAnchor: + return SemanticAnchor( + name=name, + target_property=name, + canonical_prompt=f"Question for [[situation]] ({name})?", + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + required_span_labels=frozenset({"situation"}), + ) + + +class _Record: + """Concrete RecordLike for tests.""" + + def __init__(self, item_id: str, response_label: str, question_name: str) -> None: + self.item_id = item_id + self.response_label = response_label + self.question_name = question_name + + +class TestDiagnosticLevel: + """Tests for :class:`DiagnosticLevel`.""" + + def test_str_values(self) -> None: + assert DiagnosticLevel.INFO.value == "info" + assert DiagnosticLevel.WARNING.value == "warning" + assert DiagnosticLevel.ERROR.value == "error" + + +class TestDatasetReport: + """Tests for :class:`DatasetReport`.""" + + def test_immutable_add(self) -> None: + r0 = DatasetReport(n_records_input=10) + r1 = r0.add(DiagnosticLevel.WARNING, "cat", "msg") + assert len(r0.findings) == 0 + assert len(r1.findings) == 1 + assert r1.findings[0].level == DiagnosticLevel.WARNING + assert r1.has_warnings is True + assert r1.has_errors is False + + def test_extend(self) -> None: + rec1 = DiagnosticRecord(level=DiagnosticLevel.ERROR, category="c", message="m1") + rec2 = DiagnosticRecord(level=DiagnosticLevel.INFO, category="c", message="m2") + report = DatasetReport().extend([rec1, rec2]) + assert len(report.findings) == 2 + + def test_with_coverage(self) -> None: + report = DatasetReport().with_coverage("q1", 0.95) + report = report.with_coverage("q2", 0.5) + assert report.coverage == {"q1": 0.95, "q2": 0.5} + + def test_with_missing_embedding_dedups(self) -> None: + report = DatasetReport() + report = report.with_missing_embedding("i1") + report = report.with_missing_embedding("i2") + report = report.with_missing_embedding("i1") # duplicate + assert report.items_missing_embeddings == ("i1", "i2") + + def test_filters(self) -> None: + report = ( + DatasetReport() + .add(DiagnosticLevel.WARNING, "missing", "m1") + .add(DiagnosticLevel.ERROR, "schema", "m2") + .add(DiagnosticLevel.WARNING, "schema", "m3") + ) + assert len(report.warnings) == 2 + assert len(report.errors) == 1 + assert len(report.by_category("schema")) == 2 + + def test_summary(self) -> None: + report = ( + DatasetReport( + n_records_input=10, + n_items=5, + n_records_encoded=8, + n_records_dropped=2, + ) + .with_coverage("completion", 0.8) + .add(DiagnosticLevel.WARNING, "missing", "msg") + ) + text = report.summary() + assert "5 items" in text + assert "completion: 80.0%" in text + assert "warnings: 1" in text + + +class TestRecordLike: + """Tests for :class:`RecordLike` Protocol.""" + + def test_record_conforms(self) -> None: + rec = _Record(item_id="i1", response_label="yes", question_name="q1") + assert isinstance(rec, RecordLike) + + +class TestConditionalObservationValidator: + """Tests for :class:`ConditionalObservationValidator`.""" + + def _protocol(self) -> AnnotationProtocol: + change = QuestionFamily(anchor=_anchor("change")) + uniformity = QuestionFamily( + anchor=_anchor("uniformity"), + depends_on=("change",), + ) + return AnnotationProtocol(families=[change, uniformity]) + + def test_passes_when_dependency_present(self) -> None: + proto = self._protocol() + records = { + "change": [_Record("i1", "yes", "change")], + "uniformity": [_Record("i1", "yes", "uniformity")], + } + validator = ConditionalObservationValidator() + findings = validator.validate(records, proto) + assert findings == () + + def test_warns_on_missing_dependency(self) -> None: + proto = self._protocol() + records = { + "uniformity": [_Record("i1", "yes", "uniformity")], + } + validator = ConditionalObservationValidator() + findings = validator.validate(records, proto) + assert len(findings) == 1 + assert findings[0].category == "conditional_missing_dependency" + assert findings[0].item_id == "i1" + + def test_warns_on_inapplicable_value(self) -> None: + proto = self._protocol() + records = { + "change": [_Record("i1", "no", "change")], + "uniformity": [_Record("i1", "yes", "uniformity")], + } + validator = ConditionalObservationValidator( + conditioning_values={"uniformity": {"yes"}}, + ) + findings = validator.validate(records, proto) + assert len(findings) == 1 + assert findings[0].category == "conditional_inapplicable" + + def test_skips_unconditional_families(self) -> None: + proto = AnnotationProtocol(families=[QuestionFamily(anchor=_anchor("solo"))]) + records = { + "solo": [_Record("i1", "yes", "solo")], + } + validator = ConditionalObservationValidator() + assert validator.validate(records, proto) == () + + +def test_dataset_report_round_trip_through_with() -> None: + """``with_`` preserves identity (UUID) but allows attribute updates.""" + r0 = DatasetReport(n_records_input=5) + r1 = r0.with_(n_records_input=10) + assert r0.n_records_input == 5 + assert r1.n_records_input == 10 + assert r0.id == r1.id diff --git a/tests/protocol/test_drift.py b/tests/protocol/test_drift.py new file mode 100644 index 0000000..c62bddf --- /dev/null +++ b/tests/protocol/test_drift.py @@ -0,0 +1,257 @@ +"""Tests for :mod:`bead.protocol.drift`.""" + +from __future__ import annotations + +from collections.abc import Sequence + +import pytest + +from bead.protocol.anchor import ResponseSpace, SemanticAnchor +from bead.protocol.context import ProtocolContext +from bead.protocol.drift import ( + DriftGuard, + DriftScore, + DriftValidator, + EmbeddingAdapter, + EmbeddingDriftValidator, + PerplexityAdapter, + PerplexityDriftValidator, + StructuralDriftValidator, +) + + +def _anchor( + *, + required_span_labels: frozenset[str] = frozenset({"situation"}), + required_keywords: frozenset[str] = frozenset(), + embedding_center: tuple[float, ...] | None = None, + max_drift: float = 0.3, +) -> SemanticAnchor: + return SemanticAnchor( + name="completion", + target_property="telicity", + canonical_prompt="Does [[situation]] reach an endpoint?", + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + required_span_labels=required_span_labels, + required_keywords=required_keywords, + embedding_center=embedding_center, + max_drift=max_drift, + ) + + +class TestStructuralDriftValidator: + """Tests for :class:`StructuralDriftValidator`.""" + + def test_passes_well_formed(self) -> None: + validator = StructuralDriftValidator() + score = validator.validate( + "Does [[situation]] end with an endpoint?", + _anchor(required_keywords=frozenset({"endpoint"})), + ProtocolContext(), + ) + assert score.passed is True + assert score.findings == () + + def test_missing_span_label(self) -> None: + validator = StructuralDriftValidator() + score = validator.validate( + "Does it reach an endpoint?", + _anchor(), + ProtocolContext(), + ) + assert score.passed is False + assert any("[[situation]]" in f for f in score.findings) + + def test_missing_keyword_case_insensitive(self) -> None: + validator = StructuralDriftValidator(keyword_case_sensitive=False) + score = validator.validate( + "Does [[situation]] reach a stopping point?", + _anchor(required_keywords=frozenset({"Endpoint"})), + ProtocolContext(), + ) + assert score.passed is False + assert any("endpoint" in f.lower() for f in score.findings) + + def test_missing_question_mark(self) -> None: + validator = StructuralDriftValidator() + score = validator.validate( + "Does [[situation]] reach an endpoint", + _anchor(), + ProtocolContext(), + ) + assert score.passed is False + assert any("'?'" in f for f in score.findings) + + def test_too_short(self) -> None: + validator = StructuralDriftValidator(min_length=30) + score = validator.validate( + "Short [[situation]]?", + _anchor(), + ProtocolContext(), + ) + assert score.passed is False + assert any("too short" in f for f in score.findings) + + def test_label_with_transform_recognized(self) -> None: + validator = StructuralDriftValidator() + score = validator.validate( + "Did [[situation|gerund]] reach completion?", + _anchor(), + ProtocolContext(), + ) + assert score.passed is True + + +class _StubAdapter: + """Stub adapter exposing get_embedding and compute_perplexity. + + Conforms to :class:`EmbeddingAdapter` and :class:`PerplexityAdapter`. + """ + + def __init__( + self, + *, + embed_map: dict[str, tuple[float, ...]] | None = None, + default_embedding: tuple[float, ...] = (1.0, 0.0, 0.0), + perplexity: float = 30.0, + ) -> None: + self._embed_map = embed_map or {} + self._default = default_embedding + self._perplexity = perplexity + + def get_embedding(self, text: str) -> Sequence[float]: + return self._embed_map.get(text, self._default) + + def compute_perplexity(self, text: str) -> float: + del text # unused: stub returns a fixed value + return self._perplexity + + +class TestEmbeddingDriftValidator: + """Tests for :class:`EmbeddingDriftValidator`.""" + + def test_stub_adapter_conforms_to_protocol(self) -> None: + adapter = _StubAdapter() + assert isinstance(adapter, EmbeddingAdapter) + assert isinstance(adapter, PerplexityAdapter) + + def test_passes_under_max_drift(self) -> None: + adapter = _StubAdapter( + embed_map={ + "Does [[situation]] reach an endpoint?": (1.0, 0.0, 0.0), + "Did [[situation]] finish?": (0.99, 0.05, 0.0), + } + ) + validator = EmbeddingDriftValidator(adapter) + anchor = _anchor() + score = validator.validate( + "Did [[situation]] finish?", anchor, ProtocolContext() + ) + assert score.passed is True + assert score.embedding_distance is not None + assert score.embedding_distance < 0.3 + + def test_fails_over_max_drift(self) -> None: + adapter = _StubAdapter( + embed_map={ + "Does [[situation]] reach an endpoint?": (1.0, 0.0, 0.0), + "Banana cake": (0.0, 1.0, 0.0), + } + ) + validator = EmbeddingDriftValidator(adapter, max_distance=0.1) + anchor = _anchor() + score = validator.validate("Banana cake", anchor, ProtocolContext()) + assert score.passed is False + assert any("Embedding distance" in f for f in score.findings) + + def test_uses_anchor_embedding_center_when_present(self) -> None: + # When the anchor carries a center, the adapter is not asked + # to embed the canonical prompt. + seen: list[str] = [] + + class TrackingAdapter(_StubAdapter): + def get_embedding(self, text: str) -> Sequence[float]: + seen.append(text) + return super().get_embedding(text) + + adapter = TrackingAdapter( + embed_map={"realization": (1.0, 0.0, 0.0)}, + ) + anchor = _anchor(embedding_center=(1.0, 0.0, 0.0)) + validator = EmbeddingDriftValidator(adapter) + score = validator.validate("realization", anchor, ProtocolContext()) + assert score.passed is True + assert seen == ["realization"] + + +class TestPerplexityDriftValidator: + """Tests for :class:`PerplexityDriftValidator`.""" + + def test_passes_within_ceiling(self) -> None: + adapter = _StubAdapter(perplexity=30.0) + validator = PerplexityDriftValidator(adapter, max_perplexity=50.0) + score = validator.validate( + "Did [[situation]] finish?", _anchor(), ProtocolContext() + ) + assert score.passed is True + assert score.perplexity == pytest.approx(30.0) + + def test_fails_when_exceeds_ceiling(self) -> None: + adapter = _StubAdapter(perplexity=200.0) + validator = PerplexityDriftValidator(adapter, max_perplexity=50.0) + score = validator.validate("Garbled output", _anchor(), ProtocolContext()) + assert score.passed is False + assert any("Perplexity" in f for f in score.findings) + + def test_invalid_max_perplexity(self) -> None: + adapter = _StubAdapter() + with pytest.raises(ValueError, match="positive"): + PerplexityDriftValidator(adapter, max_perplexity=0.0) + + +class TestDriftGuard: + """Tests for :class:`DriftGuard`.""" + + def test_empty_guard_always_passes(self) -> None: + guard = DriftGuard() + score = guard.check( + "any string?", + _anchor(required_span_labels=frozenset()), + ProtocolContext(), + ) + assert score.passed is True + + def test_aggregates_findings(self) -> None: + guard = DriftGuard() + guard.add(StructuralDriftValidator()) + adapter = _StubAdapter( + embed_map={ + "Does [[situation]] reach an endpoint?": (1.0, 0.0, 0.0), + "Bad [[situation]]?": (0.0, 1.0, 0.0), + } + ) + guard.add(EmbeddingDriftValidator(adapter, max_distance=0.1)) + score = guard.check( + "Bad [[situation]]?", + _anchor(), + ProtocolContext(), + ) + # Embedding fails, structural passes + assert score.passed is False + assert score.structural_ok is True + assert score.embedding_distance is not None + + def test_drift_validator_protocol(self) -> None: + validator: DriftValidator = StructuralDriftValidator() + score = validator.validate( + "Does [[situation]] reach an endpoint?", + _anchor(), + ProtocolContext(), + ) + assert isinstance(score, DriftScore) + + def test_len(self) -> None: + guard = DriftGuard() + guard.add(StructuralDriftValidator()) + guard.add(StructuralDriftValidator()) + assert len(guard) == 2 diff --git a/tests/protocol/test_encoding.py b/tests/protocol/test_encoding.py new file mode 100644 index 0000000..17b7f56 --- /dev/null +++ b/tests/protocol/test_encoding.py @@ -0,0 +1,127 @@ +"""Tests for :mod:`bead.protocol.encoding`.""" + +from __future__ import annotations + +import pytest + +from bead.protocol.anchor import ResponseSpace, SemanticPoles +from bead.protocol.encoding import ( + ResponseEncoding, + ScaleType, + encode_response_space, +) + + +class TestScaleType: + """Tests for :class:`ScaleType`.""" + + def test_str_values(self) -> None: + assert ScaleType.BINARY.value == "binary" + assert ScaleType.ORDINAL.value == "ordinal" + assert ScaleType.NOMINAL.value == "nominal" + + +class TestResponseEncoding: + """Tests for :class:`ResponseEncoding`.""" + + def _build(self) -> ResponseEncoding: + return ResponseEncoding( + name="completion", + n_levels=5, + scale_type=ScaleType.ORDINAL, + labels=( + "definitely no", + "probably no", + "unsure", + "probably yes", + "definitely yes", + ), + semantic_poles=SemanticPoles(low="definitely no", high="definitely yes"), + ) + + def test_label_index_round_trip(self) -> None: + enc = self._build() + for i, label in enumerate(enc.labels): + assert enc.label_to_index(label) == i + assert enc.index_to_label(i) == label + + def test_label_to_index_unknown_raises(self) -> None: + enc = self._build() + with pytest.raises(ValueError, match="not found"): + enc.label_to_index("absent") + + def test_index_out_of_range(self) -> None: + enc = self._build() + with pytest.raises(IndexError): + enc.index_to_label(-1) + with pytest.raises(IndexError): + enc.index_to_label(5) + + def test_scale_predicates(self) -> None: + enc = self._build() + assert enc.is_ordinal is True + assert enc.is_binary is False + assert enc.is_nominal is False + + def test_n_levels_must_match_labels(self) -> None: + with pytest.raises(Exception, match="n_levels"): + ResponseEncoding( + name="bad", + n_levels=3, + scale_type=ScaleType.NOMINAL, + labels=("a", "b"), + ) + + def test_duplicate_labels_rejected(self) -> None: + with pytest.raises(Exception, match="Duplicate"): + ResponseEncoding( + name="dup", + n_levels=3, + scale_type=ScaleType.NOMINAL, + labels=("a", "b", "a"), + ) + + def test_binary_must_have_two_levels(self) -> None: + with pytest.raises(Exception, match="BINARY"): + ResponseEncoding( + name="b", + n_levels=3, + scale_type=ScaleType.BINARY, + labels=("a", "b", "c"), + ) + + +class TestEncodeResponseSpace: + """Tests for :func:`encode_response_space`.""" + + def test_binary_classification(self) -> None: + rs = ResponseSpace(options=("no", "yes"), is_ordered=False) + enc = encode_response_space("dynamicity", rs) + assert enc.scale_type == ScaleType.BINARY + assert enc.n_levels == 2 + assert enc.is_binary is True + + def test_ordinal_classification(self) -> None: + rs = ResponseSpace( + options=("low", "med", "high"), + is_ordered=True, + semantic_poles=SemanticPoles(low="low", high="high"), + ) + enc = encode_response_space("intensity", rs) + assert enc.scale_type == ScaleType.ORDINAL + assert enc.n_levels == 3 + assert enc.semantic_poles is not None + assert enc.semantic_poles.as_tuple() == ("low", "high") + + def test_nominal_classification(self) -> None: + rs = ResponseSpace(options=("a", "b", "c"), is_ordered=False) + enc = encode_response_space("category", rs) + assert enc.scale_type == ScaleType.NOMINAL + assert enc.is_nominal is True + + def test_two_options_ordered_is_ordinal_not_binary(self) -> None: + # A two-option *ordered* space is ordinal, only unordered + # two-option spaces are classified as binary. + rs = ResponseSpace(options=("low", "high"), is_ordered=True) + enc = encode_response_space("polarity", rs) + assert enc.scale_type == ScaleType.ORDINAL diff --git a/tests/protocol/test_end_to_end.py b/tests/protocol/test_end_to_end.py new file mode 100644 index 0000000..ca93356 --- /dev/null +++ b/tests/protocol/test_end_to_end.py @@ -0,0 +1,289 @@ +"""End-to-end integration test of the protocol layer. + +Builds a three-question protocol with a conditional family, realizes +through every strategy with a stub LM, validates via a composite +:class:`DriftGuard`, runs reliability metrics over simulated +responses, and verifies the resulting :class:`DatasetReport` is +well-formed. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from bead.evaluation.reliability import ( + AnnotationRecord, + annotator_reliability, + low_entropy_annotators, +) +from bead.protocol import ( + AnnotationProtocol, + ConditionalObservationValidator, + ContextItem, + ContextualTemplateRealization, + DatasetReport, + DiagnosticLevel, + DriftGuard, + EmbeddingDriftValidator, + LMRealization, + PerplexityDriftValidator, + ProtocolContext, + QuestionFamily, + ResponseSpace, + SemanticAnchor, + StructuralDriftValidator, + TemplateRealization, + TemplateVariant, + encode_response_space, +) +from bead.protocol.anchor import SemanticPoles + + +class _StubLMClient: + def __init__(self) -> None: + self.calls = 0 + + def complete( + self, + prompt: str, + *, + temperature: float, + max_tokens: int, + ) -> str: + del prompt, temperature, max_tokens + self.calls += 1 + return "Does anything change in [[situation]] that has an endpoint?" + + +class _StubAdapter: + def get_embedding(self, text: str) -> Sequence[float]: + # Two-cluster deterministic embedding: texts containing + # "endpoint" map to one direction, everything else to the + # orthogonal direction. The change anchor uses "changing", so + # its canonical and realizations share the (0, 1, 0) cluster; + # the completion and uniformity anchors share the (1, 0, 0) + # cluster via "endpoint" / "moments". + if "endpoint" in text or "moments" in text: + return (1.0, 0.0, 0.0) + return (0.0, 1.0, 0.0) + + def compute_perplexity(self, text: str) -> float: + del text + return 25.0 + + +def _build_anchors() -> tuple[SemanticAnchor, SemanticAnchor, SemanticAnchor]: + binary = ResponseSpace(options=("no", "yes"), is_ordered=False) + likert = ResponseSpace( + options=( + "definitely no", + "probably no", + "unsure", + "probably yes", + "definitely yes", + ), + is_ordered=True, + semantic_poles=SemanticPoles( + low="definitely no", + high="definitely yes", + ), + ) + + change = SemanticAnchor( + name="change", + target_property="dynamicity", + canonical_prompt="Is anything changing in [[situation]] over time?", + response_space=binary, + required_span_labels=frozenset({"situation"}), + required_keywords=frozenset({"changing"}), + ) + completion = SemanticAnchor( + name="completion", + target_property="telicity", + canonical_prompt=("Does [[situation]] reach a definite endpoint?"), + response_space=likert, + required_span_labels=frozenset({"situation"}), + required_keywords=frozenset({"endpoint"}), + ) + uniformity = SemanticAnchor( + name="uniformity", + target_property="homogeneity", + canonical_prompt=( + "Are different moments of [[situation]] qualitatively similar?" + ), + response_space=binary, + required_span_labels=frozenset({"situation"}), + required_keywords=frozenset({"moments"}), + ) + return change, completion, uniformity + + +def test_protocol_end_to_end() -> None: + change_anchor, completion_anchor, uniformity_anchor = _build_anchors() + + adapter = _StubAdapter() + guard = DriftGuard( + validators=[ + StructuralDriftValidator(), + EmbeddingDriftValidator(adapter, max_distance=0.5), + PerplexityDriftValidator(adapter, max_perplexity=100.0), + ] + ) + + # Family A: contextual templates by target UPOS + contextual = ContextualTemplateRealization( + variants=( + TemplateVariant( + template=( + "Does anything happen during [[situation]] that is changing?" + ), + condition=lambda ctx: ctx.target_upos == "VERB", + priority=10, + ), + TemplateVariant( + template=("Is [[situation]] something that is changing in any way?"), + priority=0, + ), + ), + ) + family_change = QuestionFamily( + anchor=change_anchor, + realization=contextual, + drift_guard=guard, + ) + + # Family B: LM realization, conditional on change=='yes' + family_completion = QuestionFamily( + anchor=completion_anchor, + realization=LMRealization(_StubLMClient(), model_name="stub-lm"), + drift_guard=guard, + condition=lambda ctx: ctx.previous_responses.get("change") == "yes", + depends_on=("change",), + ) + + # Family C: plain template, conditional on change=='yes' + family_uniformity = QuestionFamily( + anchor=uniformity_anchor, + realization=TemplateRealization(), + drift_guard=guard, + condition=lambda ctx: ctx.previous_responses.get("change") == "yes", + depends_on=("change",), + ) + + protocol = AnnotationProtocol( + families=[family_change, family_completion, family_uniformity], + name="aspect-protocol", + ) + + ctx = ProtocolContext( + sentence="Mary built a sandcastle.", + target_form="built", + target_lemma="build", + target_upos="VERB", + target_position=2, + target_span_text="built a sandcastle", + target_span_positions=(2, 3, 4), + dependents=( + ContextItem( + head_lemma="Mary", + head_form="Mary", + head_upos="PROPN", + head_position=1, + span_text="Mary", + ), + ContextItem( + head_lemma="sandcastle", + head_form="sandcastle", + head_upos="NOUN", + head_position=4, + span_text="a sandcastle", + attributes={"definiteness": 0.0}, + ), + ), + ) + + # All three fire when 'change' was answered 'yes'. + realizations = protocol.realize_all(ctx, responses={"change": "yes"}) + assert [r.anchor.name for r in realizations] == [ + "change", + "completion", + "uniformity", + ] + for r in realizations: + assert r.passed_drift_check, r.drift_score + + # Only 'change' fires when no responses are pre-supplied (placeholder + # for change is its first option, "no", which fails both conditions). + only_change = protocol.realize_all(ctx) + assert [r.anchor.name for r in only_change] == ["change"] + + # Encoding round-trip + encoding = encode_response_space("change", change_anchor.response_space) + assert encoding.is_binary + encoding2 = encode_response_space("completion", completion_anchor.response_space) + assert encoding2.is_ordinal + assert encoding2.n_levels == 5 + + # Reliability over simulated responses + records = [ + AnnotationRecord( + annotator_id="a1", + item_id="i1", + question_name="change", + response_label="yes", + ), + AnnotationRecord( + annotator_id="a1", + item_id="i2", + question_name="change", + response_label="no", + ), + AnnotationRecord( + annotator_id="a2", + item_id="i1", + question_name="change", + response_label="yes", + ), + AnnotationRecord( + annotator_id="a2", + item_id="i2", + question_name="change", + response_label="yes", + ), + ] + profiles = annotator_reliability(records) + assert len(profiles) == 2 + flagged = low_entropy_annotators(profiles, threshold=0.5) + assert flagged == ("a2",) + + # Conditional dependency check: completion has a response for an + # item that lacks a 'change' response, which should warn. + cond_records = { + "completion": [ + AnnotationRecord( + annotator_id="a1", + item_id="i_orphan", + question_name="completion", + response_label="probably yes", + ), + ], + } + cond_validator = ConditionalObservationValidator() + findings = cond_validator.validate(cond_records, protocol) + assert len(findings) == 1 + assert findings[0].category == "conditional_missing_dependency" + + # Build a final report + report = ( + DatasetReport( + n_records_input=len(records), + n_items=2, + n_records_encoded=len(records), + ) + .with_coverage("change", 1.0) + .extend(findings) + .add(DiagnosticLevel.INFO, "summary", "end-to-end test complete") + ) + summary = report.summary() + assert "2 items" in summary + assert "warnings" in summary diff --git a/tests/protocol/test_family.py b/tests/protocol/test_family.py new file mode 100644 index 0000000..98f984a --- /dev/null +++ b/tests/protocol/test_family.py @@ -0,0 +1,206 @@ +"""Tests for :mod:`bead.protocol.family`.""" + +from __future__ import annotations + +import pytest + +from bead.protocol.anchor import ResponseSpace, SemanticAnchor +from bead.protocol.context import ProtocolContext +from bead.protocol.drift import DriftGuard, StructuralDriftValidator +from bead.protocol.family import ( + AnnotationProtocol, + QuestionFamily, + QuestionRealization, +) +from bead.protocol.realization import TemplateRealization + + +def _anchor( + name: str, + *, + canonical: str = "Does [[situation]] reach an endpoint?", +) -> SemanticAnchor: + return SemanticAnchor( + name=name, + target_property=name, + canonical_prompt=canonical, + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + required_span_labels=frozenset({"situation"}), + ) + + +class TestQuestionFamily: + """Tests for :class:`QuestionFamily`.""" + + def test_default_realization_uses_canonical(self) -> None: + family = QuestionFamily(anchor=_anchor("completion")) + ctx = ProtocolContext() + result = family.realize(ctx) + assert isinstance(result, QuestionRealization) + assert result.prompt == family.anchor.canonical_prompt + assert result.passed_drift_check is True + assert result.strategy_name == "TemplateRealization" + + def test_drift_failure_falls_back(self) -> None: + anchor = _anchor("completion") + family = QuestionFamily( + anchor=anchor, + realization=TemplateRealization(template="Bad realization."), + drift_guard=DriftGuard(validators=[StructuralDriftValidator()]), + fallback_on_drift=True, + ) + result = family.realize(ProtocolContext()) + # Fell back to canonical, which has the [[situation]] tag + assert result.prompt == anchor.canonical_prompt + assert "fallback" in result.strategy_name + + def test_drift_failure_raises_when_no_fallback(self) -> None: + anchor = _anchor("completion") + family = QuestionFamily( + anchor=anchor, + realization=TemplateRealization(template="Bad realization."), + drift_guard=DriftGuard(validators=[StructuralDriftValidator()]), + fallback_on_drift=False, + ) + with pytest.raises(ValueError, match="Drift validation failed"): + family.realize(ProtocolContext()) + + def test_is_always_applicable_default(self) -> None: + family = QuestionFamily(anchor=_anchor("a")) + assert family.is_always_applicable is True + + def test_explicit_condition_marks_conditional(self) -> None: + family = QuestionFamily( + anchor=_anchor("a"), + condition=lambda ctx: ctx.target_upos == "VERB", + ) + assert family.is_always_applicable is False + assert family.is_applicable(ProtocolContext(target_upos="VERB")) is True + assert family.is_applicable(ProtocolContext(target_upos="NOUN")) is False + + def test_depends_on_recorded(self) -> None: + family = QuestionFamily( + anchor=_anchor("uniformity"), + depends_on=("change",), + ) + assert family.depends_on == ("change",) + + +class TestAnnotationProtocol: + """Tests for :class:`AnnotationProtocol`.""" + + def test_construction_records_families(self) -> None: + a = QuestionFamily(anchor=_anchor("a")) + b = QuestionFamily(anchor=_anchor("b")) + proto = AnnotationProtocol(families=[a, b], name="demo") + assert len(proto) == 2 + assert proto.name == "demo" + + def test_duplicate_anchor_names_rejected(self) -> None: + a = QuestionFamily(anchor=_anchor("dup")) + b = QuestionFamily(anchor=_anchor("dup")) + with pytest.raises(ValueError, match="Duplicate"): + AnnotationProtocol(families=[a, b]) + + def test_append_rejects_duplicate(self) -> None: + a = QuestionFamily(anchor=_anchor("a")) + proto = AnnotationProtocol(families=[a]) + with pytest.raises(ValueError, match="Duplicate"): + proto.append(QuestionFamily(anchor=_anchor("a"))) + + def test_family_by_name_lookup(self) -> None: + a = QuestionFamily(anchor=_anchor("a")) + b = QuestionFamily(anchor=_anchor("b")) + proto = AnnotationProtocol(families=[a, b]) + assert proto.family_by_name("b") is b + with pytest.raises(KeyError): + proto.family_by_name("missing") + + def test_realize_all_threads_responses(self) -> None: + """Second family conditioned on the first's response.""" + first = QuestionFamily(anchor=_anchor("change")) + + def is_dynamic(ctx: ProtocolContext) -> bool: + return ctx.previous_responses.get("change") == "yes" + + second = QuestionFamily( + anchor=_anchor("uniformity"), + condition=is_dynamic, + depends_on=("change",), + ) + proto = AnnotationProtocol(families=[first, second]) + + # With responses={'change': 'yes'} both questions fire + results = proto.realize_all( + ProtocolContext(), + responses={"change": "yes"}, + ) + assert [r.anchor.name for r in results] == ["change", "uniformity"] + + # Without an explicit response, the placeholder is the first + # option (`"no"`), so the second family's condition is false. + results2 = proto.realize_all(ProtocolContext()) + assert [r.anchor.name for r in results2] == ["change"] + + def test_realize_all_rejects_unknown_response(self) -> None: + proto = AnnotationProtocol(families=[QuestionFamily(anchor=_anchor("a"))]) + with pytest.raises(ValueError, match="unknown anchors"): + proto.realize_all(ProtocolContext(), responses={"missing": "yes"}) + + def test_self_dependency_rejected_at_construction(self) -> None: + with pytest.raises(ValueError, match="depends on itself"): + AnnotationProtocol( + families=[ + QuestionFamily( + anchor=_anchor("a"), + depends_on=("a",), + ), + ], + ) + + def test_forward_dependency_rejected_at_construction(self) -> None: + with pytest.raises(ValueError, match="not earlier"): + AnnotationProtocol( + families=[ + QuestionFamily( + anchor=_anchor("a"), + depends_on=("b",), + ), + QuestionFamily(anchor=_anchor("b")), + ], + ) + + def test_unknown_dependency_rejected_at_construction(self) -> None: + with pytest.raises(ValueError, match="not earlier"): + AnnotationProtocol( + families=[ + QuestionFamily( + anchor=_anchor("a"), + depends_on=("ghost",), + ), + ], + ) + + def test_append_rejects_unknown_dependency(self) -> None: + proto = AnnotationProtocol( + families=[QuestionFamily(anchor=_anchor("first"))], + ) + with pytest.raises(ValueError, match="not in the protocol"): + proto.append( + QuestionFamily( + anchor=_anchor("second"), + depends_on=("ghost",), + ), + ) + + def test_append_rejects_self_dependency(self) -> None: + proto = AnnotationProtocol( + families=[QuestionFamily(anchor=_anchor("first"))], + ) + with pytest.raises(ValueError, match="depends on itself"): + proto.append( + QuestionFamily( + anchor=_anchor("second"), + depends_on=("second",), + ), + ) diff --git a/tests/protocol/test_items_bridge.py b/tests/protocol/test_items_bridge.py new file mode 100644 index 0000000..ae2ec94 --- /dev/null +++ b/tests/protocol/test_items_bridge.py @@ -0,0 +1,233 @@ +"""Tests for :mod:`bead.protocol.items`.""" + +from __future__ import annotations + +import pytest + +from bead.items.item_template import ItemTemplate, PresentationSpec +from bead.protocol import ( + AnnotationProtocol, + ContextItem, + ProtocolContext, + QuestionFamily, + ResponseSpace, + ScaleType, + SemanticAnchor, + family_to_item_template, + protocol_to_item_templates, + realization_to_item, + realize_protocol_to_items, + scale_type_to_task_type, +) + + +class TestScaleTypeToTaskType: + """Tests for :func:`scale_type_to_task_type`.""" + + def test_binary_maps(self) -> None: + assert scale_type_to_task_type(ScaleType.BINARY) == "binary" + + def test_ordinal_maps(self) -> None: + assert scale_type_to_task_type(ScaleType.ORDINAL) == "ordinal_scale" + + def test_nominal_maps(self) -> None: + assert scale_type_to_task_type(ScaleType.NOMINAL) == "categorical" + + +def _build_binary_anchor() -> SemanticAnchor: + return SemanticAnchor( + name="completion", + target_property="telicity", + canonical_prompt="Does [[situation]] reach an endpoint?", + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + required_span_labels=frozenset({"situation"}), + required_keywords=frozenset({"endpoint"}), + description="Telicity probe.", + ) + + +def _build_ordinal_anchor() -> SemanticAnchor: + return SemanticAnchor( + name="confidence", + target_property="confidence", + canonical_prompt="How confident is [[situation]]?", + response_space=ResponseSpace( + options=("low", "medium", "high"), + is_ordered=True, + ), + required_span_labels=frozenset({"situation"}), + ) + + +class TestFamilyToItemTemplate: + """Tests for :func:`family_to_item_template`.""" + + def test_binary_family(self) -> None: + family = QuestionFamily(anchor=_build_binary_anchor()) + template = family_to_item_template(family, judgment_type="acceptability") + assert isinstance(template, ItemTemplate) + assert template.task_type == "binary" + assert template.task_spec.options == ("no", "yes") + assert template.task_spec.scale_bounds is None + assert template.task_spec.scale_labels == () + # Two elements: text + prompt + assert {e.element_name for e in template.elements} == {"text", "prompt"} + + def test_ordinal_family(self) -> None: + family = QuestionFamily(anchor=_build_ordinal_anchor()) + template = family_to_item_template(family, judgment_type="acceptability") + assert template.task_type == "ordinal_scale" + assert template.task_spec.scale_bounds is not None + assert template.task_spec.scale_bounds.min == 0 + assert template.task_spec.scale_bounds.max == 2 + assert len(template.task_spec.scale_labels) == 3 + labels = {(p.point, p.label) for p in template.task_spec.scale_labels} + assert labels == {(0, "low"), (1, "medium"), (2, "high")} + + def test_custom_presentation_spec(self) -> None: + family = QuestionFamily(anchor=_build_binary_anchor()) + spec = PresentationSpec(mode="self_paced") + template = family_to_item_template( + family, + judgment_type="comprehension", + presentation_spec=spec, + ) + assert template.presentation_spec.mode == "self_paced" + + def test_judgment_type_propagates(self) -> None: + family = QuestionFamily(anchor=_build_binary_anchor()) + template = family_to_item_template(family, judgment_type="inference") + assert template.judgment_type == "inference" + + +class TestRealizationToItem: + """Tests for :func:`realization_to_item`.""" + + def test_basic_round_trip(self) -> None: + family = QuestionFamily(anchor=_build_binary_anchor()) + template = family_to_item_template(family, judgment_type="acceptability") + ctx = ProtocolContext( + sentence="Mary built a sandcastle.", + tokens=("Mary", "built", "a", "sandcastle", "."), + target_position=2, + target_span_text="built a sandcastle", + target_span_positions=(2, 3, 4), + ) + realization = family.realize(ctx) + item = realization_to_item(realization, item_template=template) + assert item.item_template_id == template.id + assert item.rendered_elements["text"] == "Mary built a sandcastle." + assert item.rendered_elements["prompt"] == realization.prompt + assert item.tokenized_elements["text"] == ctx.tokens + # The required_span_label "situation" yields a span anchored + # to the target's positions (translated to 0-based indexing). + assert len(item.spans) == 1 + span = item.spans[0] + assert span.label is not None + assert span.label.label == "situation" + assert span.segments[0].element_name == "text" + assert span.segments[0].indices == (1, 2, 3) + assert span.head_index == 1 + + def test_no_required_labels_no_spans(self) -> None: + anchor = SemanticAnchor( + name="dummy", + target_property="dummy", + canonical_prompt="Question?", + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + ) + family = QuestionFamily(anchor=anchor) + template = family_to_item_template(family, judgment_type="acceptability") + ctx = ProtocolContext(sentence="Plain text.") + realization = family.realize(ctx) + item = realization_to_item(realization, item_template=template) + assert item.spans == () + + def test_dependent_span_picked_when_label_matches_lemma(self) -> None: + anchor = SemanticAnchor( + name="distrib", + target_property="distributivity", + canonical_prompt=( + "Did [[situation]] involve [[participant]] one at a time?" + ), + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + required_span_labels=frozenset({"situation", "participant"}), + ) + family = QuestionFamily(anchor=anchor) + template = family_to_item_template(family, judgment_type="acceptability") + ctx = ProtocolContext( + sentence="The kids ran.", + tokens=("The", "kids", "ran", "."), + target_position=3, + target_span_text="ran", + target_span_positions=(3,), + dependents=( + ContextItem( + head_lemma="participant", # matches required label + head_position=2, + span_text="The kids", + span_positions=(1, 2), + ), + ), + ) + realization = family.realize(ctx) + item = realization_to_item(realization, item_template=template) + # Two spans; one for each required label + labels_to_indices = { + s.label.label: s.segments[0].indices + for s in item.spans + if s.label is not None + } + # situation → target span (3) → 0-based index 2 + assert labels_to_indices["situation"] == (2,) + # participant → dependent span (1, 2) → 0-based (0, 1) + assert labels_to_indices["participant"] == (0, 1) + + +class TestProtocolToItemTemplates: + """Tests for :func:`protocol_to_item_templates` and :func:`realize_protocol_to_items`.""" + + def test_protocol_to_templates(self) -> None: + proto = AnnotationProtocol( + families=[ + QuestionFamily(anchor=_build_binary_anchor()), + QuestionFamily(anchor=_build_ordinal_anchor()), + ] + ) + templates = protocol_to_item_templates(proto, judgment_type="acceptability") + assert set(templates) == {"completion", "confidence"} + assert templates["completion"].task_type == "binary" + assert templates["confidence"].task_type == "ordinal_scale" + + def test_realize_protocol_to_items(self) -> None: + proto = AnnotationProtocol( + families=[QuestionFamily(anchor=_build_binary_anchor())] + ) + ctx = ProtocolContext( + sentence="Mary built a sandcastle.", + target_span_text="built a sandcastle", + target_span_positions=(2, 3, 4), + ) + pairs = realize_protocol_to_items(proto, ctx, judgment_type="acceptability") + assert len(pairs) == 1 + realization, item = pairs[0] + assert realization.anchor.name == "completion" + assert item.rendered_elements["prompt"] == realization.prompt + + +def test_unknown_template_in_protocol_realize( + _pytest_request: object | None = None, +) -> None: + """``realize_protocol_to_items`` raises if a family has no template.""" + proto = AnnotationProtocol(families=[QuestionFamily(anchor=_build_binary_anchor())]) + other = QuestionFamily(anchor=_build_ordinal_anchor()) + other_templates = { + "confidence": family_to_item_template(other, judgment_type="acceptability") + } + with pytest.raises(KeyError): + realize_protocol_to_items( + proto, + ProtocolContext(), + judgment_type="acceptability", + item_templates=other_templates, + ) diff --git a/tests/protocol/test_realization.py b/tests/protocol/test_realization.py new file mode 100644 index 0000000..57ade2f --- /dev/null +++ b/tests/protocol/test_realization.py @@ -0,0 +1,205 @@ +"""Tests for :mod:`bead.protocol.realization`.""" + +from __future__ import annotations + +import pytest + +from bead.items.cache import ModelOutputCache +from bead.protocol.anchor import ResponseSpace, SemanticAnchor +from bead.protocol.context import ProtocolContext +from bead.protocol.realization import ( + ContextualTemplateRealization, + LMRealization, + RealizationStrategy, + TemplateRealization, + TemplateVariant, + always, +) + + +def _build_anchor(canonical: str = "Does [[situation]] end?") -> SemanticAnchor: + return SemanticAnchor( + name="completion", + target_property="telicity", + canonical_prompt=canonical, + response_space=ResponseSpace(options=("no", "yes"), is_ordered=False), + required_span_labels=frozenset({"situation"}), + ) + + +class TestTemplateRealization: + """Tests for :class:`TemplateRealization`.""" + + def test_default_uses_canonical(self) -> None: + anchor = _build_anchor() + ctx = ProtocolContext() + tr = TemplateRealization() + assert tr.realize(anchor, ctx) == anchor.canonical_prompt + + def test_explicit_template_overrides(self) -> None: + anchor = _build_anchor() + ctx = ProtocolContext() + tr = TemplateRealization(template="Did [[situation]] finish?") + assert tr.realize(anchor, ctx) == "Did [[situation]] finish?" + + def test_conforms_to_protocol(self) -> None: + assert isinstance(TemplateRealization(), RealizationStrategy) + + +class TestContextualTemplateRealization: + """Tests for :class:`ContextualTemplateRealization`.""" + + def test_priority_ordering(self) -> None: + anchor = _build_anchor() + ctx = ProtocolContext(target_upos="VERB") + + def is_verb(c: ProtocolContext) -> bool: + return c.target_upos == "VERB" + + verb_variant = TemplateVariant( + template="VERB-specific [[situation]]?", + condition=is_verb, + priority=10, + ) + fallback_variant = TemplateVariant( + template="Generic [[situation]]?", + condition=always, + priority=0, + ) + ctr = ContextualTemplateRealization(variants=(fallback_variant, verb_variant)) + assert ctr.realize(anchor, ctx) == "VERB-specific [[situation]]?" + + def test_fallback_when_no_match(self) -> None: + anchor = _build_anchor() + ctx = ProtocolContext(target_upos="NOUN") + + def is_verb(c: ProtocolContext) -> bool: + return c.target_upos == "VERB" + + ctr = ContextualTemplateRealization( + variants=(TemplateVariant(template="V", condition=is_verb, priority=10),), + fallback="Custom fallback?", + ) + assert ctr.realize(anchor, ctx) == "Custom fallback?" + + def test_fallback_to_canonical_when_no_explicit_fallback(self) -> None: + anchor = _build_anchor("Canonical [[situation]]?") + ctx = ProtocolContext(target_upos="NOUN") + ctr = ContextualTemplateRealization( + variants=( + TemplateVariant( + template="V", + condition=lambda c: c.target_upos == "VERB", + priority=10, + ), + ), + ) + assert ctr.realize(anchor, ctx) == "Canonical [[situation]]?" + + +class _StubLMClient: + """Stub LM client recording every call.""" + + def __init__(self, response: str) -> None: + self.response = response + self.calls: list[tuple[str, float, int]] = [] + + def complete( + self, + prompt: str, + *, + temperature: float, + max_tokens: int, + ) -> str: + self.calls.append((prompt, temperature, max_tokens)) + return self.response + + +def _memory_cache() -> ModelOutputCache: + """Return an in-memory ModelOutputCache for use in tests.""" + return ModelOutputCache(backend="memory") + + +class TestLMRealization: + """Tests for :class:`LMRealization`.""" + + def test_realize_appends_question_mark(self) -> None: + client = _StubLMClient("What does the situation do") + lm = LMRealization(client, model_name="stub") + prompt = lm.realize(_build_anchor(), ProtocolContext()) + assert prompt.endswith("?") + + def test_realize_strips_quotes_and_whitespace(self) -> None: + client = _StubLMClient(' "Did the event end?" ') + lm = LMRealization(client, model_name="stub") + prompt = lm.realize(_build_anchor(), ProtocolContext()) + assert prompt == "Did the event end?" + + def test_caching(self) -> None: + client = _StubLMClient("Did it end?") + lm = LMRealization(client, model_name="stub", cache=_memory_cache()) + anchor = _build_anchor() + ctx = ProtocolContext(sentence="Mary ran.") + out1 = lm.realize(anchor, ctx) + out2 = lm.realize(anchor, ctx) + assert out1 == out2 + assert len(client.calls) == 1 + + def test_caching_disabled(self) -> None: + client = _StubLMClient("Did it end?") + lm = LMRealization(client, model_name="stub") + anchor = _build_anchor() + ctx = ProtocolContext() + lm.realize(anchor, ctx) + lm.realize(anchor, ctx) + assert len(client.calls) == 2 + + def test_lm_failure_wraps_runtime_error(self) -> None: + class FailingClient: + def complete( + self, + prompt: str, + *, + temperature: float, + max_tokens: int, + ) -> str: + del prompt, temperature, max_tokens + raise ConnectionError("network down") + + lm = LMRealization(FailingClient(), model_name="stub") + with pytest.raises(RuntimeError, match="LM realization failed"): + lm.realize(_build_anchor(), ProtocolContext()) + + def test_empty_response_raises(self) -> None: + client = _StubLMClient(" ") + lm = LMRealization(client, model_name="stub") + with pytest.raises(RuntimeError, match="empty response"): + lm.realize(_build_anchor(), ProtocolContext()) + + def test_quoted_empty_response_raises(self) -> None: + client = _StubLMClient(' "" ') + lm = LMRealization(client, model_name="stub") + with pytest.raises(RuntimeError, match="empty response"): + lm.realize(_build_anchor(), ProtocolContext()) + + def test_calls_pass_kwargs(self) -> None: + client = _StubLMClient("Did it end?") + lm = LMRealization(client, model_name="stub", temperature=0.5, max_tokens=128) + lm.realize(_build_anchor(), ProtocolContext()) + assert len(client.calls) == 1 + _, temperature, max_tokens = client.calls[0] + assert temperature == pytest.approx(0.5) + assert max_tokens == 128 + + def test_cache_isolated_by_model_name(self) -> None: + """Two realizations sharing a cache but different model_names + do not collide.""" + cache = _memory_cache() + client_a = _StubLMClient("Answer A?") + client_b = _StubLMClient("Answer B?") + lm_a = LMRealization(client_a, model_name="model-a", cache=cache) + lm_b = LMRealization(client_b, model_name="model-b", cache=cache) + anchor = _build_anchor() + ctx = ProtocolContext() + assert lm_a.realize(anchor, ctx) == "Answer A?" + assert lm_b.realize(anchor, ctx) == "Answer B?" diff --git a/tests/test_labels.py b/tests/test_labels.py new file mode 100644 index 0000000..e2daec9 --- /dev/null +++ b/tests/test_labels.py @@ -0,0 +1,104 @@ +"""Tests for :mod:`bead.labels`.""" + +from __future__ import annotations + +import re + +from bead.labels import ( + LABEL_PATTERN, + LabelRef, + find_label_names, + parse_label_refs, + replace_label_refs, +) + + +class TestParseLabelRefs: + """Tests for :func:`parse_label_refs`.""" + + def test_no_references(self) -> None: + assert parse_label_refs("Plain text with no refs.") == () + + def test_bare_label(self) -> None: + refs = parse_label_refs("Did [[agent]] act?") + assert len(refs) == 1 + assert refs[0].label == "agent" + assert refs[0].display_text is None + assert refs[0].transforms == () + + def test_explicit_display_text(self) -> None: + refs = parse_label_refs("Did [[event:the breaking]] happen?") + assert refs[0].label == "event" + assert refs[0].display_text == "the breaking" + assert refs[0].transforms == () + + def test_single_transform(self) -> None: + refs = parse_label_refs("Did [[situation|gerund]] happen?") + assert refs[0].label == "situation" + assert refs[0].transforms == ("gerund",) + + def test_chained_transforms(self) -> None: + refs = parse_label_refs("Did [[situation|gerund|lower]] happen?") + assert refs[0].transforms == ("gerund", "lower") + + def test_explicit_text_and_transforms(self) -> None: + refs = parse_label_refs("[[event:the running|upper]]") + assert refs[0].label == "event" + assert refs[0].display_text == "the running" + assert refs[0].transforms == ("upper",) + + def test_offsets_are_correct(self) -> None: + prompt = "x [[a]] y" + refs = parse_label_refs(prompt) + assert refs[0].start_offset == 2 + assert refs[0].end_offset == 7 + assert prompt[refs[0].start_offset : refs[0].end_offset] == "[[a]]" + + def test_multiple_refs_in_order(self) -> None: + refs = parse_label_refs("[[a]] then [[b:bee]] then [[c|x]]") + assert [r.label for r in refs] == ["a", "b", "c"] + + +class TestFindLabelNames: + """Tests for :func:`find_label_names`.""" + + def test_distinct_labels(self) -> None: + names = find_label_names("[[a]] and [[b:bee]] and [[a|gerund]] and [[c]]") + assert names == frozenset({"a", "b", "c"}) + + def test_empty_prompt(self) -> None: + assert find_label_names("") == frozenset() + + +class TestReplaceLabelRefs: + """Tests for :func:`replace_label_refs`.""" + + def test_no_refs_returns_input(self) -> None: + assert replace_label_refs("plain", lambda r: "X") == "plain" + + def test_replaces_in_order(self) -> None: + out = replace_label_refs("[[a]] [[b]]", lambda r: f"<{r.label}>") + assert out == " " + + def test_replacement_uses_explicit_text(self) -> None: + out = replace_label_refs( + "Did [[event:the running]] happen?", + lambda r: r.display_text or r.label, + ) + assert out == "Did the running happen?" + + +class TestLabelRef: + """Tests for the :class:`LabelRef` BeadBaseModel.""" + + def test_round_trip_through_with(self) -> None: + ref = LabelRef(label="x", start_offset=0, end_offset=5) + ref2 = ref.with_(label="y") + assert ref.label == "x" + assert ref2.label == "y" + assert ref.id == ref2.id + + +def test_pattern_is_compiled() -> None: + """The exported regex is a compiled pattern.""" + assert isinstance(LABEL_PATTERN, re.Pattern) diff --git a/tests/transforms/test_prompt_integration.py b/tests/transforms/test_prompt_integration.py index 13461f4..31eb0fe 100644 --- a/tests/transforms/test_prompt_integration.py +++ b/tests/transforms/test_prompt_integration.py @@ -11,78 +11,78 @@ SpanColorMap, _assign_span_colors, _build_transform_context, - _parse_prompt_references, _resolve_prompt_references, ) from bead.items.item import Item from bead.items.spans import Span, SpanLabel, SpanSegment +from bead.labels import parse_label_refs from bead.transforms.base import TransformRegistry class TestParsePromptReferencesWithTransforms: - """Tests for _parse_prompt_references() transform syntax.""" + """Tests for parse_label_refs() transform syntax.""" def test_no_transforms(self) -> None: """Plain label has empty transforms list.""" - refs = _parse_prompt_references("[[agent]]") + refs = parse_label_refs("[[agent]]") - assert refs[0].transforms == [] + assert refs[0].transforms == () def test_single_transform(self) -> None: """Single transform after pipe is captured.""" - refs = _parse_prompt_references("[[situation|gerund]]") + refs = parse_label_refs("[[situation|gerund]]") assert refs[0].label == "situation" assert refs[0].display_text is None - assert refs[0].transforms == ["gerund"] + assert refs[0].transforms == ("gerund",) def test_multiple_transforms(self) -> None: """Chained transforms are split on pipe.""" - refs = _parse_prompt_references("[[situation|gerund|lower]]") + refs = parse_label_refs("[[situation|gerund|lower]]") assert refs[0].label == "situation" - assert refs[0].transforms == ["gerund", "lower"] + assert refs[0].transforms == ("gerund", "lower") def test_explicit_text_with_transform(self) -> None: """Display text and transforms can coexist.""" - refs = _parse_prompt_references("[[event:the running|upper]]") + refs = parse_label_refs("[[event:the running|upper]]") assert refs[0].label == "event" assert refs[0].display_text == "the running" - assert refs[0].transforms == ["upper"] + assert refs[0].transforms == ("upper",) def test_backward_compatible_colon_syntax(self) -> None: """Existing [[label:text]] syntax still works.""" - refs = _parse_prompt_references("[[event:the breaking]]") + refs = parse_label_refs("[[event:the breaking]]") assert refs[0].label == "event" assert refs[0].display_text == "the breaking" - assert refs[0].transforms == [] + assert refs[0].transforms == () def test_backward_compatible_plain_label(self) -> None: """Existing [[label]] syntax still works.""" - refs = _parse_prompt_references("[[agent]]") + refs = parse_label_refs("[[agent]]") assert refs[0].label == "agent" assert refs[0].display_text is None - assert refs[0].transforms == [] + assert refs[0].transforms == () def test_mixed_references(self) -> None: """Various syntax forms in one prompt are parsed correctly.""" prompt = "Did [[agent]] do [[event|gerund]] to [[patient:the vase|upper]]?" - refs = _parse_prompt_references(prompt) + refs = parse_label_refs(prompt) assert len(refs) == 3 assert refs[0].label == "agent" - assert refs[0].transforms == [] + assert refs[0].transforms == () assert refs[1].label == "event" - assert refs[1].transforms == ["gerund"] + assert refs[1].transforms == ("gerund",) assert refs[2].label == "patient" assert refs[2].display_text == "the vase" - assert refs[2].transforms == ["upper"] + assert refs[2].transforms == ("upper",) class TestResolvePromptReferencesWithTransforms: diff --git a/uv.lock b/uv.lock index a68375b..0098559 100644 --- a/uv.lock +++ b/uv.lock @@ -157,7 +157,7 @@ wheels = [ [[package]] name = "bead" -version = "0.2.1" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "accelerate" }, @@ -208,6 +208,8 @@ dev = [ { name = "pytest-examples" }, { name = "pytest-mock" }, { name = "ruff" }, + { name = "spacy" }, + { name = "stanza" }, ] stats = [ { name = "statsmodels" }, @@ -230,7 +232,7 @@ requires-dist = [ { name = "anthropic", marker = "extra == 'api'", specifier = ">=0.8.0" }, { name = "click", specifier = ">=8.0.0" }, { name = "datasets", specifier = ">=2.14.0" }, - { name = "didactic", specifier = ">=0.4.3" }, + { name = "didactic", specifier = ">=0.6.2" }, { name = "evaluate", specifier = ">=0.4.0" }, { name = "glazing", specifier = ">=0.2.0" }, { name = "google-generativeai", marker = "extra == 'api'", specifier = ">=0.3.0" }, @@ -261,7 +263,9 @@ requires-dist = [ { name = "scipy", specifier = ">=1.11.0" }, { name = "sentence-transformers", specifier = ">=2.0.0" }, { name = "slopit", marker = "extra == 'behavioral-analysis'", specifier = ">=0.1.0" }, + { name = "spacy", marker = "extra == 'dev'", specifier = ">=3.7" }, { name = "spacy", marker = "extra == 'tokenization'", specifier = ">=3.7" }, + { name = "stanza", marker = "extra == 'dev'", specifier = ">=1.8" }, { name = "stanza", marker = "extra == 'tokenization'", specifier = ">=1.8" }, { name = "statsmodels", specifier = ">=0.14.6" }, { name = "statsmodels", marker = "extra == 'stats'", specifier = ">=0.14.0" }, @@ -507,15 +511,15 @@ wheels = [ [[package]] name = "didactic" -version = "0.4.3" +version = "0.6.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-types" }, { name = "panproto" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2b/04/aad0b04f350050f847fe9609e422456e73f48fbf7a10429d2090079f01d6/didactic-0.4.3.tar.gz", hash = "sha256:71c0988b5da6dd2965e46d279c1774cd2abf64373ff4c45c32f595f0699eb434", size = 100240, upload-time = "2026-05-05T18:55:41.239Z" } +sdist = { url = "https://files.pythonhosted.org/packages/82/49/f20c2d920359c35a3196af220bd97e87e81d4fe4a93b1c604d5a14f4ae88/didactic-0.6.2.tar.gz", hash = "sha256:e782eeae17b03b027f6119dafcaeef7224c23468e255a7b0a487f9b437b92cb4", size = 108463, upload-time = "2026-05-06T20:03:47.514Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/17/a5/3ac3da8c62e8181592e6baeb8a09b9add914a6013a19b6c6f9b5929080b1/didactic-0.4.3-py3-none-any.whl", hash = "sha256:729943219fc6675433b3462c4585bcd8fe7f39b454534f14afd2ce52e6a26a41", size = 125863, upload-time = "2026-05-05T18:55:39.832Z" }, + { url = "https://files.pythonhosted.org/packages/77/95/3f1e20bb65e78fea6d936ac94c79907bf36c28bf5c332b7e60b88546e124/didactic-0.6.2-py3-none-any.whl", hash = "sha256:34ef2e4df0b938ee7fbd4b352903b85dd3a959b34c2ee3e2b987773426dd2dfb", size = 134154, upload-time = "2026-05-06T20:03:45.934Z" }, ] [[package]]