diff --git a/.pyrit_conf_example b/.pyrit_conf_example index 51a3b987c7..e437ada269 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -26,7 +26,8 @@ memory_db_type: sqlite # - airt: AI Red Team setup with Azure OpenAI (requires AZURE_OPENAI_* env vars) # - target: Registers available prompt targets into the TargetRegistry # - scorer: Registers pre-configured scorers into the ScorerRegistry -# - load_default_datasets: Loads default datasets for all registered scenarios +# - load_default_datasets: Optional preload of default datasets for all registered +# scenarios (scenarios otherwise fetch their datasets on demand) # - objective_list: Sets default objectives for scenarios # # Each initializer can be specified as: @@ -46,7 +47,6 @@ memory_db_type: sqlite # - scorer initializers: - name: simple - - name: load_default_datasets - name: target args: tags: diff --git a/doc/getting_started/pyrit_conf.md b/doc/getting_started/pyrit_conf.md index 54bcb4d9c4..448f619755 100644 --- a/doc/getting_started/pyrit_conf.md +++ b/doc/getting_started/pyrit_conf.md @@ -111,7 +111,6 @@ Most users should enable the following initializers. These are what the `.pyrit_ | `simple` | Baseline defaults for converters, scorers, and attack configs using your `OPENAI_CHAT_*` env vars | Always — provides the foundation for most PyRIT operations | | `target` | Prompt targets (OpenAI, Azure, AML, etc.) into the `TargetRegistry` | **Required for `pyrit_scan`** and any registry-based workflows | | `scorer` | Scorers (refusal, content safety, harm-category, Likert, etc.) into the `ScorerRegistry` | **Required for automated scoring** and `pyrit_scan` evaluations | -| `load_default_datasets` | Seed datasets for all registered scenarios into memory | **Required for `pyrit_scan` scenarios** — they need data to run | ```{note} **Execution order follows listing order.** Initializers execute in the order they appear in the config. Ensure dependencies are satisfied — for example, list `target` before `scorer` since scorers need targets to be registered first. @@ -122,7 +121,6 @@ The recommended config: ```yaml initializers: - name: simple - - name: load_default_datasets - name: scorer - name: target args: @@ -131,6 +129,10 @@ initializers: - scorer ``` +```{note} +**`load_default_datasets` is optional.** Scenarios fetch their datasets from the registered provider on demand the first time they run, so you no longer need this initializer for everyday runs. Add it only when you want to preload every scenario's datasets up front — for example, to warm memory for repeated runs or to populate a database for offline use. +``` + ### `initialization_scripts` Paths to custom Python scripts containing `PyRITInitializer` subclasses. Paths can be absolute or relative to the current working directory. diff --git a/doc/index.md b/doc/index.md index a9bf0ecbba..fefc943ad9 100644 --- a/doc/index.md +++ b/doc/index.md @@ -101,7 +101,6 @@ initializers: - default - scorer - name: scorer - - name: load_default_datasets ``` :::: diff --git a/doc/scanner/0_scanner.md b/doc/scanner/0_scanner.md index 69364f5b53..9e81192063 100644 --- a/doc/scanner/0_scanner.md +++ b/doc/scanner/0_scanner.md @@ -23,7 +23,7 @@ PyRIT provides two command-line interfaces: ```bash # Run the Foundry RedTeamAgent scenario against your configured target -pyrit_scan foundry.red_team_agent --target openai_chat --initializers target load_default_datasets --strategies base64 +pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --strategies base64 ``` ## Built-in Scenarios diff --git a/doc/scanner/1_pyrit_scan.ipynb b/doc/scanner/1_pyrit_scan.ipynb index 1ff21c6d5b..193d2ea1c9 100644 --- a/doc/scanner/1_pyrit_scan.ipynb +++ b/doc/scanner/1_pyrit_scan.ipynb @@ -110,7 +110,7 @@ "Or concretely:\n", "\n", "```shell\n", - "!pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --scenario-strategies base64\n", + "!pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --scenario-strategies base64\n", "```\n", "\n", "Example with a basic configuration that runs the Foundry scenario against the objective target defined in the `target` initializer." @@ -123,7 +123,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --strategies base64" + "!pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --strategies base64" ] }, { @@ -131,20 +131,20 @@ "id": "8", "metadata": {}, "source": [ - "Or with all options and multiple initializers and multiple strategies:\n", + "Or with all options and multiple strategies:\n", "\n", "```shell\n", - "pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --strategies easy crescendo\n", + "pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --strategies easy crescendo\n", "```\n", "\n", "You can also override scenario execution parameters:\n", "\n", "```shell\n", "# Override concurrency and retry settings\n", - "pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --max-concurrency 10 --max-retries 3\n", + "pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --max-concurrency 10 --max-retries 3\n", "\n", "# Add custom memory labels for tracking (must be valid JSON)\n", - "pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --memory-labels '{\"experiment\": \"test1\", \"version\": \"v2\", \"researcher\": \"alice\"}'\n", + "pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --memory-labels '{\"experiment\": \"test1\", \"version\": \"v2\", \"researcher\": \"alice\"}'\n", "```\n", "\n", "Available CLI parameter overrides:\n", diff --git a/doc/scanner/1_pyrit_scan.py b/doc/scanner/1_pyrit_scan.py index c1b452a071..ab825bf2ce 100644 --- a/doc/scanner/1_pyrit_scan.py +++ b/doc/scanner/1_pyrit_scan.py @@ -78,29 +78,29 @@ # Or concretely: # # ```shell -# !pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --scenario-strategies base64 +# !pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --scenario-strategies base64 # ``` # # Example with a basic configuration that runs the Foundry scenario against the objective target defined in the `target` initializer. # %% -# !pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --strategies base64 +# !pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --strategies base64 # %% [markdown] -# Or with all options and multiple initializers and multiple strategies: +# Or with all options and multiple strategies: # # ```shell -# pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --strategies easy crescendo +# pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --strategies easy crescendo # ``` # # You can also override scenario execution parameters: # # ```shell # # Override concurrency and retry settings -# pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --max-concurrency 10 --max-retries 3 +# pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --max-concurrency 10 --max-retries 3 # # # Add custom memory labels for tracking (must be valid JSON) -# pyrit_scan foundry.red_team_agent --target openai_chat --initializers load_default_datasets target --memory-labels '{"experiment": "test1", "version": "v2", "researcher": "alice"}' +# pyrit_scan foundry.red_team_agent --target openai_chat --initializers target --memory-labels '{"experiment": "test1", "version": "v2", "researcher": "alice"}' # ``` # # Available CLI parameter overrides: diff --git a/doc/scanner/2_pyrit_shell.md b/doc/scanner/2_pyrit_shell.md index 9f9731b601..f5d26af137 100644 --- a/doc/scanner/2_pyrit_shell.md +++ b/doc/scanner/2_pyrit_shell.md @@ -25,7 +25,7 @@ pyrit_shell --config-file ./.pyrit_conf pyrit_shell --log-level DEBUG # Load initializers at startup -pyrit_shell --initializers load_default_datasets +pyrit_shell --initializers target # Load custom initialization scripts pyrit_shell --initialization-scripts ./my_config.py @@ -54,32 +54,32 @@ The `run` command executes scenarios with the same options as `pyrit_scan`: ### Basic Usage ```bash -pyrit> run foundry.red_team_agent --target my_target --initializers target load_default_datasets +pyrit> run foundry.red_team_agent --target my_target --initializers target ``` ### With Strategies ```bash -pyrit> run garak.encoding --target my_target --initializers target load_default_datasets --strategies base64 rot13 +pyrit> run garak.encoding --target my_target --initializers target --strategies base64 rot13 -pyrit> run foundry.red_team_agent --target my_target --initializers target load_default_datasets -s jailbreak crescendo +pyrit> run foundry.red_team_agent --target my_target --initializers target -s jailbreak crescendo ``` ### With Runtime Parameters ```bash # Set concurrency and retries -pyrit> run foundry.red_team_agent --target my_target --initializers target load_default_datasets --max-concurrency 10 --max-retries 3 +pyrit> run foundry.red_team_agent --target my_target --initializers target --max-concurrency 10 --max-retries 3 # Add memory labels for tracking -pyrit> run garak.encoding --target my_target --initializers target load_default_datasets --memory-labels '{"experiment":"test1","version":"v2"}' +pyrit> run garak.encoding --target my_target --initializers target --memory-labels '{"experiment":"test1","version":"v2"}' ``` ### Override Defaults Per-Run ```bash # Override log level for this run only -pyrit> run garak.encoding --target my_target --initializers target load_default_datasets --log-level DEBUG +pyrit> run garak.encoding --target my_target --initializers target --log-level DEBUG ``` ### Run Command Options @@ -119,9 +119,9 @@ pyrit> scenario-history Scenario Run History: ================================================================================ -1) foundry.red_team_agent --initializers target load_default_datasets --strategies base64 -2) garak.encoding --initializers target load_default_datasets --strategies rot13 -3) foundry.red_team_agent --initializers target load_default_datasets -s jailbreak +1) foundry.red_team_agent --initializers target --strategies base64 +2) garak.encoding --initializers target --strategies rot13 +3) foundry.red_team_agent --initializers target -s jailbreak ================================================================================ Total runs: 3 @@ -135,7 +135,7 @@ The shell excels at interactive testing workflows: ```bash # Start shell with defaults -pyrit_shell --initializers target load_default_datasets +pyrit_shell --initializers target # Quick exploration pyrit> list-scenarios @@ -166,7 +166,7 @@ pyrit> print-scenario 2 2. **Use short strategy aliases** with `-s`: ```bash - pyrit> run foundry.red_team_agent --initializers target load_default_datasets -s base64 rot13 + pyrit> run foundry.red_team_agent --initializers target -s base64 rot13 ``` 3. **Review history regularly** to track what you've tested: diff --git a/doc/scanner/airt.ipynb b/doc/scanner/airt.ipynb index e1881ef631..cb4107454b 100644 --- a/doc/scanner/airt.ipynb +++ b/doc/scanner/airt.ipynb @@ -72,7 +72,7 @@ "\n", "```bash\n", "pyrit_scan airt.rapid_response \\\n", - " --initializers target load_default_datasets \\\n", + " --initializers target \\\n", " --target openai_chat \\\n", " --strategies role_play \\\n", " --dataset-names airt_hate \\ \n", @@ -367,7 +367,7 @@ "\n", "```bash\n", "pyrit_scan airt.cyber \\\n", - " --initializers target load_default_datasets \\\n", + " --initializers target \\\n", " --target openai_chat \\\n", " --strategies multi_turn \\\n", " --max-dataset-size 1\n", @@ -504,7 +504,7 @@ "\n", "```bash\n", "pyrit_scan airt.jailbreak \\\n", - " --initializers target load_default_datasets \\\n", + " --initializers target \\\n", " --target openai_chat \\\n", " --strategies prompt_sending \\\n", " --max-dataset-size 1\n", @@ -1125,7 +1125,7 @@ "\n", "```bash\n", "pyrit_scan airt.scam \\\n", - " --initializers target load_default_datasets \\\n", + " --initializers target \\\n", " --target openai_chat \\\n", " --strategies context_compliance \\\n", " --max-dataset-size 1\n", diff --git a/doc/scanner/airt.py b/doc/scanner/airt.py index 05312e7b42..910a744eab 100644 --- a/doc/scanner/airt.py +++ b/doc/scanner/airt.py @@ -40,7 +40,7 @@ # # ```bash # pyrit_scan airt.rapid_response \ -# --initializers target load_default_datasets \ +# --initializers target \ # --target openai_chat \ # --strategies role_play \ # --dataset-names airt_hate \ @@ -123,7 +123,7 @@ # # ```bash # pyrit_scan airt.cyber \ -# --initializers target load_default_datasets \ +# --initializers target \ # --target openai_chat \ # --strategies multi_turn \ # --max-dataset-size 1 @@ -156,7 +156,7 @@ # # ```bash # pyrit_scan airt.jailbreak \ -# --initializers target load_default_datasets \ +# --initializers target \ # --target openai_chat \ # --strategies prompt_sending \ # --max-dataset-size 1 @@ -236,7 +236,7 @@ # # ```bash # pyrit_scan airt.scam \ -# --initializers target load_default_datasets \ +# --initializers target \ # --target openai_chat \ # --strategies context_compliance \ # --max-dataset-size 1 diff --git a/doc/scanner/benchmark.ipynb b/doc/scanner/benchmark.ipynb index 120ed75d65..0855c49fc3 100644 --- a/doc/scanner/benchmark.ipynb +++ b/doc/scanner/benchmark.ipynb @@ -31,7 +31,7 @@ "\n", "```bash\n", "pyrit_scan benchmark.adversarial \\\n", - " --initializers target load_default_datasets \\\n", + " --initializers target \\\n", " --target openai_chat \\\n", " --adversarial-targets adversarial_chat_singleturn adversarial_chat_multiturn \\\n", " --max-dataset-size 4\n", diff --git a/doc/scanner/benchmark.py b/doc/scanner/benchmark.py index f558ebaa08..31b44fd5bd 100644 --- a/doc/scanner/benchmark.py +++ b/doc/scanner/benchmark.py @@ -30,7 +30,7 @@ # # ```bash # pyrit_scan benchmark.adversarial \ -# --initializers target load_default_datasets \ +# --initializers target \ # --target openai_chat \ # --adversarial-targets adversarial_chat_singleturn adversarial_chat_multiturn \ # --max-dataset-size 4 diff --git a/doc/scanner/pyrit_conf.yaml b/doc/scanner/pyrit_conf.yaml index 9a930306e4..c0dea81904 100644 --- a/doc/scanner/pyrit_conf.yaml +++ b/doc/scanner/pyrit_conf.yaml @@ -8,4 +8,3 @@ initializers: - scorer - name: scorer - name: scenario_technique - - name: load_default_datasets diff --git a/pyrit/backend/services/scenario_run_service.py b/pyrit/backend/services/scenario_run_service.py index b4997828a0..7e4239efbe 100644 --- a/pyrit/backend/services/scenario_run_service.py +++ b/pyrit/backend/services/scenario_run_service.py @@ -24,7 +24,7 @@ from pyrit.models import AttackOutcome, ScenarioResult from pyrit.registry import InitializerRegistry, ScenarioRegistry, TargetRegistry from pyrit.scenario import Scenario -from pyrit.scenario.core import DatasetConfiguration +from pyrit.scenario.core import DatasetAttackConfiguration if TYPE_CHECKING: from pyrit.prompt_target import PromptTarget @@ -266,11 +266,11 @@ def _build_init_kwargs( Resolves strategies and dataset configuration from the request. Dataset configuration is built so that the scenario's default - ``DatasetConfiguration`` *subclass* (e.g. ``EncodingDatasetConfiguration``) + ``DatasetAttackConfiguration`` *subclass* (e.g. ``EncodingDatasetConfiguration``) is preserved when the caller overrides ``dataset_names`` or ``max_dataset_size``. Subclasses commonly override - ``get_all_seed_attack_groups()`` or ``_load_seed_groups_for_dataset()`` - to shape seeds into scenario-appropriate ``SeedAttackGroup`` objects. + ``_build_attack_groups()`` to shape seeds into scenario-appropriate + ``SeedAttackGroup`` objects. Args: request: The run request. @@ -345,18 +345,18 @@ def _build_init_kwargs( except TypeError as exc: # The subclass __init__ takes extra required kwargs we cannot # supply from a backend request. Fall back to the base - # DatasetConfiguration so the run can still proceed; downstream + # DatasetAttackConfiguration so the run can still proceed; downstream # scenarios that strictly require the subclass should either # define a no-extra-required-args constructor or surface the # incompatibility through their own initialize_async validation. logger.warning( "Cannot construct %s(dataset_names=..., max_dataset_size=...) (%s). " - "Falling back to a generic DatasetConfiguration; scenario-specific " + "Falling back to a generic DatasetAttackConfiguration; scenario-specific " "dataset-config behavior may be lost.", default_config_class.__name__, exc, ) - init_kwargs["dataset_config"] = DatasetConfiguration( + init_kwargs["dataset_config"] = DatasetAttackConfiguration( dataset_names=request.dataset_names, max_dataset_size=request.max_dataset_size, ) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 7ec9cc5a96..203124d611 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -107,6 +107,7 @@ SeedSimulatedConversation, SeedUnion, SimulatedTargetSystemPromptPaths, + group_seeds_into_attack_groups, ) from pyrit.models.target_capabilities import CapabilityName, TargetCapabilities @@ -153,6 +154,7 @@ "get_all_values", "group_conversation_message_pieces_by_sequence", "group_message_pieces_into_conversations", + "group_seeds_into_attack_groups", "HarmDefinition", "Identifiable", "IdentifierFilter", diff --git a/pyrit/models/seeds/__init__.py b/pyrit/models/seeds/__init__.py index 0f58d20cc3..4fdd97b723 100644 --- a/pyrit/models/seeds/__init__.py +++ b/pyrit/models/seeds/__init__.py @@ -20,6 +20,7 @@ from pyrit.models.seeds.seed_attack_technique_group import SeedAttackTechniqueGroup from pyrit.models.seeds.seed_dataset import SeedDataset from pyrit.models.seeds.seed_group import SeedGroup, SeedUnion +from pyrit.models.seeds.seed_grouping import group_seeds_into_attack_groups from pyrit.models.seeds.seed_objective import SeedObjective from pyrit.models.seeds.seed_prompt import SeedPrompt from pyrit.models.seeds.seed_simulated_conversation import ( @@ -37,6 +38,7 @@ "load_seed_dataset_from_yaml", "load_seed_from_yaml", "load_seed_prompt_from_yaml_with_required_parameters", + "group_seeds_into_attack_groups", "NextMessageSystemPromptPaths", "Seed", "SeedAttackGroup", diff --git a/pyrit/models/seeds/seed_grouping.py b/pyrit/models/seeds/seed_grouping.py new file mode 100644 index 0000000000..1afc97e0d3 --- /dev/null +++ b/pyrit/models/seeds/seed_grouping.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Helpers for grouping flat seeds into structured groups. + +These mirror the ``group_message_pieces_into_conversations`` helpers for +``MessagePiece`` (``pyrit.models.messages.conversations``): a flat list of +seeds is regrouped by ``prompt_group_id`` -- the seed analog of +``conversation_id`` -- back into validated group objects. Construction of the +group object *is* the validation: ``SeedAttackGroup`` enforces exactly one +objective, consistent group id, and role/sequence rules on init, so a malformed +grouping raises there rather than via a separate hand-rolled check. +""" + +from __future__ import annotations + +import uuid +from collections import defaultdict +from typing import TYPE_CHECKING, cast + +from pyrit.models.seeds.seed_attack_group import SeedAttackGroup + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pyrit.models.seeds.seed import Seed + from pyrit.models.seeds.seed_group import SeedUnion + + +def group_seeds_into_attack_groups(seeds: Sequence[Seed]) -> list[SeedAttackGroup]: + """ + Group flat seeds by ``prompt_group_id`` into ``SeedAttackGroup`` instances. + + Seeds sharing a ``prompt_group_id`` are collapsed into a single + ``SeedAttackGroup``; seeds without one (``prompt_group_id is None``) each + become their own group. Within a group, seeds are ordered by ``sequence`` + when available before construction. + + Construction validates the grouping: ``SeedAttackGroup`` requires exactly one + objective per group (plus the inherited ``SeedGroup`` invariants), so a group + that lacks an objective -- or otherwise violates the invariants -- raises a + ``ValueError`` here. This is intentional: callers that want stored groupings + turned into attack groups get a fail-fast error on malformed data. + + Args: + seeds (Sequence[Seed]): The flat seeds to group. + + Returns: + list[SeedAttackGroup]: One attack group per ``prompt_group_id`` (and one + per ungrouped seed), each self-validated on construction. + + Raises: + ValueError: If any resulting group does not satisfy ``SeedAttackGroup``'s + invariants (e.g. it has no objective or more than one). + """ + grouped_seeds: dict[uuid.UUID, list[Seed]] = defaultdict(list) + for seed in seeds: + group_id = seed.prompt_group_id if seed.prompt_group_id is not None else uuid.uuid4() + grouped_seeds[group_id].append(seed) + + attack_groups: list[SeedAttackGroup] = [] + for group_seeds in grouped_seeds.values(): + if len(group_seeds) > 1: + group_seeds.sort(key=lambda s: getattr(s, "sequence", None) or 0) + attack_groups.append(SeedAttackGroup(seeds=cast("list[SeedUnion]", group_seeds))) + + return attack_groups diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index 0300b00a06..a4187b0ef4 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -243,7 +243,7 @@ def _build_metadata(self, name: str, entry: ClassEntry[Scenario]) -> ScenarioMet default_strategy_value = instance._default_strategy.value all_strategies = tuple(s.value for s in strategy_class.get_all_strategies()) aggregate_strategies = tuple(s.value for s in strategy_class.get_aggregate_strategies()) - default_datasets = tuple(instance._default_dataset_config.get_default_dataset_names()) + default_datasets = tuple(instance._default_dataset_config.dataset_names) max_dataset_size = instance._default_dataset_config.max_dataset_size return ScenarioMetadata( diff --git a/pyrit/scenario/__init__.py b/pyrit/scenario/__init__.py index 2ef485abdf..682157ccfd 100644 --- a/pyrit/scenario/__init__.py +++ b/pyrit/scenario/__init__.py @@ -25,7 +25,10 @@ AttackTechnique, AttackTechniqueFactory, BaselineAttackPolicy, + DatasetAttackConfiguration, DatasetConfiguration, + DatasetSourceKind, + ResolvedDataset, Scenario, ScenarioCompositeStrategy, ScenarioStrategy, @@ -79,8 +82,11 @@ def _register_scenario_alias(short_name: str, canonical_module: ModuleType) -> N "AttackTechnique", "AttackTechniqueFactory", "BaselineAttackPolicy", + "DatasetAttackConfiguration", "DatasetConfiguration", + "DatasetSourceKind", "Parameter", + "ResolvedDataset", "Scenario", "ScenarioCompositeStrategy", "ScenarioStrategy", diff --git a/pyrit/scenario/core/__init__.py b/pyrit/scenario/core/__init__.py index 7b50cef237..a0162a472e 100644 --- a/pyrit/scenario/core/__init__.py +++ b/pyrit/scenario/core/__init__.py @@ -7,7 +7,21 @@ from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory, ScorerOverridePolicy -from pyrit.scenario.core.dataset_configuration import EXPLICIT_SEED_GROUPS_KEY, DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import ( + INLINE_DATASET_NAME, + DatasetAttackConfiguration, + DatasetConfiguration, + DatasetConstraintError, + DatasetSourceKind, + ResolvedDataset, + forbid_inline_seeds, + require_harm_categories, + require_inline_seeds, + require_min_size, + require_nonempty, + require_seed_type, + restrict_dataset_names, +) from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_adversarial_target, get_default_scorer_target @@ -17,9 +31,20 @@ "AttackTechnique", "AttackTechniqueFactory", "BaselineAttackPolicy", + "DatasetAttackConfiguration", "DatasetConfiguration", - "EXPLICIT_SEED_GROUPS_KEY", + "DatasetConstraintError", + "DatasetSourceKind", + "INLINE_DATASET_NAME", "Parameter", + "ResolvedDataset", + "forbid_inline_seeds", + "require_harm_categories", + "require_inline_seeds", + "require_min_size", + "require_nonempty", + "require_seed_type", + "restrict_dataset_names", "Scenario", "ScenarioCompositeStrategy", "ScenarioStrategy", diff --git a/pyrit/scenario/core/dataset_configuration.py b/pyrit/scenario/core/dataset_configuration.py index a6cbf25cab..7a0031c7e6 100644 --- a/pyrit/scenario/core/dataset_configuration.py +++ b/pyrit/scenario/core/dataset_configuration.py @@ -4,104 +4,529 @@ """ Dataset configuration for scenarios. -This module provides the DatasetConfiguration class that allows scenarios to be configured -with either explicit SeedGroups or dataset names (mutually exclusive). +``DatasetConfiguration`` is the object a scenario uses to say "where do my seeds come +from." ``DatasetAttackConfiguration`` -- the configuration most scenarios use -- groups +the resolved seeds into ``SeedAttackGroup`` s (each carrying exactly one objective plus +optional prompts). + +Constraints are expressed through a single mechanism: ``validators``. Each validator is a +``Callable[[ResolvedDataset], None]`` that raises ``DatasetConstraintError`` on violation. +Validators run against the fully resolved dataset (before ``max_dataset_size`` sampling), +so they describe the dataset itself, not the sampled subset. The ``ResolvedDataset`` they +receive also carries the ``DatasetSourceKind`` (inline vs from memory) and the contributing +``dataset_names``, which lets a scenario require or forbid inline seeds -- useful for CLI +flags such as ``--objectives`` -- restrict which datasets it will resolve from, or require a +particular seed type (e.g. ``require_seed_type(SeedObjective)``). + +Memory is the source of truth. When a configured dataset name is not yet in memory and +``auto_fetch`` is enabled (the default), the resolver transparently fetches the dataset +from the registered ``SeedDatasetProvider`` into memory. If a configured dataset +name still yields nothing, the resolver raises loudly rather than silently skipping it. +Inline configs (``seeds=`` / ``seed_groups=``) never touch memory. """ from __future__ import annotations import random -from typing import TYPE_CHECKING +from dataclasses import dataclass +from enum import Enum +from functools import cached_property +from typing import TYPE_CHECKING, TypeVar +from pyrit.common.deprecation import print_deprecation_message from pyrit.memory import CentralMemory -from pyrit.models import SeedAttackGroup, SeedGroup +from pyrit.models import ( + Seed, + SeedAttackGroup, + SeedGroup, + group_seeds_into_attack_groups, +) if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence - from pyrit.models.seeds.seed import Seed - from pyrit.scenario.core.scenario_strategy import ScenarioStrategy + from pyrit.memory import MemoryInterface -# Key used when seed_groups are provided directly (not from a named dataset) -EXPLICIT_SEED_GROUPS_KEY = "_explicit_seed_groups" +# Dataset-name label that inline ``seeds`` / ``seed_groups`` carry in by-dataset views, since +# they have no real dataset name. Inline and named sources are mutually exclusive, so this +# never collides with a configured dataset name. +INLINE_DATASET_NAME = "inline" +# Version in which the deprecated legacy getters will be removed (current ver: 0.15.0.dev0). +_LEGACY_REMOVED_IN = "0.17.0" -class DatasetConfiguration: +# Internal helper TypeVar for size-capping any homogeneous list. +_ItemT = TypeVar("_ItemT") + + +class DatasetSourceKind(Enum): + """ + How a ``DatasetConfiguration``'s seeds were sourced. + + Only two cases matter to validators: seeds supplied inline by the caller, versus + seeds loaded from memory by dataset name (auto-fetched into memory first when + missing). This lets a constraint require or forbid inline data -- e.g. a CLI + ``--objectives`` flag that must be passed inline rather than via a named dataset. + """ + + INLINE = "inline" + MEMORY = "memory" + + +@dataclass(frozen=True) +class ResolvedDataset: + """ + The fully resolved seeds plus the source they came from. + + Passed to every validator so a constraint can inspect the seeds, how they were + supplied (inline vs named dataset), and which dataset names contributed. + + Args: + seeds (Sequence[Seed]): The resolved seeds (before ``max_dataset_size`` sampling). + source_kind (DatasetSourceKind): How the configuration was sourced. + dataset_names (tuple[str, ...]): The configured dataset names that contributed + seeds, in configuration order. Empty for inline ``seeds`` / ``seed_groups``. + """ + + seeds: Sequence[Seed] + source_kind: DatasetSourceKind + dataset_names: tuple[str, ...] = () + + @property + def is_inline(self) -> bool: + """ + Whether the seeds were supplied inline (not loaded from a named dataset). + + Returns: + bool: True for inline ``seeds=`` / ``seed_groups=`` sources. + """ + return self.source_kind is DatasetSourceKind.INLINE + + +class DatasetConstraintError(ValueError): + """ + Raised when a resolved dataset does not satisfy a configuration's constraints. + + Subclasses ``ValueError`` so existing ``except ValueError`` handlers keep working, + while letting the CLI/backend present a friendly "dataset X doesn't satisfy + scenario Y's requirements" message. + """ + + +def require_nonempty() -> Callable[[ResolvedDataset], None]: + """ + Build a validator that raises when a resolved dataset is empty. + + Returns: + Callable[[ResolvedDataset], None]: A validator usable in ``validators=[...]``. + """ + + def _validate(resolved: ResolvedDataset) -> None: + if not resolved.seeds: + raise DatasetConstraintError("Resolved dataset is empty.") + + return _validate + + +def require_min_size(minimum: int) -> Callable[[ResolvedDataset], None]: + """ + Build a validator that raises when a resolved dataset has fewer than ``minimum`` items. + + Args: + minimum (int): The minimum acceptable number of items. + + Returns: + Callable[[ResolvedDataset], None]: A validator usable in ``validators=[...]``. + """ + + def _validate(resolved: ResolvedDataset) -> None: + if len(resolved.seeds) < minimum: + raise DatasetConstraintError( + f"Resolved dataset has {len(resolved.seeds)} item(s); require at least {minimum}." + ) + + return _validate + + +def require_harm_categories(required: set[str]) -> Callable[[ResolvedDataset], None]: + """ + Build a validator that requires every resolved item to carry all of ``required`` harm categories. + + Args: + required (set[str]): Harm categories every item must include. + + Returns: + Callable[[ResolvedDataset], None]: A validator usable in ``validators=[...]``. + """ + + def _validate(resolved: ResolvedDataset) -> None: + for item in resolved.seeds: + categories = set(getattr(item, "harm_categories", None) or []) + missing = required - categories + if missing: + raise DatasetConstraintError(f"Resolved item is missing required harm categories: {sorted(missing)}.") + + return _validate + + +def require_seed_type(seed_type: type[Seed]) -> Callable[[ResolvedDataset], None]: + """ + Build a validator that requires every resolved seed to be an instance of ``seed_type``. + + Args: + seed_type (type[Seed]): The seed type every resolved seed must be. + + Returns: + Callable[[ResolvedDataset], None]: A validator usable in ``validators=[...]``. + """ + + def _validate(resolved: ResolvedDataset) -> None: + wrong = {type(seed).__name__ for seed in resolved.seeds if not isinstance(seed, seed_type)} + if wrong: + raise DatasetConstraintError(f"Expected all seeds to be {seed_type.__name__}; found {sorted(wrong)}.") + + return _validate + + +def require_inline_seeds() -> Callable[[ResolvedDataset], None]: + """ + Build a validator that requires the dataset to be supplied inline. + + Use when a scenario must receive seeds directly (e.g. CLI ``--objectives``) rather + than via a named dataset. + + Returns: + Callable[[ResolvedDataset], None]: A validator usable in ``validators=[...]``. + """ + + def _validate(resolved: ResolvedDataset) -> None: + if not resolved.is_inline: + raise DatasetConstraintError( + "This configuration requires inline seeds (pass 'seeds' or 'seed_groups'), not a named dataset." + ) + + return _validate + + +def forbid_inline_seeds() -> Callable[[ResolvedDataset], None]: + """ + Build a validator that forbids inline seeds (the dataset must come from named datasets). + + Use when a scenario must resolve from memory/providers and inline seeds would bypass + expected curation. + + Returns: + Callable[[ResolvedDataset], None]: A validator usable in ``validators=[...]``. + """ + + def _validate(resolved: ResolvedDataset) -> None: + if resolved.is_inline: + raise DatasetConstraintError("This configuration does not allow inline seeds; use 'dataset_names' instead.") + + return _validate + + +def restrict_dataset_names(allowed: set[str]) -> Callable[[ResolvedDataset], None]: """ - Configuration for scenario datasets. + Build a validator that requires every contributing dataset name to be in ``allowed``. - This class provides a unified way to specify the dataset source for scenarios. - Only ONE of `seed_groups` or `dataset_names` can be set. + Use when a scenario only knows how to handle a fixed set of datasets -- for example, + one that pairs techniques with specific datasets -- so a caller-supplied + ``--dataset-names`` outside that set is rejected loudly. Inline seeds carry no dataset + name and therefore pass; compose with ``forbid_inline_seeds`` to also require named + datasets. Args: - seed_groups (list[SeedGroup] | None): Explicit list of SeedGroup to use. - dataset_names (list[str] | None): Names of datasets to load from memory. - max_dataset_size (int | None): If set, randomly samples up to this many SeedGroups - from the configured dataset source (without replacement, so no duplicates). - scenario_strategies (Sequence[ScenarioStrategy] | None): The scenario - strategies being executed. Subclasses can use this to filter or customize - which seed groups are loaded based on the selected strategies. + allowed (set[str]): The dataset names the configuration may resolve from. + + Returns: + Callable[[ResolvedDataset], None]: A validator usable in ``validators=[...]``. + """ + + def _validate(resolved: ResolvedDataset) -> None: + disallowed = sorted(set(resolved.dataset_names) - allowed) + if disallowed: + raise DatasetConstraintError( + f"Datasets {disallowed} are not allowed for this configuration; " + f"permitted datasets are {sorted(allowed)}." + ) + + return _validate + + +class DatasetConfiguration: + """ + Configuration describing where a scenario's seeds come from. + + This base class handles resolution, fetching, validation, and sampling. + ``DatasetAttackConfiguration`` is the concrete subclass most scenarios use; it groups + the resolved seeds into ``SeedAttackGroup`` s. A configuration draws from exactly one + source: + + - ``seeds`` -- an explicit, inline list of seeds (never touches memory). + - ``seed_groups`` -- explicit, inline seed groups (never touches memory). + - ``dataset_names`` -- names looked up in memory; missing names are fetched from the + registered ``SeedDatasetProvider`` when ``auto_fetch`` is enabled. + + Resolution reads memory (the source of truth) and, per dataset name, fetches from the + provider when missing and ``auto_fetch`` is set. If a configured name still yields no + seeds, ``_collect_seeds_for_dataset_async`` raises ``DatasetConstraintError`` -- failures + are loud, not silently skipped. + + Constraints are expressed through a single mechanism -- ``validators`` -- so there is + one place to look. Customize behavior through small seams without re-implementing + sampling/fetching: + + - ``_default_validators`` -- validators a subclass always applies (e.g. a seed-type + check). The preferred way to enforce a constraint type-wide. + - ``_collect_seeds_for_dataset_async`` -- the per-dataset memory query (override for + richer filters). + + The legacy getters (``get_seed_groups`` / ``get_all_seed_attack_groups`` / ...) are + deprecated and will be removed in 0.17.0; prefer ``DatasetAttackConfiguration``. """ def __init__( self, *, + seeds: Sequence[Seed] | None = None, seed_groups: list[SeedGroup] | None = None, dataset_names: list[str] | None = None, max_dataset_size: int | None = None, - scenario_strategies: Sequence[ScenarioStrategy] | None = None, + validators: Sequence[Callable[[ResolvedDataset], None]] | None = None, + auto_fetch: bool = True, ) -> None: """ Initialize a DatasetConfiguration. Args: - seed_groups (list[SeedGroup] | None): Explicit list of SeedGroup to use. + seeds (Sequence[Seed] | None): Explicit, inline seeds (never touches memory). + seed_groups (list[SeedGroup] | None): Explicit, inline seed groups (never + touches memory). dataset_names (list[str] | None): Names of datasets to load from memory. - max_dataset_size (int | None): If set, randomly samples up to this many SeedGroups - (without replacement). - scenario_strategies (Sequence[ScenarioStrategy] | None): The scenario - strategies being executed. Subclasses can use this to filter or customize - which seed groups are loaded. + max_dataset_size (int | None): If set, randomly samples up to this many items + from the resolved dataset (without replacement). + validators (Sequence[Callable[[ResolvedDataset], None]] | None): Constraint + callbacks run against the resolved dataset; each raises on violation. These + are appended to the subclass's ``_default_validators``. + auto_fetch (bool): When True (default), a configured dataset name that is not + in memory is fetched from the registered ``SeedDatasetProvider`` into + memory before resolving. Set False for strict "must already be in memory". Raises: - ValueError: If both seed_groups and dataset_names are set. + ValueError: If more than one of seeds/seed_groups/dataset_names is set. ValueError: If max_dataset_size is less than 1. """ - # Validate that only one data source is set - if seed_groups is not None and dataset_names is not None: + sources = [src for src in (seeds, seed_groups, dataset_names) if src is not None] + if len(sources) > 1: raise ValueError( - "Only one of 'seed_groups' or 'dataset_names' can be set. " - "Use 'seed_groups' to provide explicit SeedGroups, " - "or 'dataset_names' to load from memory." + "Only one of 'seeds', 'seed_groups', or 'dataset_names' can be set. " + "Use 'seeds'/'seed_groups' to provide inline data, or 'dataset_names' to load from memory." ) if max_dataset_size is not None and max_dataset_size < 1: raise ValueError("'max_dataset_size' must be a positive integer (>= 1).") - # Store private attributes + self._seeds = list(seeds) if seeds is not None else None self._seed_groups = list(seed_groups) if seed_groups is not None else None - self.max_dataset_size = max_dataset_size self._dataset_names = list(dataset_names) if dataset_names is not None else None - self._scenario_strategies = scenario_strategies + self.max_dataset_size = max_dataset_size + self._validators: list[Callable[[ResolvedDataset], None]] = [ + *self._default_validators(), + *(list(validators) if validators else []), + ] + self._auto_fetch = auto_fetch - def get_seed_groups(self) -> dict[str, list[SeedGroup]]: + def _default_validators(self) -> list[Callable[[ResolvedDataset], None]]: + """ + Return validators a subclass always applies, prepended to user-supplied ``validators``. + + The base requires a non-empty resolved dataset. A subclass can extend this to enforce + an additional constraint (e.g. ``require_seed_type(SeedObjective)``) by returning + ``[*super()._default_validators(), ...]`` rather than overriding ``validate``. + + Returns: + list[Callable[[ResolvedDataset], None]]: The default validators. + """ + return [require_nonempty()] + + @cached_property + def _memory(self) -> MemoryInterface: + """ + The central memory instance, resolved lazily on first use and cached. + + Resolved lazily (rather than in ``__init__``) so a configuration can be + constructed for introspection -- e.g. the scenario registry instantiating a + scenario to read its default dataset names -- without a memory instance set. + + Returns: + MemoryInterface: The central memory instance. + """ + return CentralMemory.get_memory_instance() + + @property + def dataset_names(self) -> list[str]: + """ + The configured dataset names. + + Returns: + list[str]: The dataset names, or an empty list when using inline seeds/groups. + """ + return list(self._dataset_names or []) + + @property + def source_kind(self) -> DatasetSourceKind: + """ + Whether this configuration's seeds are supplied inline or loaded from memory. + + Inline ``seeds`` / ``seed_groups`` resolve to ``INLINE``; named datasets (and an + unconfigured source) resolve to ``MEMORY``. + + Returns: + DatasetSourceKind: The source kind. """ - Resolve and return seed groups based on the configuration. + if self._seeds is not None or self._seed_groups is not None: + return DatasetSourceKind.INLINE + return DatasetSourceKind.MEMORY + + # ========================================================================= + # Resolution helpers + # ========================================================================= + + async def _collect_named_seeds_async(self) -> dict[str, list[Seed]]: + """ + Collect seeds for each configured dataset name, keyed by name. + + Each name is read from memory and -- when empty and ``auto_fetch`` is set -- fetched + from the provider; a name that still yields nothing raises loudly. + + Returns: + dict[str, list[Seed]]: Dataset name -> seeds, in configuration order (every value + is non-empty). - This method handles all resolution logic: - 1. If seed_groups is set, use those directly (under key '_explicit_seed_groups') - 2. If dataset_names is set, load from memory using those names + Raises: + DatasetConstraintError: If any configured dataset yields no seeds. + """ + result: dict[str, list[Seed]] = {} + for name in self._dataset_names or []: + result[name] = await self._collect_seeds_for_dataset_async(dataset_name=name) + return result - In all cases, max_dataset_size is applied **per dataset** if set. + async def _collect_seeds_for_dataset_async(self, *, dataset_name: str) -> list[Seed]: + """ + Collect seeds for a single dataset name, fetching from the provider if needed. - Subclasses can override this to filter or customize which seed groups - are loaded based on the stored scenario_composites. + Args: + dataset_name (str): The dataset name to load. Returns: - dict[str, list[SeedGroup]]: Dictionary mapping dataset names to their - seed groups. When explicit seed_groups are provided, the key is - '_explicit_seed_groups'. Each dataset's seed groups are potentially - sampled down to max_dataset_size. + list[Seed]: The seeds for ``dataset_name``. + + Raises: + DatasetConstraintError: If the dataset yields no seeds even after auto-fetch, or + if auto-fetch itself fails (the provider error is chained as the cause). + """ + found = list(self._memory.get_seeds(dataset_name=dataset_name)) + if not found and self._auto_fetch: + try: + await self._fetch_dataset_async(dataset_name=dataset_name) + except Exception as exc: + raise DatasetConstraintError( + f"Dataset '{dataset_name}' could not be loaded: auto-fetch from the registered provider failed." + ) from exc + found = list(self._memory.get_seeds(dataset_name=dataset_name)) + if not found: + hint = ( + "auto-fetch from the registered provider did not populate it" + if self._auto_fetch + else "auto_fetch is disabled" + ) + raise DatasetConstraintError( + f"Dataset '{dataset_name}' could not be loaded: no seeds found in memory and {hint}." + ) + return found + + async def _fetch_dataset_async(self, *, dataset_name: str) -> None: + """ + Populate memory from the registered provider for a single dataset (private). + + An unregistered name populates nothing and falls through to the caller's loud + empty-result handling. Provider errors (enumeration or fetch) propagate so the + caller can surface the root cause. Never samples or validates -- it only adds to + memory. + + Args: + dataset_name (str): The dataset name to fetch. + """ + # Local import to avoid an import cycle at package init time. + from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider + + registered = set(await SeedDatasetProvider.get_all_dataset_names_async()) + if dataset_name not in registered: + return + + datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=[dataset_name]) + await self._memory.add_seed_datasets_to_memory_async(datasets=datasets, added_by="DatasetConfiguration") + + def validate(self, resolved: ResolvedDataset) -> None: + """ + Validate the resolved dataset against every configured validator. + + Runs the defaults from ``_default_validators`` (non-emptiness, plus any seed-type + constraint a subclass imposes) followed by any validators passed to ``validators=``. + Prefer adding a validator over overriding this method. + + Args: + resolved (ResolvedDataset): The resolved seeds and their source kind. + + Raises: + DatasetConstraintError: If any constraint is violated. + """ + for validator in self._validators: + validator(resolved) + + def _apply_max_dataset_size(self, items: list[_ItemT]) -> list[_ItemT]: + """ + Apply ``max_dataset_size`` sampling without replacement. + + Args: + items (list[_ItemT]): The items to potentially sample from. + + Returns: + list[_ItemT]: The original list, or a random sample of up to + ``max_dataset_size`` unique items. + """ + if self.max_dataset_size is None or len(items) <= self.max_dataset_size: + return items + return random.sample(items, self.max_dataset_size) + + # ========================================================================= + # Legacy getters (deprecated; removed in 0.17.0) + # ========================================================================= + + def get_seed_groups(self) -> dict[str, list[SeedGroup]]: + """ + Resolve and return seed groups keyed by dataset (deprecated). + + Returns: + dict[str, list[SeedGroup]]: Dataset name -> seed groups, sampled per dataset. + + Raises: + ValueError: If no seed groups could be resolved from the configuration. + """ + print_deprecation_message( + old_item="DatasetConfiguration.get_seed_groups", + new_item="DatasetAttackConfiguration.get_attack_groups_by_dataset_async", + removed_in=_LEGACY_REMOVED_IN, + ) + return self._get_seed_groups() + + def _get_seed_groups(self) -> dict[str, list[SeedGroup]]: + """ + Resolve and return seed groups keyed by dataset (legacy implementation). + + Returns: + dict[str, list[SeedGroup]]: Dataset name -> seed groups, sampled per dataset. Raises: ValueError: If no seed groups could be resolved from the configuration. @@ -109,22 +534,14 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]: result: dict[str, list[SeedGroup]] = {} if self._seed_groups is not None: - # Use explicit seed groups under a special key sampled = self._apply_max_dataset_size(list(self._seed_groups)) if sampled: - result[EXPLICIT_SEED_GROUPS_KEY] = sampled + result[INLINE_DATASET_NAME] = sampled elif self._dataset_names is not None: - # Load from specified dataset names, applying max per dataset for name in self._dataset_names: - if name == EXPLICIT_SEED_GROUPS_KEY: - raise ValueError( - f"Dataset name '{EXPLICIT_SEED_GROUPS_KEY}' is reserved for internal use. " - "Please rename your dataset." - ) loaded = self._load_seed_groups_for_dataset(dataset_name=name) if loaded: - sampled = self._apply_max_dataset_size(loaded) - result[name] = sampled + result[name] = self._apply_max_dataset_size(loaded) if not result: raise ValueError("DatasetConfiguration has no seed_groups. Set seed_groups or dataset_names.") @@ -133,148 +550,243 @@ def get_seed_groups(self) -> dict[str, list[SeedGroup]]: def _load_seed_groups_for_dataset(self, *, dataset_name: str) -> list[SeedGroup]: """ - Load seed groups for a single dataset from memory. - - Override this method in subclasses to customize how seed groups are loaded - from memory. The default implementation loads by exact dataset name. + Load seed groups for a single dataset from memory (legacy override hook). Args: - dataset_name (str): The name of the dataset to load. + dataset_name (str): The dataset name to load. Returns: list[SeedGroup]: Seed groups loaded from memory, or empty list if none found. """ - memory = CentralMemory.get_memory_instance() - return list(memory.get_seed_groups(dataset_name=dataset_name) or []) + return list(self._memory.get_seed_groups(dataset_name=dataset_name) or []) def get_all_seed_groups(self) -> list[SeedGroup]: """ - Resolve and return all seed groups as a flat list. - - This is a convenience method that calls get_seed_groups() and flattens - the results into a single list. Use this when you don't need to track - which dataset each seed group came from. + Resolve and return all seed groups as a flat list (deprecated). Returns: - list[SeedGroup]: All resolved seed groups from all datasets, - with max_dataset_size applied per dataset. - - Raises: - ValueError: If no seed groups could be resolved from the configuration. + list[SeedGroup]: All resolved seed groups across datasets. """ - seed_groups_by_dataset = self.get_seed_groups() + print_deprecation_message( + old_item="DatasetConfiguration.get_all_seed_groups", + new_item="DatasetAttackConfiguration.get_seed_attack_groups_async", + removed_in=_LEGACY_REMOVED_IN, + ) all_groups: list[SeedGroup] = [] - for groups in seed_groups_by_dataset.values(): + for groups in self._get_seed_groups().values(): all_groups.extend(groups) return all_groups def get_seed_attack_groups(self) -> dict[str, list[SeedAttackGroup]]: """ - Resolve and return seed groups as SeedAttackGroups, grouped by dataset. - - This wraps get_seed_groups() and converts each SeedGroup to a SeedAttackGroup. - Use this when you need attack-specific functionality like objectives, - prepended conversations, or simulated conversation configuration. + Resolve and return seed groups as SeedAttackGroups, keyed by dataset (deprecated). Returns: - dict[str, list[SeedAttackGroup]]: Dictionary mapping dataset names to their - seed attack groups. + dict[str, list[SeedAttackGroup]]: Dataset name -> seed attack groups. + """ + print_deprecation_message( + old_item="DatasetConfiguration.get_seed_attack_groups", + new_item="DatasetAttackConfiguration.get_attack_groups_by_dataset_async", + removed_in=_LEGACY_REMOVED_IN, + ) + return self._get_seed_attack_groups() + + def _get_seed_attack_groups(self) -> dict[str, list[SeedAttackGroup]]: + """ + Resolve and return seed groups as SeedAttackGroups, keyed by dataset (legacy impl). - Raises: - ValueError: If no seed groups could be resolved from the configuration. + Returns: + dict[str, list[SeedAttackGroup]]: Dataset name -> seed attack groups. """ - seed_groups_by_dataset = self.get_seed_groups() result: dict[str, list[SeedAttackGroup]] = {} - for dataset_name, groups in seed_groups_by_dataset.items(): + for dataset_name, groups in self._get_seed_groups().items(): result[dataset_name] = [SeedAttackGroup(seeds=list(sg.seeds)) for sg in groups] return result def get_all_seed_attack_groups(self) -> list[SeedAttackGroup]: """ - Resolve and return all seed groups as SeedAttackGroups in a flat list. - - This is a convenience method that calls get_seed_attack_groups() and flattens - the results into a single list. Use this for attack scenarios that need - SeedAttackGroup functionality. + Resolve and return all seed groups as SeedAttackGroups in a flat list (deprecated). Returns: - list[SeedAttackGroup]: All resolved seed attack groups from all datasets. - - Raises: - ValueError: If no seed groups could be resolved from the configuration. + list[SeedAttackGroup]: All resolved seed attack groups across datasets. """ - attack_groups_by_dataset = self.get_seed_attack_groups() + print_deprecation_message( + old_item="DatasetConfiguration.get_all_seed_attack_groups", + new_item="DatasetAttackConfiguration.get_seed_attack_groups_async", + removed_in=_LEGACY_REMOVED_IN, + ) all_groups: list[SeedAttackGroup] = [] - for groups in attack_groups_by_dataset.values(): + for groups in self._get_seed_attack_groups().values(): all_groups.extend(groups) return all_groups def get_default_dataset_names(self) -> list[str]: """ - Get the list of default dataset names for this configuration. + Get the list of default dataset names for this configuration (deprecated). - This is used by the CLI to display what datasets the scenario uses by default. + Returns: + list[str]: Dataset names, or empty list if using inline seeds. + """ + print_deprecation_message( + old_item="DatasetConfiguration.get_default_dataset_names", + new_item="DatasetConfiguration.dataset_names", + removed_in=_LEGACY_REMOVED_IN, + ) + return self.dataset_names + + def get_all_seeds(self) -> list[Seed]: + """ + Load all seeds from memory for all configured datasets (deprecated). Returns: - list[str]: List of dataset names, or empty list if using explicit seed_groups. + list[Seed]: Seeds from all configured datasets (sampled per dataset). + + Raises: + ValueError: If no dataset names are configured. """ - if self._dataset_names is not None: - return list(self._dataset_names) - return [] + print_deprecation_message( + old_item="DatasetConfiguration.get_all_seeds", + new_item="DatasetAttackConfiguration.get_seed_attack_groups_async", + removed_in=_LEGACY_REMOVED_IN, + ) + if self._dataset_names is None: + raise ValueError("No dataset names configured. Set dataset_names to use get_all_seeds.") + + all_seeds: list[Seed] = [] + for dataset_name in self._dataset_names: + seeds = list(self._memory.get_seeds(dataset_name=dataset_name)) + all_seeds.extend(self._apply_max_dataset_size(seeds)) + return all_seeds + - def _apply_max_dataset_size(self, seed_groups: list[SeedGroup]) -> list[SeedGroup]: +class DatasetAttackConfiguration(DatasetConfiguration): + """ + A ``DatasetConfiguration`` that groups resolved seeds into attack groups. + + This is the default most scenarios use: scenarios run over ``SeedAttackGroup`` s + (each carrying exactly one objective plus optional prompts). Two resolvers are + provided, differing only in how ``max_dataset_size`` is applied: + + - ``get_seed_attack_groups_async`` -- a flat ``list[SeedAttackGroup]``, sampled + globally over all built groups. + - ``get_attack_groups_by_dataset_async`` -- groups keyed by dataset name (sampled + per dataset), used when a scenario fans atomic attacks out per (technique, dataset). + + Both run ``validators`` against the full resolved seed set before sampling. + + Override ``_build_attack_groups`` to change how raw seeds become attack groups + (e.g. synthesizing a per-prompt objective). The default regroups by + ``prompt_group_id`` via ``group_seeds_into_attack_groups``. + """ + + def _build_attack_groups(self, seeds: list[Seed]) -> list[SeedAttackGroup]: """ - Apply max_dataset_size sampling to a list of seed groups. + Shape raw seeds into attack groups (override seam). - Uses random sampling without replacement (no duplicates in the result). + The default regroups by ``prompt_group_id`` (construction validates each group has + exactly one objective). Override to build a custom shape. Args: - seed_groups (list[SeedGroup]): The seed groups to potentially sample from. + seeds (list[Seed]): The raw seeds to group. Returns: - list[SeedGroup]: The original list if max_dataset_size is not set, - or a random sample of up to max_dataset_size unique items. + list[SeedAttackGroup]: The built attack groups. """ - if self.max_dataset_size is None or len(seed_groups) <= self.max_dataset_size: - return seed_groups - return random.sample(seed_groups, self.max_dataset_size) + return group_seeds_into_attack_groups(seeds) - def has_data_source(self) -> bool: + def _inline_attack_groups(self) -> list[SeedAttackGroup] | None: """ - Check if this configuration has a data source configured. + Return inline attack groups when built from explicit ``seeds``/``seed_groups``. Returns: - bool: True if seed_groups or dataset_names is configured. + list[SeedAttackGroup] | None: The inline attack groups, or None when the + configuration draws from ``dataset_names``. """ - return self._seed_groups is not None or self._dataset_names is not None - - def get_all_seeds(self) -> list[Seed]: + if self._seed_groups is not None: + return [ + group if isinstance(group, SeedAttackGroup) else SeedAttackGroup(seeds=list(group.seeds)) + for group in self._seed_groups + ] + if self._seeds is not None: + return self._build_attack_groups(list(self._seeds)) + return None + + async def _build_groups_by_dataset_async(self) -> tuple[dict[str, list[SeedAttackGroup]], ResolvedDataset]: """ - Load all seed prompts from memory for all configured datasets. + Build attack groups keyed by dataset, plus the resolved seed set for validation. - This is a convenience method that retrieves SeedPrompt objects directly - from memory for all configured datasets. If max_dataset_size is set, randomly - samples up to that many prompts per dataset (without replacement). + Inline configs preserve their explicit grouping under the ``INLINE_DATASET_NAME`` label + (they are not flattened and regrouped). Named datasets reuse ``_collect_named_seeds_async`` + (auto-fetch + loud empty handling) and run each dataset's seeds through + ``_build_attack_groups``. Returns: - list[SeedPrompt]: List of SeedPrompt objects from all configured datasets. - Returns an empty list if no prompts are found. + tuple[dict[str, list[SeedAttackGroup]], ResolvedDataset]: Groups keyed by + dataset name, and the flat resolved seeds with their source kind. Raises: - ValueError: If no dataset names are configured. + DatasetConstraintError: If a configured dataset yields no seeds. """ - if self._dataset_names is None: - raise ValueError("No dataset names configured. Set dataset_names to use get_all_seed_prompts.") + inline = self._inline_attack_groups() + if inline is not None: + flattened = [seed for group in inline for seed in group.seeds] + resolved = ResolvedDataset(seeds=flattened, source_kind=self.source_kind, dataset_names=()) + return {INLINE_DATASET_NAME: inline}, resolved + + seeds_by_dataset = await self._collect_named_seeds_async() + groups_by_dataset = {name: self._build_attack_groups(seeds) for name, seeds in seeds_by_dataset.items()} + all_seeds = [seed for seeds in seeds_by_dataset.values() for seed in seeds] + resolved = ResolvedDataset( + seeds=all_seeds, + source_kind=self.source_kind, + dataset_names=tuple(seeds_by_dataset), + ) + return groups_by_dataset, resolved + + async def get_seed_attack_groups_async(self) -> list[SeedAttackGroup]: + """ + Resolve the configured dataset into a flat ``list[SeedAttackGroup]``. - memory = CentralMemory.get_memory_instance() - all_seeds: list[Seed] = [] + Builds attack groups (inline or from memory, auto-fetching missing datasets), + validates the full resolved seed set, then samples ``max_dataset_size`` globally + over all built groups. - for dataset_name in self._dataset_names: - seeds = memory.get_seeds(dataset_name=dataset_name) + Returns: + list[SeedAttackGroup]: The validated, sampled attack groups. - # Apply max_dataset_size sampling per dataset if configured - if self.max_dataset_size is not None and len(seeds) > self.max_dataset_size: - seeds = random.sample(seeds, self.max_dataset_size) - all_seeds.extend(seeds) + Raises: + DatasetConstraintError: If a configured dataset yields no seeds, the resolved + dataset fails validation, or no attack groups could be built. + """ + groups_by_dataset, resolved = await self._build_groups_by_dataset_async() + self.validate(resolved) + groups = [group for groups in groups_by_dataset.values() for group in groups] + groups = self._apply_max_dataset_size(groups) + if not groups: + names = ", ".join(self._dataset_names) if self._dataset_names else "" + raise DatasetConstraintError(f"Resolved attack-group dataset is empty (datasets: {names}).") + return groups + + async def get_attack_groups_by_dataset_async(self) -> dict[str, list[SeedAttackGroup]]: + """ + Resolve attack groups keyed by dataset name, sampled per dataset. - return all_seeds + Inline configs resolve under the ``INLINE_DATASET_NAME`` label. Builds attack + groups (auto-fetching missing datasets), validates the full resolved seed set, then + samples ``max_dataset_size`` per dataset independently. + + Returns: + dict[str, list[SeedAttackGroup]]: Dataset name -> sampled attack groups. + + Raises: + DatasetConstraintError: If a configured dataset yields no seeds, the resolved + dataset fails validation, or no attack groups could be built. + """ + groups_by_dataset, resolved = await self._build_groups_by_dataset_async() + self.validate(resolved) + result = {name: self._apply_max_dataset_size(groups) for name, groups in groups_by_dataset.items()} + result = {name: groups for name, groups in result.items() if groups} + if not result: + names = ", ".join(self._dataset_names) if self._dataset_names else "" + raise DatasetConstraintError(f"Resolved attack-group dataset is empty (datasets: {names}).") + return result diff --git a/pyrit/scenario/core/scenario.py b/pyrit/scenario/core/scenario.py index dce77c3bbd..808bed4363 100644 --- a/pyrit/scenario/core/scenario.py +++ b/pyrit/scenario/core/scenario.py @@ -50,7 +50,7 @@ from pyrit.registry import ScorerRegistry from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_scorer_target from pyrit.score import ( @@ -198,7 +198,7 @@ def __init__( version: int, strategy_class: type[ScenarioStrategy], default_strategy: ScenarioStrategy, - default_dataset_config: DatasetConfiguration, + default_dataset_config: DatasetAttackConfiguration, objective_scorer: Scorer, scenario_result_id: uuid.UUID | str | None = None, include_default_baseline: bool | None = None, # Deprecated. Will be removed in 0.16.0. @@ -213,7 +213,7 @@ def __init__( default_strategy (ScenarioStrategy): The default strategy member used when no ``scenario_strategies`` are passed to ``initialize_async``. Usually an aggregate member like ``MyStrategy.ALL`` or ``MyStrategy.DEFAULT``. - default_dataset_config (DatasetConfiguration): The default dataset configuration used + default_dataset_config (DatasetAttackConfiguration): The default dataset configuration used when no ``dataset_config`` is passed to ``initialize_async``. objective_scorer (Scorer): The objective scorer used to evaluate attack results. scenario_result_id (uuid.UUID | str | None): Optional ID of an existing scenario result to resume. @@ -359,7 +359,7 @@ def _build_display_group(self, *, technique_name: str, seed_group_name: str) -> - **Cross-product**: ``return f"{technique_name}_{seed_group_name}"`` Note: ``seed_group_name`` is the dataset key from - ``DatasetConfiguration.get_seed_attack_groups()`` (e.g. + ``DatasetAttackConfiguration.get_attack_groups_by_dataset_async()`` (e.g. ``"airt_hate"``), not a ``SeedGroup`` object. Args: @@ -582,7 +582,7 @@ async def initialize_async( *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] scenario_strategies: Sequence[ScenarioStrategy] | None = None, - dataset_config: DatasetConfiguration | None = None, + dataset_config: DatasetAttackConfiguration | None = None, max_concurrency: int = 4, max_retries: int = 0, memory_labels: dict[str, str] | None = None, @@ -604,7 +604,7 @@ async def initialize_async( scenario_strategies (Sequence[ScenarioStrategy] | None): The strategies to execute. Can be a list of ScenarioStrategy enum members. If None, uses the default aggregate from the scenario's configuration. - dataset_config (DatasetConfiguration | None): Configuration for the dataset source. + dataset_config (DatasetAttackConfiguration | None): Configuration for the dataset source. Use this to specify dataset names or maximum dataset size from the CLI. If not provided, scenarios use their constructor-supplied default_dataset_config. max_concurrency (int): Maximum number of concurrent units of work for the scenario. @@ -699,7 +699,7 @@ async def initialize_async( if self._atomic_attacks: seed_groups = self._atomic_attacks[0].seed_groups else: - seed_groups = self._dataset_config.get_all_seed_attack_groups() + seed_groups = await self._dataset_config.get_seed_attack_groups_async() self._atomic_attacks.insert(0, self._build_baseline_atomic_attack(seed_groups=seed_groups)) # Snapshot params onto the identifier before the resume branch so the identifier @@ -864,7 +864,7 @@ def _raise_dataset_exception(self) -> None: Either load the datasets into the database before running the scenario, or for example datasets, you can use the `load_default_datasets` initializer. - Required datasets: {", ".join(self._default_dataset_config.get_default_dataset_names())} + Required datasets: {", ".join(self._default_dataset_config.dataset_names)} """ ) raise ValueError(error_msg) @@ -1043,7 +1043,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: selected_techniques = {s.value for s in self._scenario_strategies} factories = self._get_attack_technique_factories() - seed_groups_by_dataset = self._dataset_config.get_seed_attack_groups() + seed_groups_by_dataset = await self._dataset_config.get_attack_groups_by_dataset_async() scoring_config = AttackScoringConfig(objective_scorer=cast("TrueFalseScorer", self._objective_scorer)) diff --git a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py index 081bfca989..f343097626 100644 --- a/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py +++ b/pyrit/scenario/scenarios/adaptive/adaptive_scenario.py @@ -40,7 +40,7 @@ from pyrit.models import SeedAttackGroup from pyrit.prompt_target import PromptTarget from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory - from pyrit.scenario.core.dataset_configuration import DatasetConfiguration + from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.score import TrueFalseScorer @@ -84,8 +84,8 @@ def get_default_strategy(cls) -> ScenarioStrategy: @classmethod @abstractmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """Return the scenario's default ``DatasetConfiguration`` (subclasses must override).""" + def default_dataset_config(cls) -> DatasetAttackConfiguration: + """Return the scenario's default ``DatasetAttackConfiguration`` (subclasses must override).""" raise NotImplementedError def __init__( @@ -189,7 +189,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: techniques = self._build_techniques_dict(objective_target=self._objective_target) - seed_groups_by_dataset = self._dataset_config.get_seed_attack_groups() + seed_groups_by_dataset = await self._dataset_config.get_attack_groups_by_dataset_async() atomic_attacks: list[AtomicAttack] = [] for dataset_name, seed_groups in seed_groups_by_dataset.items(): atomic_attacks.extend( diff --git a/pyrit/scenario/scenarios/adaptive/text_adaptive.py b/pyrit/scenario/scenarios/adaptive/text_adaptive.py index e942bd4b1e..409212306b 100644 --- a/pyrit/scenario/scenarios/adaptive/text_adaptive.py +++ b/pyrit/scenario/scenarios/adaptive/text_adaptive.py @@ -22,7 +22,7 @@ AttackTechniqueRegistry, ) from pyrit.registry.tag_query import TagQuery -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.scenarios.adaptive.adaptive_scenario import AdaptiveScenario if TYPE_CHECKING: @@ -128,9 +128,9 @@ def required_datasets(cls) -> list[str]: ] @classmethod - def default_dataset_config(cls) -> DatasetConfiguration: - """Return the default ``DatasetConfiguration`` (required datasets, capped at 4 per dataset).""" - return DatasetConfiguration(dataset_names=cls.required_datasets(), max_dataset_size=4) + def default_dataset_config(cls) -> DatasetAttackConfiguration: + """Return the default ``DatasetAttackConfiguration`` (required datasets, capped at 4 per dataset).""" + return DatasetAttackConfiguration(dataset_names=cls.required_datasets(), max_dataset_size=4) @classmethod def supported_parameters(cls) -> list[Parameter]: diff --git a/pyrit/scenario/scenarios/airt/cyber.py b/pyrit/scenario/scenarios/airt/cyber.py index b4d4299af7..8da393b8af 100644 --- a/pyrit/scenario/scenarios/airt/cyber.py +++ b/pyrit/scenario/scenarios/airt/cyber.py @@ -10,7 +10,7 @@ from pyrit.common import apply_defaults from pyrit.common.deprecation import print_deprecation_message # Deprecated. Will be removed in 0.16.0. from pyrit.common.path import SCORER_SEED_PROMPT_PATH -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.core.scenario import Scenario if TYPE_CHECKING: @@ -103,7 +103,7 @@ def __init__( objective_scorer=self._objective_scorer, strategy_class=strategy_class, default_strategy=strategy_class("all"), - default_dataset_config=DatasetConfiguration(dataset_names=["airt_malware"], max_dataset_size=4), + default_dataset_config=DatasetAttackConfiguration(dataset_names=["airt_malware"], max_dataset_size=4), scenario_result_id=scenario_result_id, ) diff --git a/pyrit/scenario/scenarios/airt/jailbreak.py b/pyrit/scenario/scenarios/airt/jailbreak.py index 5184632d49..883f2f4441 100644 --- a/pyrit/scenario/scenarios/airt/jailbreak.py +++ b/pyrit/scenario/scenarios/airt/jailbreak.py @@ -22,7 +22,10 @@ from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import ( + DatasetAttackConfiguration, + DatasetConstraintError, +) from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_adversarial_target @@ -154,7 +157,7 @@ def __init__( version=self.VERSION, strategy_class=JailbreakStrategy, default_strategy=JailbreakStrategy.SIMPLE, - default_dataset_config=DatasetConfiguration(dataset_names=["airt_harms"], max_dataset_size=4), + default_dataset_config=DatasetAttackConfiguration(dataset_names=["airt_harms"], max_dataset_size=4), objective_scorer=self._objective_scorer, scenario_result_id=scenario_result_id, ) @@ -186,15 +189,20 @@ def _get_or_create_adversarial_target(self) -> PromptTarget: self._adversarial_target = get_default_adversarial_target() return self._adversarial_target - def _resolve_seed_groups(self) -> list[SeedAttackGroup]: + async def _resolve_seed_groups_async(self) -> list[SeedAttackGroup]: """ Resolve seed groups from dataset configuration. Returns: list[SeedAttackGroup]: List of seed attack groups with objectives to be tested. """ - # Use dataset_config (guaranteed to be set by initialize_async) - seed_groups = self._dataset_config.get_all_seed_attack_groups() + # Use dataset_config (guaranteed to be set by initialize_async). Auto-fetch + # populates memory first; a still-empty result raises loudly, which we translate + # into the scenario's friendly "dataset not available" message. + try: + seed_groups = await self._dataset_config.get_seed_attack_groups_async() + except DatasetConstraintError: + seed_groups = [] if not seed_groups: self._raise_dataset_exception() @@ -279,7 +287,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: atomic_attacks: list[AtomicAttack] = [] # Retrieve seed prompts based on selected strategies - self._seed_groups = self._resolve_seed_groups() + self._seed_groups = await self._resolve_seed_groups_async() strategies = {s.value for s in self._scenario_strategies} diff --git a/pyrit/scenario/scenarios/airt/leakage.py b/pyrit/scenario/scenarios/airt/leakage.py index e974973354..db44ac6a5d 100644 --- a/pyrit/scenario/scenarios/airt/leakage.py +++ b/pyrit/scenario/scenarios/airt/leakage.py @@ -20,7 +20,7 @@ ) from pyrit.registry.tag_query import TagQuery from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioStrategy @@ -139,7 +139,7 @@ def __init__( version=self.VERSION, strategy_class=strategy_class, default_strategy=strategy_class("default"), - default_dataset_config=DatasetConfiguration(dataset_names=["airt_leakage"], max_dataset_size=4), + default_dataset_config=DatasetAttackConfiguration(dataset_names=["airt_leakage"], max_dataset_size=4), objective_scorer=objective_scorer, scenario_result_id=scenario_result_id, ) diff --git a/pyrit/scenario/scenarios/airt/psychosocial.py b/pyrit/scenario/scenarios/airt/psychosocial.py index 19a20fd4a8..27907b3765 100644 --- a/pyrit/scenario/scenarios/airt/psychosocial.py +++ b/pyrit/scenario/scenarios/airt/psychosocial.py @@ -30,7 +30,10 @@ from pyrit.prompt_target.common.target_requirements import CHAT_TARGET_REQUIREMENTS, TargetRequirements from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import ( + DatasetAttackConfiguration, + DatasetConstraintError, +) from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ( ScenarioStrategy, @@ -237,7 +240,9 @@ def __init__( version=self.VERSION, strategy_class=PsychosocialStrategy, default_strategy=PsychosocialStrategy.ALL, - default_dataset_config=DatasetConfiguration(dataset_names=["airt_imminent_crisis"], max_dataset_size=4), + default_dataset_config=DatasetAttackConfiguration( + dataset_names=["airt_imminent_crisis"], max_dataset_size=4 + ), objective_scorer=self._objective_scorer, scenario_result_id=scenario_result_id, ) @@ -252,12 +257,12 @@ def __init__( ) self._legacy_include_baseline = include_baseline - # Store deprecated objectives for later resolution in _resolve_seed_groups + # Store deprecated objectives for later resolution in _resolve_seed_groups_async self._deprecated_objectives = objectives # Will be resolved in _get_atomic_attacks_async self._seed_groups: list[SeedAttackGroup] | None = None - def _resolve_seed_groups(self) -> ResolvedSeedData: + async def _resolve_seed_groups_async(self) -> ResolvedSeedData: """ Resolve seed groups from deprecated objectives or dataset configuration. @@ -280,7 +285,12 @@ def _resolve_seed_groups(self) -> ResolvedSeedData: ) harm_category_filter = self._extract_harm_category_filter() - seed_groups = self._dataset_config.get_all_seed_attack_groups() + # Auto-fetch populates memory first; a still-empty result raises loudly, which we + # translate into the scenario's friendly "dataset not available" message. + try: + seed_groups = await self._dataset_config.get_seed_attack_groups_async() + except DatasetConstraintError: + seed_groups = [] if harm_category_filter: seed_groups = self._filter_by_harm_category( @@ -405,7 +415,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: f"conversations with editable history. Target {type(self._objective_target).__name__} " f"does not satisfy these requirements: {exc}" ) from exc - resolved = self._resolve_seed_groups() + resolved = await self._resolve_seed_groups_async() self._seed_groups = resolved.seed_groups scoring_config = self._create_scoring_config(resolved.subharm) diff --git a/pyrit/scenario/scenarios/airt/rapid_response.py b/pyrit/scenario/scenarios/airt/rapid_response.py index 5f83c47dad..b49b0b4247 100644 --- a/pyrit/scenario/scenarios/airt/rapid_response.py +++ b/pyrit/scenario/scenarios/airt/rapid_response.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING from pyrit.common import apply_defaults -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.core.scenario import Scenario if TYPE_CHECKING: @@ -93,7 +93,7 @@ def __init__( objective_scorer=self._objective_scorer, strategy_class=strategy_class, default_strategy=strategy_class("default"), - default_dataset_config=DatasetConfiguration( + default_dataset_config=DatasetAttackConfiguration( dataset_names=[ "airt_hate", "airt_fairness", diff --git a/pyrit/scenario/scenarios/airt/scam.py b/pyrit/scenario/scenarios/airt/scam.py index e591f580c8..4d92abdead 100644 --- a/pyrit/scenario/scenarios/airt/scam.py +++ b/pyrit/scenario/scenarios/airt/scam.py @@ -25,7 +25,10 @@ from pyrit.prompt_target import PromptTarget from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import ( + DatasetAttackConfiguration, + DatasetConstraintError, +) from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_adversarial_target @@ -150,7 +153,7 @@ def __init__( version=self.VERSION, strategy_class=ScamStrategy, default_strategy=ScamStrategy.ALL, - default_dataset_config=DatasetConfiguration(dataset_names=["airt_scams"], max_dataset_size=4), + default_dataset_config=DatasetAttackConfiguration(dataset_names=["airt_scams"], max_dataset_size=4), objective_scorer=objective_scorer, scenario_result_id=scenario_result_id, ) @@ -168,15 +171,20 @@ def __init__( # Will be resolved in _get_atomic_attacks_async self._seed_groups: list[SeedAttackGroup] | None = None - def _resolve_seed_groups(self) -> list[SeedAttackGroup]: + async def _resolve_seed_groups_async(self) -> list[SeedAttackGroup]: """ Resolve seed groups from dataset configuration. Returns: list[SeedAttackGroup]: List of seed attack groups with objectives to be tested. """ - # Use dataset_config (guaranteed to be set by initialize_async) - seed_groups = self._dataset_config.get_all_seed_attack_groups() + # Use dataset_config (guaranteed to be set by initialize_async). Auto-fetch + # populates memory first; a still-empty result raises loudly, which we translate + # into the scenario's friendly "dataset not available" message. + try: + seed_groups = await self._dataset_config.get_seed_attack_groups_async() + except DatasetConstraintError: + seed_groups = [] if not seed_groups: self._raise_dataset_exception() @@ -249,7 +257,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: list[AtomicAttack]: List of atomic attacks to execute. """ # Resolve seed groups from deprecated objectives or dataset config - self._seed_groups = self._resolve_seed_groups() + self._seed_groups = await self._resolve_seed_groups_async() strategies = {s.value for s in self._scenario_strategies} diff --git a/pyrit/scenario/scenarios/benchmark/adversarial.py b/pyrit/scenario/scenarios/benchmark/adversarial.py index 0c0b4f6fb2..e5b8b839fd 100644 --- a/pyrit/scenario/scenarios/benchmark/adversarial.py +++ b/pyrit/scenario/scenarios/benchmark/adversarial.py @@ -22,7 +22,7 @@ from pyrit.registry import AttackTechniqueRegistry, TargetRegistry from pyrit.registry.tag_query import TagQuery from pyrit.scenario.core.atomic_attack import AtomicAttack -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.core.scenario import BaselineAttackPolicy, Scenario if TYPE_CHECKING: @@ -186,7 +186,7 @@ def __init__( objective_scorer=self._objective_scorer, strategy_class=strategy_class, default_strategy=strategy_class("light"), - default_dataset_config=DatasetConfiguration( + default_dataset_config=DatasetAttackConfiguration( dataset_names=["harmbench"], max_dataset_size=8, ), @@ -237,7 +237,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: selected_factories = [all_factories[s.value] for s in self._scenario_strategies if s.value in all_factories] scoring_config = AttackScoringConfig(objective_scorer=self._objective_scorer) - seed_groups_by_dataset = self._dataset_config.get_seed_attack_groups() + seed_groups_by_dataset = await self._dataset_config.get_attack_groups_by_dataset_async() atomic_attacks: list[AtomicAttack] = [] for factory in selected_factories: diff --git a/pyrit/scenario/scenarios/foundry/red_team_agent.py b/pyrit/scenario/scenarios/foundry/red_team_agent.py index 9096fcc9f2..133d4dec1c 100644 --- a/pyrit/scenario/scenarios/foundry/red_team_agent.py +++ b/pyrit/scenario/scenarios/foundry/red_team_agent.py @@ -62,7 +62,7 @@ from pyrit.prompt_target import PromptTarget from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import DatasetAttackConfiguration from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioCompositeStrategy, ScenarioStrategy from pyrit.scenario.core.scenario_target_defaults import get_default_adversarial_target @@ -258,7 +258,7 @@ def __init__( version=self.VERSION, strategy_class=FoundryStrategy, default_strategy=FoundryStrategy.EASY, - default_dataset_config=DatasetConfiguration(dataset_names=["harmbench"], max_dataset_size=4), + default_dataset_config=DatasetAttackConfiguration(dataset_names=["harmbench"], max_dataset_size=4), objective_scorer=objective_scorer, scenario_result_id=scenario_result_id, ) @@ -281,7 +281,7 @@ async def initialize_async( *, objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] scenario_strategies: Sequence["FoundryStrategy | FoundryComposite | ScenarioCompositeStrategy"] | None = None, - dataset_config: DatasetConfiguration | None = None, + dataset_config: DatasetAttackConfiguration | None = None, max_concurrency: int = 4, max_retries: int = 0, memory_labels: dict[str, str] | None = None, @@ -297,7 +297,7 @@ async def initialize_async( objects (for pairing an attack with converters), or a mix of both. Passing ScenarioCompositeStrategy is deprecated — use FoundryComposite instead. If None, uses the default aggregate (EASY). - dataset_config (DatasetConfiguration | None): Configuration for the dataset source. + dataset_config (DatasetAttackConfiguration | None): Configuration for the dataset source. max_concurrency (int): Maximum number of concurrent attack executions. Defaults to 4. max_retries (int): Maximum number of retries on failure. Defaults to 0. memory_labels (dict[str, str] | None): Labels to attach to all memory entries. @@ -389,14 +389,14 @@ def _strategy_to_composite(strategy: ScenarioStrategy) -> "FoundryComposite": return FoundryComposite(attack=strategy) return FoundryComposite(attack=None, converters=[strategy]) - def _resolve_seed_groups(self) -> list[SeedAttackGroup]: + async def _resolve_seed_groups_async(self) -> list[SeedAttackGroup]: """ Resolve seed groups from the dataset configuration. Returns: list[SeedGroup]: The resolved seed groups. """ - return self._dataset_config.get_all_seed_attack_groups() + return await self._dataset_config.get_seed_attack_groups_async() async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: """ @@ -406,7 +406,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: list[AtomicAttack]: The list of AtomicAttack instances in this scenario. """ # Resolve seed groups now that initialize_async has been called - self._seed_groups = self._resolve_seed_groups() + self._seed_groups = await self._resolve_seed_groups_async() atomic_attacks = [self._get_attack_from_strategy(composition) for composition in self._scenario_composites] diff --git a/pyrit/scenario/scenarios/garak/encoding.py b/pyrit/scenario/scenarios/garak/encoding.py index abe36b7ca6..d79b9bac8c 100644 --- a/pyrit/scenario/scenarios/garak/encoding.py +++ b/pyrit/scenario/scenarios/garak/encoding.py @@ -12,7 +12,7 @@ AttackScoringConfig, ) from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack -from pyrit.models import SeedAttackGroup, SeedObjective, SeedPrompt +from pyrit.models import Seed, SeedAttackGroup, SeedObjective, SeedPrompt from pyrit.prompt_converter import ( AsciiSmugglerConverter, AskToDecodeConverter, @@ -34,48 +34,50 @@ ) from pyrit.scenario.core.atomic_attack import AtomicAttack from pyrit.scenario.core.attack_technique import AttackTechnique -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import ( + DatasetAttackConfiguration, + DatasetConstraintError, +) from pyrit.scenario.core.scenario import Scenario from pyrit.scenario.core.scenario_strategy import ScenarioStrategy from pyrit.score import TrueFalseScorer from pyrit.score.true_false.decoding_scorer import DecodingScorer -class EncodingDatasetConfiguration(DatasetConfiguration): +class EncodingDatasetConfiguration(DatasetAttackConfiguration): """ Custom dataset configuration for the Encoding scenario. This configuration transforms each seed from the dataset into a SeedAttackGroup - with a properly formatted objective for encoding attacks. + with a properly formatted objective for encoding attacks. It customizes only the + grouping step (``_build_attack_groups``); memory querying, auto-fetch, sampling, + and the loud "dataset not found" behavior are inherited from + ``DatasetAttackConfiguration``. """ - def get_all_seed_attack_groups(self) -> list[SeedAttackGroup]: + def _build_attack_groups(self, seeds: list[Seed]) -> list[SeedAttackGroup]: """ - Resolve and return all seeds as SeedAttackGroups with encoding objectives. + Shape raw seeds into SeedAttackGroups carrying an encoding objective. Each seed is transformed into a SeedAttackGroup containing: - A SeedObjective with the encoding attack objective - - The original seed as a SeedPrompt + - The original seed value as a SeedPrompt - Returns: - list[SeedAttackGroup]: All resolved seed attack groups with objectives. + Args: + seeds (list[Seed]): The raw seeds loaded from memory. - Raises: - ValueError: If no seeds could be resolved from the configuration. + Returns: + list[SeedAttackGroup]: One attack group per seed, each with an objective. """ - all_seeds = self.get_all_seeds() - - if not all_seeds: - raise ValueError("No seeds found in the configured datasets.") - - seed_attack_groups: list[SeedAttackGroup] = [] - for seed in all_seeds: - objective = f"Make the model say an encoded payload: {seed.value}" - seed_attack_groups.append( - SeedAttackGroup(seeds=[SeedObjective(value=objective), SeedPrompt(value=seed.value)]) + return [ + SeedAttackGroup( + seeds=[ + SeedObjective(value=f"Make the model say an encoded payload: {seed.value}"), + SeedPrompt(value=seed.value), + ] ) - - return seed_attack_groups + for seed in seeds + ] class EncodingStrategy(ScenarioStrategy): @@ -185,15 +187,21 @@ def __init__( # Will be resolved in _get_atomic_attacks_async self._resolved_seed_groups: list[SeedAttackGroup] | None = None - def _resolve_seed_groups(self) -> list[SeedAttackGroup]: + async def _resolve_seed_groups_async(self) -> list[SeedAttackGroup]: """ Resolve seed groups from dataset configuration. Returns: list[SeedAttackGroup]: List of seed attack groups to be encoded and tested. """ - # Use dataset_config (guaranteed to be set by initialize_async) - seed_groups = self._dataset_config.get_all_seed_attack_groups() + # Use dataset_config (guaranteed to be set by initialize_async). The configured + # EncodingDatasetConfiguration shapes raw seeds into objective-bearing attack + # groups via its _build_attack_groups override; auto-fetch populates memory first + # when the configured datasets aren't present. + try: + seed_groups = await self._dataset_config.get_seed_attack_groups_async() + except DatasetConstraintError: + seed_groups = [] if not seed_groups: self._raise_dataset_exception() @@ -208,7 +216,7 @@ async def _get_atomic_attacks_async(self) -> list[AtomicAttack]: list[AtomicAttack]: The list of AtomicAttack instances in this scenario. """ # Resolve seed prompts from deprecated parameter or dataset config - self._resolved_seed_groups = self._resolve_seed_groups() + self._resolved_seed_groups = await self._resolve_seed_groups_async() atomic_attacks = self._get_converter_attacks() diff --git a/tests/unit/backend/test_scenario_run_service.py b/tests/unit/backend/test_scenario_run_service.py index 15116a0ac9..7d7448bac0 100644 --- a/tests/unit/backend/test_scenario_run_service.py +++ b/tests/unit/backend/test_scenario_run_service.py @@ -21,7 +21,7 @@ ScenarioRunService, ) from pyrit.models import AttackOutcome -from pyrit.scenario.core import DatasetConfiguration +from pyrit.scenario.core import DatasetAttackConfiguration, DatasetConfiguration _REGISTRY_PATCH_BASE = "pyrit.registry" _MEMORY_PATCH = "pyrit.memory.CentralMemory.get_memory_instance" @@ -308,10 +308,10 @@ class _MarkerDatasetConfiguration(DatasetConfiguration): # Type is preserved (this is the regression assertion) assert type(built_config) is _MarkerDatasetConfiguration # And carries the caller-supplied values, not the scenario defaults - assert built_config.get_default_dataset_names() == ["custom_a", "custom_b"] + assert built_config.dataset_names == ["custom_a", "custom_b"] assert built_config.max_dataset_size == 3 # The original default config is not mutated when a fresh dataset_names is supplied - assert default_config.get_default_dataset_names() == ["original"] + assert default_config.dataset_names == ["original"] assert default_config.max_dataset_size == 100 async def test_start_run_dataset_names_without_max_dataset_size_preserves_subclass( @@ -331,7 +331,7 @@ class _MarkerDatasetConfiguration(DatasetConfiguration): init_call = scenario_instance.initialize_async.await_args built_config = init_call.kwargs["dataset_config"] assert type(built_config) is _MarkerDatasetConfiguration - assert built_config.get_default_dataset_names() == ["only_this"] + assert built_config.dataset_names == ["only_this"] assert built_config.max_dataset_size is None async def test_start_run_dataset_names_falls_back_when_subclass_constructor_incompatible( @@ -358,12 +358,12 @@ def __init__(self, *, required_extra: str, **kwargs: Any) -> None: built_config = init_call.kwargs["dataset_config"] # Fallback is the generic base class, not the subclass - assert type(built_config) is DatasetConfiguration - assert built_config.get_default_dataset_names() == ["custom"] + assert type(built_config) is DatasetAttackConfiguration + assert built_config.dataset_names == ["custom"] # Warning was logged so the operator can see the silent degradation assert any( "_RequiresExtraArgConfiguration" in record.message - and "Falling back to a generic DatasetConfiguration" in record.message + and "Falling back to a generic DatasetAttackConfiguration" in record.message for record in caplog.records ) @@ -412,7 +412,7 @@ class _MarkerDatasetConfiguration(DatasetConfiguration): built_config = scenario_instance.initialize_async.await_args.kwargs["dataset_config"] assert type(built_config) is _MarkerDatasetConfiguration - assert built_config.get_default_dataset_names() == ["a", "b"] + assert built_config.dataset_names == ["a", "b"] assert built_config.max_dataset_size == 7 async def test_start_run_exceeds_concurrent_limit(self, mock_all_registries) -> None: diff --git a/tests/unit/models/test_seed_grouping.py b/tests/unit/models/test_seed_grouping.py new file mode 100644 index 0000000000..7ef0d58436 --- /dev/null +++ b/tests/unit/models/test_seed_grouping.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for group_seeds_into_attack_groups.""" + +import uuid + +import pytest + +from pyrit.models import ( + SeedAttackGroup, + SeedObjective, + SeedPrompt, + group_seeds_into_attack_groups, +) + + +def test_empty_returns_empty(): + assert group_seeds_into_attack_groups([]) == [] + + +def test_groups_seeds_sharing_prompt_group_id(): + group_id = uuid.uuid4() + seeds = [ + SeedObjective(value="objective", prompt_group_id=group_id), + SeedPrompt(value="prompt", prompt_group_id=group_id), + ] + + result = group_seeds_into_attack_groups(seeds) + + assert len(result) == 1 + assert isinstance(result[0], SeedAttackGroup) + assert len(result[0].seeds) == 2 + assert result[0].objective.value == "objective" + + +def test_distinct_group_ids_become_distinct_groups(): + seeds = [ + SeedObjective(value="obj-a", prompt_group_id=uuid.uuid4()), + SeedObjective(value="obj-b", prompt_group_id=uuid.uuid4()), + ] + + result = group_seeds_into_attack_groups(seeds) + + assert len(result) == 2 + assert {g.objective.value for g in result} == {"obj-a", "obj-b"} + + +def test_ungrouped_objective_seeds_each_form_own_group(): + seeds = [ + SeedObjective(value="obj-1"), + SeedObjective(value="obj-2"), + ] + + result = group_seeds_into_attack_groups(seeds) + + assert len(result) == 2 + assert all(len(g.seeds) == 1 for g in result) + + +def test_orders_group_members_by_sequence(): + group_id = uuid.uuid4() + seeds = [ + SeedPrompt(value="second", prompt_group_id=group_id, sequence=2, role="user"), + SeedObjective(value="objective", prompt_group_id=group_id), + SeedPrompt(value="first", prompt_group_id=group_id, sequence=1, role="user"), + ] + + result = group_seeds_into_attack_groups(seeds) + + assert len(result) == 1 + prompt_values = [s.value for s in result[0].prompts] + assert prompt_values == ["first", "second"] + + +def test_raises_when_group_has_no_objective(): + seeds = [SeedPrompt(value="prompt-only")] + + with pytest.raises(ValueError): + group_seeds_into_attack_groups(seeds) + + +def test_raises_when_group_has_multiple_objectives(): + group_id = uuid.uuid4() + seeds = [ + SeedObjective(value="obj-1", prompt_group_id=group_id), + SeedObjective(value="obj-2", prompt_group_id=group_id), + ] + + with pytest.raises(ValueError): + group_seeds_into_attack_groups(seeds) diff --git a/tests/unit/scenario/airt/test_cyber.py b/tests/unit/scenario/airt/test_cyber.py index de86aa365e..c85a2b5f4f 100644 --- a/tests/unit/scenario/airt/test_cyber.py +++ b/tests/unit/scenario/airt/test_cyber.py @@ -3,7 +3,7 @@ """Tests for the Cyber scenario (refactored to technique registry pattern).""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -11,7 +11,10 @@ from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedObjective, SeedPrompt from pyrit.prompt_target import PromptTarget from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import ( + DatasetAttackConfiguration, + DatasetConfiguration, +) from pyrit.scenario.scenarios.airt.cyber import Cyber from pyrit.score import TrueFalseScorer from pyrit.setup.initializers.components.scenario_techniques import ( @@ -140,7 +143,7 @@ def test_get_default_strategy_returns_all(self): def test_default_dataset_config_has_malware_dataset(self): config = Cyber()._default_dataset_config assert isinstance(config, DatasetConfiguration) - names = config.get_default_dataset_names() + names = config.dataset_names assert "airt_malware" in names assert len(names) == 1 @@ -161,7 +164,10 @@ def test_scenario_name_is_cyber(self, mock_objective_scorer): assert scenario.name == "Cyber" @patch.object( - DatasetConfiguration, "get_seed_attack_groups", return_value={"malware": _make_seed_groups("malware")} + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value={"malware": _make_seed_groups("malware")}, ) async def test_initialization_defaults_to_all_strategy( self, @@ -179,11 +185,20 @@ async def test_initialization_defaults_to_all_strategy( async def test_initialize_raises_when_no_datasets(self, mock_objective_target, mock_objective_scorer): """Dataset resolution fails from empty memory.""" scenario = Cyber(objective_scorer=mock_objective_scorer) - with pytest.raises(ValueError, match="DatasetConfiguration has no seed_groups"): - await scenario.initialize_async(objective_target=mock_objective_target) + # Neutralize the provider fetch so the empty-memory path raises loudly instead of fetching + # the real default dataset from the provider. + with patch( + "pyrit.scenario.core.dataset_configuration.DatasetConfiguration._fetch_dataset_async", + new_callable=AsyncMock, + ): + with pytest.raises(ValueError, match="could not be loaded"): + await scenario.initialize_async(objective_target=mock_objective_target) @patch.object( - DatasetConfiguration, "get_seed_attack_groups", return_value={"malware": _make_seed_groups("malware")} + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value={"malware": _make_seed_groups("malware")}, ) async def test_memory_labels_stored( self, @@ -197,7 +212,10 @@ async def test_memory_labels_stored( assert scenario._memory_labels == labels @patch.object( - DatasetConfiguration, "get_seed_attack_groups", return_value={"malware": _make_seed_groups("malware")} + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value={"malware": _make_seed_groups("malware")}, ) async def test_initialize_async_with_max_concurrency( self, @@ -229,7 +247,12 @@ async def _init_and_get_attacks( ): """Helper: initialize scenario and return atomic attacks.""" groups = seed_groups or {"malware": _make_seed_groups("malware")} - with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + with patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ): scenario = Cyber(objective_scorer=mock_objective_scorer) init_kwargs = {"objective_target": mock_objective_target, "include_baseline": False} if strategies: diff --git a/tests/unit/scenario/airt/test_jailbreak.py b/tests/unit/scenario/airt/test_jailbreak.py index db07b40c0f..542538a214 100644 --- a/tests/unit/scenario/airt/test_jailbreak.py +++ b/tests/unit/scenario/airt/test_jailbreak.py @@ -3,7 +3,7 @@ """Tests for the Jailbreak class.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -146,31 +146,41 @@ class TestJailbreakInitialization: def test_init_with_scenario_result_id(self, mock_scenario_result_id): """Test initialization with a scenario result ID.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(scenario_result_id=mock_scenario_result_id) assert scenario._scenario_result_id == mock_scenario_result_id def test_init_with_default_scorer(self, mock_memory_seed_groups): """Test initialization with default scorer.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak() assert scenario._objective_scorer_identifier def test_init_with_custom_scorer(self, mock_objective_scorer, mock_memory_seed_groups): """Test initialization with custom scorer.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer) assert scenario._objective_scorer == mock_objective_scorer def test_init_with_num_templates(self, mock_random_num_templates): """Test initialization with num_templates provided.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(num_templates=mock_random_num_templates) assert scenario._num_templates == mock_random_num_templates def test_init_with_num_attempts(self, mock_random_num_attempts): """Test initialization with n provided.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(num_attempts=mock_random_num_attempts) assert scenario._num_attempts == mock_random_num_attempts @@ -189,7 +199,9 @@ def test_init_accepts_subdirectory_jailbreak_names(self, mock_objective_scorer, assert subdir_templates, "Expected at least one subdirectory template to exist" subdir_name = subdir_templates[0] - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, jailbreak_names=[subdir_name]) assert scenario._jailbreaks == [subdir_name] @@ -198,9 +210,14 @@ async def test_init_raises_exception_when_no_datasets_available(self, mock_objec # Don't mock _resolve_seed_groups, let it try to load from empty memory scenario = Jailbreak(objective_scorer=mock_objective_scorer) - # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed groups - with pytest.raises(ValueError, match="DatasetConfiguration has no seed_groups"): - await scenario.initialize_async(objective_target=mock_objective_target) + # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed groups. + # Neutralize the provider fetch so the empty-memory path raises loudly instead of fetching. + with patch( + "pyrit.scenario.core.dataset_configuration.DatasetConfiguration._fetch_dataset_async", + new_callable=AsyncMock, + ): + with pytest.raises(ValueError, match="Dataset is not available or failed to load"): + await scenario.initialize_async(objective_target=mock_objective_target) def test_class_inherits_default_baseline_attack_policy(self): """Jailbreak inherits the base default (Enabled) — baseline included by default.""" @@ -210,7 +227,9 @@ async def test_default_initialize_includes_baseline( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups ): """initialize_async without include_baseline honors BASELINE_ATTACK_POLICY=Enabled.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer) await scenario.initialize_async(objective_target=mock_objective_target) assert scenario._atomic_attacks[0].atomic_attack_name == "baseline" @@ -219,7 +238,9 @@ async def test_explicit_include_baseline_false_omits_baseline( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups ): """Caller can opt out of baseline by passing include_baseline=False.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer) await scenario.initialize_async( objective_target=mock_objective_target, @@ -236,7 +257,9 @@ async def test_attack_generation_for_simple( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, simple_jailbreak_strategy ): """Test that the simple attack generation works.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await scenario.initialize_async( @@ -250,7 +273,9 @@ async def test_attack_generation_for_complex( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, complex_jailbreak_strategy ): """Test that the complex attack generation works.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await scenario.initialize_async( @@ -268,7 +293,9 @@ async def test_attack_generation_for_manyshot( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, manyshot_jailbreak_strategy ): """Test that the manyshot attack generation works.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await scenario.initialize_async( @@ -284,7 +311,9 @@ async def test_attack_generation_for_promptsending( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, promptsending_jailbreak_strategy ): """Test that the prompt sending attack generation works.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await scenario.initialize_async( @@ -300,7 +329,9 @@ async def test_attack_generation_for_skeleton( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, skeleton_jailbreak_attack ): """Test that the skelton key attack generation works.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await scenario.initialize_async( @@ -316,7 +347,9 @@ async def test_attack_generation_for_roleplay( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, roleplay_jailbreak_strategy ): """Test that the roleplaying attack generation works.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await scenario.initialize_async( @@ -336,7 +369,9 @@ async def test_attack_runs_include_objectives( Combined coverage previously split across test_get_atomic_attacks_async_returns_attacks. """ - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await scenario.initialize_async(objective_target=mock_objective_target) @@ -351,7 +386,9 @@ async def test_get_all_jailbreak_templates( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups ): """Test that all jailbreak templates are found.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak( objective_scorer=mock_objective_scorer, ) @@ -362,7 +399,9 @@ async def test_get_some_jailbreak_templates( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_random_num_templates ): """Test that random jailbreak template selection works.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=mock_random_num_templates) await scenario.initialize_async(objective_target=mock_objective_target) assert len(scenario._jailbreaks) == mock_random_num_templates @@ -371,7 +410,9 @@ async def test_custom_num_attempts( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_random_num_attempts ): """Test that n successfully tries each jailbreak template n-many times.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): base_scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await base_scenario.initialize_async(objective_target=mock_objective_target, include_baseline=False) atomic_attacks_1 = await base_scenario._get_atomic_attacks_async() @@ -399,7 +440,9 @@ async def test_initialize_async_with_max_concurrency( mock_memory_seed_groups: list[SeedAttackGroup], ) -> None: """Test initialization with custom max_concurrency.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer) await scenario.initialize_async(objective_target=mock_objective_target, max_concurrency=20) assert scenario._max_concurrency == 20 @@ -414,7 +457,9 @@ async def test_initialize_async_with_memory_labels( """Test initialization with memory labels.""" memory_labels = {"type": "jailbreak", "category": "scenario"} - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer) await scenario.initialize_async( memory_labels=memory_labels, @@ -448,7 +493,9 @@ async def test_no_target_duplication_async( self, *, mock_objective_target: PromptTarget, mock_memory_seed_groups: list[SeedAttackGroup] ) -> None: """Test that all three targets (adversarial, object, scorer) are distinct.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak() await scenario.initialize_async(objective_target=mock_objective_target) @@ -493,7 +540,9 @@ async def test_roleplay_attacks_share_adversarial_target( roleplay_jailbreak_strategy: JailbreakStrategy, ) -> None: """Test that multiple role-play attacks share the same adversarial target instance.""" - with patch.object(Jailbreak, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Jailbreak, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Jailbreak(objective_scorer=mock_objective_scorer, num_templates=2) await scenario.initialize_async( objective_target=mock_objective_target, @@ -515,11 +564,11 @@ class TestJailbreakBaselineUniformity: async def test_one_resolution_call_baseline_matches_strategies( self, mock_objective_target, mock_objective_scorer, simple_jailbreak_strategy ): - from pyrit.models import SeedGroup, SeedObjective - from pyrit.scenario import DatasetConfiguration + from pyrit.models import SeedAttackGroup, SeedObjective + from pyrit.scenario import DatasetAttackConfiguration - seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] - config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + seed_groups = [SeedAttackGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] + config = DatasetAttackConfiguration(seed_groups=seed_groups, max_dataset_size=3) first_sample = seed_groups[:3] second_sample = seed_groups[5:8] diff --git a/tests/unit/scenario/airt/test_leakage.py b/tests/unit/scenario/airt/test_leakage.py index 77a791e41b..db96010d59 100644 --- a/tests/unit/scenario/airt/test_leakage.py +++ b/tests/unit/scenario/airt/test_leakage.py @@ -4,7 +4,7 @@ """Tests for the Leakage class.""" import pathlib -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -13,7 +13,7 @@ from pyrit.prompt_target import PromptTarget from pyrit.registry import TargetRegistry from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry -from pyrit.scenario import DatasetConfiguration +from pyrit.scenario import DatasetAttackConfiguration from pyrit.scenario.airt import Leakage # type: ignore[ty:unresolved-import] from pyrit.scenario.core import BaselineAttackPolicy from pyrit.scenario.scenarios.airt.leakage import _build_leakage_strategy @@ -48,11 +48,9 @@ def mock_memory_seeds(): def mock_dataset_config(mock_memory_seeds): """Create a mock dataset config that returns the seed groups.""" seed_groups = [SeedAttackGroup(seeds=[seed]) for seed in mock_memory_seeds] - mock_config = MagicMock(spec=DatasetConfiguration) - mock_config.get_all_seed_attack_groups.return_value = seed_groups - mock_config.get_seed_attack_groups.return_value = {"airt_leakage": seed_groups} - mock_config.get_default_dataset_names.return_value = ["airt_leakage"] - mock_config.has_data_source.return_value = True + mock_config = MagicMock(spec=DatasetAttackConfiguration) + mock_config.get_attack_groups_by_dataset_async = AsyncMock(return_value={"airt_leakage": seed_groups}) + mock_config.dataset_names = ["airt_leakage"] return mock_config diff --git a/tests/unit/scenario/airt/test_psychosocial.py b/tests/unit/scenario/airt/test_psychosocial.py index 35400f88e9..c5206d2b83 100644 --- a/tests/unit/scenario/airt/test_psychosocial.py +++ b/tests/unit/scenario/airt/test_psychosocial.py @@ -3,7 +3,7 @@ """Tests for the Psychosocial class.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -36,12 +36,11 @@ def mock_resolved_seed_data(mock_memory_seed_groups) -> ResolvedSeedData: @pytest.fixture def mock_dataset_config(mock_memory_seed_groups): """Create a mock dataset config that returns the seed groups.""" - from pyrit.scenario import DatasetConfiguration + from pyrit.scenario import DatasetAttackConfiguration - mock_config = MagicMock(spec=DatasetConfiguration) - mock_config.get_all_seed_attack_groups.return_value = mock_memory_seed_groups - mock_config.get_default_dataset_names.return_value = ["airt_psychosocial"] - mock_config.has_data_source.return_value = True + mock_config = MagicMock(spec=DatasetAttackConfiguration) + mock_config.get_seed_attack_groups_async = AsyncMock(return_value=mock_memory_seed_groups) + mock_config.dataset_names = ["airt_psychosocial"] return mock_config @@ -164,9 +163,14 @@ async def test_init_raises_exception_when_no_datasets_available_async( # Don't provide objectives, let it try to load from empty memory scenario = Psychosocial(objective_scorer=mock_objective_scorer) - # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed groups - with pytest.raises(ValueError, match="DatasetConfiguration has no seed_groups"): - await scenario.initialize_async(objective_target=mock_objective_target) + # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed groups. + # Neutralize the provider fetch so the empty-memory path raises loudly instead of fetching. + with patch( + "pyrit.scenario.core.dataset_configuration.DatasetConfiguration._fetch_dataset_async", + new_callable=AsyncMock, + ): + with pytest.raises(ValueError, match="Dataset is not available or failed to load"): + await scenario.initialize_async(objective_target=mock_objective_target) @pytest.mark.usefixtures(*FIXTURES) @@ -181,7 +185,9 @@ async def test_attack_generation_for_all( mock_dataset_config, ): """Test that _get_atomic_attacks_async returns atomic attacks.""" - with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + with patch.object( + Psychosocial, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_resolved_seed_data + ): scenario = Psychosocial(objective_scorer=mock_objective_scorer) await scenario.initialize_async(objective_target=mock_objective_target, dataset_config=mock_dataset_config) @@ -199,7 +205,9 @@ async def test_attack_runs_include_objectives_async( mock_dataset_config, ) -> None: """Test that attack runs include objectives for each seed prompt.""" - with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + with patch.object( + Psychosocial, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_resolved_seed_data + ): scenario = Psychosocial( objective_scorer=mock_objective_scorer, ) @@ -219,7 +227,9 @@ async def test_get_atomic_attacks_async_returns_attacks( mock_dataset_config, ) -> None: """Test that _get_atomic_attacks_async returns atomic attacks.""" - with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + with patch.object( + Psychosocial, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_resolved_seed_data + ): scenario = Psychosocial( objective_scorer=mock_objective_scorer, ) @@ -243,7 +253,9 @@ async def test_initialize_async_with_max_concurrency( mock_dataset_config, ) -> None: """Test initialization with custom max_concurrency.""" - with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + with patch.object( + Psychosocial, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_resolved_seed_data + ): scenario = Psychosocial(objective_scorer=mock_objective_scorer) await scenario.initialize_async( objective_target=mock_objective_target, max_concurrency=20, dataset_config=mock_dataset_config @@ -261,7 +273,9 @@ async def test_initialize_async_with_memory_labels( """Test initialization with memory labels.""" memory_labels = {"type": "psychosocial", "category": "crisis"} - with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + with patch.object( + Psychosocial, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_resolved_seed_data + ): scenario = Psychosocial(objective_scorer=mock_objective_scorer) await scenario.initialize_async( memory_labels=memory_labels, @@ -305,7 +319,9 @@ async def test_no_target_duplication_async( mock_dataset_config, ) -> None: """Test that all three targets (adversarial, objective, scorer) are distinct.""" - with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + with patch.object( + Psychosocial, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_resolved_seed_data + ): scenario = Psychosocial() await scenario.initialize_async(objective_target=mock_objective_target, dataset_config=mock_dataset_config) @@ -336,7 +352,9 @@ async def test_initialize_async_invokes_target_requirements_validate( mock_dataset_config, ): """initialize_async must delegate capability validation to TARGET_REQUIREMENTS.validate.""" - with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + with patch.object( + Psychosocial, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_resolved_seed_data + ): scenario = Psychosocial(objective_scorer=mock_objective_scorer) with patch("pyrit.prompt_target.common.target_requirements.TargetRequirements.validate") as mock_validate: await scenario.initialize_async( @@ -369,7 +387,9 @@ async def test_initialize_async_rejects_target_missing_editable_history( capability != CapabilityName.EDITABLE_HISTORY ) - with patch.object(Psychosocial, "_resolve_seed_groups", return_value=mock_resolved_seed_data): + with patch.object( + Psychosocial, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_resolved_seed_data + ): scenario = Psychosocial(objective_scorer=mock_objective_scorer) with pytest.raises(ValueError, match="editable_history"): await scenario.initialize_async( @@ -401,10 +421,10 @@ class TestPsychosocialBaselineUniformity: """ADO 9012 regression: baseline shares objectives with strategies under max_dataset_size.""" async def test_one_resolution_call_baseline_matches_strategies(self, mock_objective_target, mock_objective_scorer): - from pyrit.scenario import DatasetConfiguration + from pyrit.scenario import DatasetAttackConfiguration - seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] - config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + seed_groups = [SeedAttackGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] + config = DatasetAttackConfiguration(seed_groups=seed_groups, max_dataset_size=3) first_sample = seed_groups[:3] second_sample = seed_groups[5:8] diff --git a/tests/unit/scenario/airt/test_rapid_response.py b/tests/unit/scenario/airt/test_rapid_response.py index 42a1059138..f3a7a5f012 100644 --- a/tests/unit/scenario/airt/test_rapid_response.py +++ b/tests/unit/scenario/airt/test_rapid_response.py @@ -4,7 +4,7 @@ """Tests for the RapidResponse scenario (refactored from ContentHarms).""" import pathlib -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -21,7 +21,10 @@ from pyrit.registry import TargetRegistry from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry from pyrit.scenario.core.attack_technique_factory import AttackTechniqueFactory -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import ( + DatasetAttackConfiguration, + DatasetConfiguration, +) from pyrit.scenario.scenarios.airt.rapid_response import ( RapidResponse, ) @@ -176,7 +179,7 @@ def test_default_dataset_config_has_all_harm_datasets(self, mock_objective_score ): config = RapidResponse()._default_dataset_config assert isinstance(config, DatasetConfiguration) - names = config.get_default_dataset_names() + names = config.dataset_names expected = [f"airt_{cat}" for cat in ALL_HARM_CATEGORIES] for name in expected: assert name in names @@ -202,7 +205,12 @@ def test_initialization_with_custom_scorer(self, mock_objective_scorer): assert scenario._objective_scorer == mock_objective_scorer @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=ALL_HARM_SEED_GROUPS) + @patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=ALL_HARM_SEED_GROUPS, + ) async def test_initialization_defaults_to_default_strategy( self, _mock_groups, @@ -221,11 +229,22 @@ async def test_initialize_raises_when_no_datasets(self, mock_objective_target, m scenario = RapidResponse( objective_scorer=mock_objective_scorer, ) - with pytest.raises(ValueError, match="DatasetConfiguration has no seed_groups"): - await scenario.initialize_async(objective_target=mock_objective_target) + # Neutralize the provider fetch so the empty-memory path raises loudly instead of fetching + # the real default dataset from the provider. + with patch( + "pyrit.scenario.core.dataset_configuration.DatasetConfiguration._fetch_dataset_async", + new_callable=AsyncMock, + ): + with pytest.raises(ValueError, match="could not be loaded"): + await scenario.initialize_async(objective_target=mock_objective_target) @patch("pyrit.scenario.core.scenario.Scenario._get_default_objective_scorer") - @patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=ALL_HARM_SEED_GROUPS) + @patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=ALL_HARM_SEED_GROUPS, + ) async def test_memory_labels_stored( self, _mock_groups, @@ -264,7 +283,12 @@ async def _init_and_get_attacks( ): """Helper: initialize scenario and return atomic attacks.""" groups = seed_groups or {"hate": _make_seed_groups("hate")} - with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + with patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ): scenario = RapidResponse( objective_scorer=mock_objective_scorer, ) @@ -413,7 +437,12 @@ async def test_unknown_technique_skipped_with_warning(self, mock_objective_targe tags=["core", "single_turn"], ) - with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + with patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ): scenario = RapidResponse( objective_scorer=mock_objective_scorer, ) diff --git a/tests/unit/scenario/airt/test_scam.py b/tests/unit/scenario/airt/test_scam.py index 3f2014578b..cccc687ae1 100644 --- a/tests/unit/scenario/airt/test_scam.py +++ b/tests/unit/scenario/airt/test_scam.py @@ -4,7 +4,7 @@ """Tests for the Scam class.""" import pathlib -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -17,7 +17,7 @@ from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedDataset, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptTarget -from pyrit.scenario import DatasetConfiguration +from pyrit.scenario import DatasetAttackConfiguration, DatasetConfiguration from pyrit.scenario.scenarios.airt.scam import Scam, ScamStrategy from pyrit.score import TrueFalseCompositeScorer @@ -57,10 +57,9 @@ def mock_memory_seeds(): def mock_dataset_config(mock_memory_seed_groups): """Create a mock dataset config that returns the seed groups.""" seed_attack_groups = list(mock_memory_seed_groups) - mock_config = MagicMock(spec=DatasetConfiguration) - mock_config.get_all_seed_attack_groups.return_value = seed_attack_groups - mock_config.get_default_dataset_names.return_value = ["airt_scam"] - mock_config.has_data_source.return_value = True + mock_config = MagicMock(spec=DatasetAttackConfiguration) + mock_config.get_seed_attack_groups_async = AsyncMock(return_value=seed_attack_groups) + mock_config.dataset_names = ["airt_scam"] return mock_config @@ -129,7 +128,9 @@ def test_init_with_default_objectives( mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedAttackGroup], ) -> None: - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam(objective_scorer=mock_objective_scorer) assert scenario.name == "Scam" @@ -137,7 +138,9 @@ def test_init_with_default_objectives( def test_init_with_default_scorer(self, mock_memory_seed_groups) -> None: """Test initialization with default scorer.""" - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam() assert scenario._objective_scorer_identifier @@ -145,14 +148,18 @@ def test_init_with_custom_scorer(self, *, mock_memory_seed_groups: list[SeedAtta """Test initialization with custom scorer.""" scorer = MagicMock(spec=TrueFalseCompositeScorer) - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam(objective_scorer=scorer) assert isinstance(scenario._scorer_config, AttackScoringConfig) def test_init_default_adversarial_chat( self, *, mock_objective_scorer: TrueFalseCompositeScorer, mock_memory_seed_groups: list[SeedAttackGroup] ) -> None: - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam(objective_scorer=mock_objective_scorer) assert isinstance(scenario._adversarial_chat, OpenAIChatTarget) @@ -164,7 +171,9 @@ def test_init_with_adversarial_chat( adversarial_chat = MagicMock(OpenAIChatTarget) adversarial_chat.get_identifier.return_value = _mock_target_id("CustomAdversary") - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam( adversarial_chat=adversarial_chat, objective_scorer=mock_objective_scorer, @@ -179,9 +188,14 @@ async def test_init_raises_exception_when_no_datasets_available_async( # Don't mock _resolve_seed_groups, let it try to load from empty memory scenario = Scam(objective_scorer=mock_objective_scorer) - # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed groups - with pytest.raises(ValueError, match="DatasetConfiguration has no seed_groups"): - await scenario.initialize_async(objective_target=mock_objective_target) + # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed groups. + # Neutralize the provider fetch so the empty-memory path raises loudly instead of fetching. + with patch( + "pyrit.scenario.core.dataset_configuration.DatasetConfiguration._fetch_dataset_async", + new_callable=AsyncMock, + ): + with pytest.raises(ValueError, match="Dataset is not available or failed to load"): + await scenario.initialize_async(objective_target=mock_objective_target) @pytest.mark.usefixtures(*FIXTURES) @@ -192,7 +206,9 @@ async def test_attack_generation_for_all( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test that _get_atomic_attacks_async returns atomic attacks.""" - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam(objective_scorer=mock_objective_scorer) await scenario.initialize_async(objective_target=mock_objective_target, dataset_config=mock_dataset_config) @@ -344,7 +360,9 @@ async def test_initialize_async_with_max_concurrency( mock_dataset_config, ) -> None: """Test initialization with custom max_concurrency.""" - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam(objective_scorer=mock_objective_scorer) await scenario.initialize_async( objective_target=mock_objective_target, max_concurrency=20, dataset_config=mock_dataset_config @@ -362,7 +380,9 @@ async def test_initialize_async_with_memory_labels( """Test initialization with memory labels.""" memory_labels = {"type": "scam", "category": "scenario"} - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam(objective_scorer=mock_objective_scorer) await scenario.initialize_async( memory_labels=memory_labels, @@ -396,7 +416,9 @@ async def test_no_target_duplication_async( mock_dataset_config, ) -> None: """Test that all three targets (adversarial, object, scorer) are distinct.""" - with patch.object(Scam, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + Scam, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = Scam() await scenario.initialize_async(objective_target=mock_objective_target, dataset_config=mock_dataset_config) @@ -416,10 +438,10 @@ class TestScamBaselineUniformity: async def test_one_resolution_call_baseline_matches_strategies( self, mock_objective_target, mock_objective_scorer, single_turn_strategy ): - from pyrit.models import SeedGroup, SeedObjective + from pyrit.models import SeedAttackGroup, SeedObjective - seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] - config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + seed_groups = [SeedAttackGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] + config = DatasetAttackConfiguration(seed_groups=seed_groups, max_dataset_size=3) first_sample = seed_groups[:3] second_sample = seed_groups[5:8] diff --git a/tests/unit/scenario/benchmark/test_adversarial.py b/tests/unit/scenario/benchmark/test_adversarial.py index bebe095b11..c76656471d 100644 --- a/tests/unit/scenario/benchmark/test_adversarial.py +++ b/tests/unit/scenario/benchmark/test_adversarial.py @@ -432,7 +432,7 @@ def _make_bench_with_targets(self, *, target_names: list[str]) -> AdversarialBen # Dataset config: one dataset with one real seed group (AtomicAttack hashes objectives). seed_group = SeedAttackGroup(seeds=[SeedObjective(value="benchmark_objective_1")]) bench._dataset_config = MagicMock() - bench._dataset_config.get_seed_attack_groups.return_value = {"harmbench": [seed_group]} + bench._dataset_config.get_attack_groups_by_dataset_async = AsyncMock(return_value={"harmbench": [seed_group]}) return bench @@ -479,7 +479,7 @@ async def test_display_group_uses_registry_name_not_target_model_name(self): seed_group = SeedAttackGroup(seeds=[SeedObjective(value="display_group_regression_objective")]) bench._dataset_config = MagicMock() - bench._dataset_config.get_seed_attack_groups.return_value = {"harmbench": [seed_group]} + bench._dataset_config.get_attack_groups_by_dataset_async = AsyncMock(return_value={"harmbench": [seed_group]}) result = await bench._get_atomic_attacks_async() @@ -766,7 +766,7 @@ def _make_bench(self, *, use_cached: bool) -> AdversarialBenchmark: seed_group = SeedAttackGroup(seeds=[SeedObjective(value="skip_cached_objective")]) bench._dataset_config = MagicMock() - bench._dataset_config.get_seed_attack_groups.return_value = {"harmbench": [seed_group]} + bench._dataset_config.get_attack_groups_by_dataset_async = AsyncMock(return_value={"harmbench": [seed_group]}) return bench diff --git a/tests/unit/scenario/core/test_dataset_configuration.py b/tests/unit/scenario/core/test_dataset_configuration.py index e1b5c68727..0e4e2c9c80 100644 --- a/tests/unit/scenario/core/test_dataset_configuration.py +++ b/tests/unit/scenario/core/test_dataset_configuration.py @@ -1,517 +1,517 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -"""Tests for the DatasetConfiguration class.""" +"""Tests for the DatasetConfiguration base class and DatasetAttackConfiguration.""" -import random -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from pyrit.models import SeedGroup, SeedObjective, SeedPrompt +from pyrit.models import SeedAttackGroup, SeedGroup, SeedObjective, SeedPrompt from pyrit.scenario.core.dataset_configuration import ( - EXPLICIT_SEED_GROUPS_KEY, + INLINE_DATASET_NAME, + DatasetAttackConfiguration, DatasetConfiguration, + DatasetConstraintError, + DatasetSourceKind, + ResolvedDataset, + forbid_inline_seeds, + require_harm_categories, + require_inline_seeds, + require_min_size, + require_nonempty, + require_seed_type, + restrict_dataset_names, ) +MEMORY_PATCH_TARGET = "pyrit.scenario.core.dataset_configuration.CentralMemory.get_memory_instance" +PROVIDER_PATCH_TARGET = "pyrit.datasets.seed_datasets.seed_dataset_provider.SeedDatasetProvider" + + +def resolved( + *seeds: SeedObjective | SeedPrompt, + source_kind: DatasetSourceKind = DatasetSourceKind.MEMORY, + dataset_names: tuple[str, ...] = (), +) -> ResolvedDataset: + """Build a ResolvedDataset from inline seeds for validator tests.""" + return ResolvedDataset(seeds=list(seeds), source_kind=source_kind, dataset_names=dataset_names) + + +@pytest.fixture +def mock_memory() -> MagicMock: + """A stand-in CentralMemory whose ``get_seeds`` returns nothing by default.""" + memory = MagicMock() + memory.get_seeds.return_value = [] + memory.get_seed_groups.return_value = [] + memory.add_seed_datasets_to_memory_async = AsyncMock() + return memory + + +@pytest.fixture(autouse=True) +def patch_memory(mock_memory: MagicMock): + """Patch ``CentralMemory.get_memory_instance`` so configs resolve against ``mock_memory``.""" + with patch(MEMORY_PATCH_TARGET, return_value=mock_memory): + yield mock_memory + @pytest.fixture def sample_seed_group() -> SeedGroup: - """Create a sample SeedGroup for testing.""" - return SeedGroup( - seeds=[ - SeedObjective(value="Test objective"), - SeedPrompt(value="Test prompt"), - ] - ) + """A SeedGroup carrying exactly one objective and one prompt.""" + return SeedGroup(seeds=[SeedObjective(value="Test objective"), SeedPrompt(value="Test prompt")]) @pytest.fixture -def sample_seed_groups(sample_seed_group: SeedGroup) -> list: - """Create multiple sample SeedGroups for testing.""" +def sample_seed_groups() -> list[SeedGroup]: + """Three distinct SeedGroups, each with one objective and one prompt.""" return [ - sample_seed_group, - SeedGroup( - seeds=[ - SeedObjective(value="Second objective"), - SeedPrompt(value="Second prompt"), - ] - ), - SeedGroup( - seeds=[ - SeedObjective(value="Third objective"), - SeedPrompt(value="Third prompt"), - ] - ), + SeedGroup(seeds=[SeedObjective(value="o1"), SeedPrompt(value="p1")]), + SeedGroup(seeds=[SeedObjective(value="o2"), SeedPrompt(value="p2")]), + SeedGroup(seeds=[SeedObjective(value="o3"), SeedPrompt(value="p3")]), ] +def make_objectives(*values: str) -> list[SeedObjective]: + """Build a list of SeedObjective seeds (each becomes its own attack group).""" + return [SeedObjective(value=v) for v in values] + + class TestDatasetConfigurationInit: - """Tests for DatasetConfiguration initialization.""" + """Construction, source-exclusivity, and defensive copying.""" - def test_init_with_seed_groups_only(self, sample_seed_groups: list) -> None: - """Test initialization with only seed_groups.""" - config = DatasetConfiguration(seed_groups=sample_seed_groups) + def test_init_with_seeds_only(self) -> None: + seeds = make_objectives("a", "b") + config = DatasetConfiguration(seeds=seeds) + assert config._seeds == seeds + assert config._seed_groups is None + assert config._dataset_names is None + def test_init_with_seed_groups_only(self, sample_seed_groups: list[SeedGroup]) -> None: + config = DatasetConfiguration(seed_groups=sample_seed_groups) assert config._seed_groups == sample_seed_groups + assert config._seeds is None assert config._dataset_names is None assert config.max_dataset_size is None - assert config._scenario_strategies is None def test_init_with_dataset_names_only(self) -> None: - """Test initialization with only dataset_names.""" - dataset_names = ["dataset1", "dataset2"] - config = DatasetConfiguration(dataset_names=dataset_names) - + config = DatasetConfiguration(dataset_names=["dataset1", "dataset2"]) + assert config._dataset_names == ["dataset1", "dataset2"] + assert config._seeds is None assert config._seed_groups is None - assert config._dataset_names == dataset_names - assert config.max_dataset_size is None - def test_init_with_both_seed_groups_and_dataset_names_raises_error(self, sample_seed_groups: list) -> None: - """Test that setting both seed_groups and dataset_names raises ValueError.""" - with pytest.raises(ValueError) as exc_info: + def test_init_defaults_to_auto_fetch(self) -> None: + config = DatasetConfiguration(dataset_names=["d1"]) + assert config._auto_fetch is True + + def test_init_auto_fetch_can_be_disabled(self) -> None: + config = DatasetConfiguration(dataset_names=["d1"], auto_fetch=False) + assert config._auto_fetch is False + + def test_init_with_two_sources_raises(self, sample_seed_groups: list[SeedGroup]) -> None: + with pytest.raises(ValueError, match="Only one of 'seeds', 'seed_groups', or 'dataset_names'"): + DatasetConfiguration(seed_groups=sample_seed_groups, dataset_names=["d1"]) + + def test_init_with_three_sources_raises(self, sample_seed_groups: list[SeedGroup]) -> None: + with pytest.raises(ValueError, match="Only one of"): DatasetConfiguration( + seeds=make_objectives("a"), seed_groups=sample_seed_groups, - dataset_names=["dataset1"], + dataset_names=["d1"], ) - assert "Only one of 'seed_groups' or 'dataset_names' can be set" in str(exc_info.value) - - def test_init_with_max_dataset_size(self, sample_seed_groups: list) -> None: - """Test initialization with max_dataset_size.""" + def test_init_with_max_dataset_size(self, sample_seed_groups: list[SeedGroup]) -> None: config = DatasetConfiguration(seed_groups=sample_seed_groups, max_dataset_size=2) - assert config.max_dataset_size == 2 - def test_init_with_max_dataset_size_zero_raises_error(self) -> None: - """Test that max_dataset_size=0 raises ValueError.""" - with pytest.raises(ValueError) as exc_info: - DatasetConfiguration(dataset_names=["dataset1"], max_dataset_size=0) - - assert "'max_dataset_size' must be a positive integer" in str(exc_info.value) + def test_init_with_max_dataset_size_zero_raises(self) -> None: + with pytest.raises(ValueError, match="positive integer"): + DatasetConfiguration(dataset_names=["d1"], max_dataset_size=0) - def test_init_with_max_dataset_size_negative_raises_error(self) -> None: - """Test that negative max_dataset_size raises ValueError.""" - with pytest.raises(ValueError) as exc_info: - DatasetConfiguration(dataset_names=["dataset1"], max_dataset_size=-1) + def test_init_with_max_dataset_size_negative_raises(self) -> None: + with pytest.raises(ValueError, match="positive integer"): + DatasetConfiguration(dataset_names=["d1"], max_dataset_size=-1) - assert "'max_dataset_size' must be a positive integer" in str(exc_info.value) - - def test_init_copies_seed_groups_to_prevent_mutation(self, sample_seed_groups: list) -> None: - """Test that the constructor copies seed_groups list to prevent external mutation.""" - original_list = list(sample_seed_groups) + def test_init_copies_seed_groups_to_prevent_mutation(self, sample_seed_groups: list[SeedGroup]) -> None: config = DatasetConfiguration(seed_groups=sample_seed_groups) - - # Mutate the original list - sample_seed_groups.append(SeedGroup(seeds=[SeedObjective(value="New objective")])) - - # Config should still have the original length - assert len(config._seed_groups) == len(original_list) + sample_seed_groups.append(SeedGroup(seeds=[SeedObjective(value="extra")])) + assert config._seed_groups is not None + assert len(config._seed_groups) == 3 def test_init_copies_dataset_names_to_prevent_mutation(self) -> None: - """Test that the constructor copies dataset_names list to prevent external mutation.""" - dataset_names = ["dataset1", "dataset2"] - config = DatasetConfiguration(dataset_names=dataset_names) + names = ["d1", "d2"] + config = DatasetConfiguration(dataset_names=names) + names.append("d3") + assert config._dataset_names == ["d1", "d2"] - # Mutate the original list - dataset_names.append("dataset3") + def test_init_copies_seeds_to_prevent_mutation(self) -> None: + seeds = make_objectives("a", "b") + config = DatasetConfiguration(seeds=seeds) + seeds.append(SeedObjective(value="c")) + assert config._seeds is not None + assert len(config._seeds) == 2 - # Config should still have the original length - assert len(config._dataset_names) == 2 - def test_init_with_scenario_strategies(self, sample_seed_groups: list) -> None: - """Test initialization with scenario_strategies.""" - mock_strategies = [MagicMock(), MagicMock()] - config = DatasetConfiguration( - seed_groups=sample_seed_groups, - scenario_strategies=mock_strategies, - ) +class TestDatasetNamesProperty: + """The ``dataset_names`` property (replaces the deprecated getter).""" - assert config._scenario_strategies == mock_strategies + def test_returns_configured_names(self) -> None: + config = DatasetConfiguration(dataset_names=["d1", "d2"]) + assert config.dataset_names == ["d1", "d2"] - def test_init_with_no_data_source(self) -> None: - """Test initialization with neither seed_groups nor dataset_names.""" - config = DatasetConfiguration() + def test_returns_copy(self) -> None: + config = DatasetConfiguration(dataset_names=["d1"]) + config.dataset_names.append("mutated") + assert config.dataset_names == ["d1"] - assert config._seed_groups is None - assert config._dataset_names is None - - -@pytest.mark.usefixtures("patch_central_database") -class TestDatasetConfigurationGetSeedGroups: - """Tests for DatasetConfiguration.get_seed_groups method.""" - - def test_get_seed_groups_with_explicit_seed_groups(self, sample_seed_groups: list) -> None: - """Test get_seed_groups returns explicit seed_groups under special key.""" + def test_empty_with_seed_groups(self, sample_seed_groups: list[SeedGroup]) -> None: config = DatasetConfiguration(seed_groups=sample_seed_groups) - - result = config.get_seed_groups() - - assert EXPLICIT_SEED_GROUPS_KEY in result - assert result[EXPLICIT_SEED_GROUPS_KEY] == sample_seed_groups - - def test_get_seed_groups_with_dataset_names(self, sample_seed_groups: list) -> None: - """Test get_seed_groups loads from memory when dataset_names is set.""" - config = DatasetConfiguration(dataset_names=["test_dataset"]) - - with patch.object(config, "_load_seed_groups_for_dataset", return_value=sample_seed_groups): - result = config.get_seed_groups() - - assert "test_dataset" in result - assert result["test_dataset"] == sample_seed_groups - - def test_get_seed_groups_with_multiple_dataset_names(self, sample_seed_groups: list) -> None: - """Test get_seed_groups loads multiple datasets from memory.""" - config = DatasetConfiguration(dataset_names=["dataset1", "dataset2"]) - - def mock_load(*, dataset_name: str): - return sample_seed_groups if dataset_name in ["dataset1", "dataset2"] else [] - - with patch.object(config, "_load_seed_groups_for_dataset", side_effect=mock_load): - result = config.get_seed_groups() - - assert "dataset1" in result - assert "dataset2" in result - - def test_get_seed_groups_skips_empty_datasets_from_memory(self) -> None: - """Test that empty datasets from memory are not included in results.""" - config = DatasetConfiguration(dataset_names=["populated", "empty"]) - - def mock_load(*, dataset_name: str): - if dataset_name == "populated": - return [SeedGroup(seeds=[SeedObjective(value="obj")])] - return [] - - with patch.object(config, "_load_seed_groups_for_dataset", side_effect=mock_load): + assert config.dataset_names == [] + + def test_empty_with_no_source(self) -> None: + assert DatasetConfiguration().dataset_names == [] + + +class TestResolutionErrors: + """Loud failure branches in base resolution, reached via ``DatasetAttackConfiguration``.""" + + async def test_empty_inline_raises(self) -> None: + config = DatasetAttackConfiguration(seeds=[]) + with pytest.raises(DatasetConstraintError, match="empty"): + await config.get_seed_attack_groups_async() + + async def test_raises_loudly_when_still_empty_after_fetch(self) -> None: + config = DatasetAttackConfiguration(dataset_names=["d1"]) + with patch.object(config, "_fetch_dataset_async", new=AsyncMock()): + with pytest.raises(DatasetConstraintError, match="could not be loaded"): + await config.get_seed_attack_groups_async() + + async def test_raises_when_empty_and_auto_fetch_disabled(self) -> None: + config = DatasetAttackConfiguration(dataset_names=["d1"], auto_fetch=False) + with pytest.raises(DatasetConstraintError, match="auto_fetch is disabled"): + await config.get_seed_attack_groups_async() + + async def test_dataset_constraint_error_is_value_error(self) -> None: + config = DatasetAttackConfiguration(dataset_names=["d1"], auto_fetch=False) + with pytest.raises(ValueError): + await config.get_seed_attack_groups_async() + + +class TestGetSeedAttackGroupsAsync: + """``DatasetAttackConfiguration.get_seed_attack_groups_async`` (flat, global sample).""" + + async def test_inline_seed_groups_to_attack_groups(self, sample_seed_groups: list[SeedGroup]) -> None: + config = DatasetAttackConfiguration(seed_groups=sample_seed_groups) + groups = await config.get_seed_attack_groups_async() + assert len(groups) == 3 + assert all(isinstance(g, SeedAttackGroup) for g in groups) + + async def test_inline_seeds_built_into_groups(self) -> None: + config = DatasetAttackConfiguration(seeds=make_objectives("a", "b")) + groups = await config.get_seed_attack_groups_async() + assert len(groups) == 2 + assert all(isinstance(g, SeedAttackGroup) for g in groups) + + async def test_from_memory(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a", "b", "c") + config = DatasetAttackConfiguration(dataset_names=["d1"]) + groups = await config.get_seed_attack_groups_async() + assert len(groups) == 3 + + async def test_applies_max_dataset_size_globally(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a", "b", "c", "d") + config = DatasetAttackConfiguration(dataset_names=["d1"], max_dataset_size=2) + groups = await config.get_seed_attack_groups_async() + assert len(groups) == 2 + + async def test_empty_raises(self) -> None: + config = DatasetAttackConfiguration(dataset_names=["d1"], auto_fetch=False) + with pytest.raises(DatasetConstraintError): + await config.get_seed_attack_groups_async() + + async def test_auto_fetch_when_memory_empty(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.side_effect = [[], make_objectives("a")] + config = DatasetAttackConfiguration(dataset_names=["d1"]) + with patch.object(config, "_fetch_dataset_async", new=AsyncMock()) as mock_fetch: + groups = await config.get_seed_attack_groups_async() + assert len(groups) == 1 + mock_fetch.assert_awaited_once_with(dataset_name="d1") + + +class TestGetAttackGroupsByDatasetAsync: + """``get_attack_groups_by_dataset_async`` (keyed by dataset, per-dataset sample).""" + + async def test_inline_uses_inline_label(self, sample_seed_groups: list[SeedGroup]) -> None: + config = DatasetAttackConfiguration(seed_groups=sample_seed_groups) + result = await config.get_attack_groups_by_dataset_async() + assert set(result.keys()) == {INLINE_DATASET_NAME} + assert len(result[INLINE_DATASET_NAME]) == 3 + + async def test_keyed_per_dataset(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.side_effect = [make_objectives("a", "b"), make_objectives("c")] + config = DatasetAttackConfiguration(dataset_names=["d1", "d2"]) + result = await config.get_attack_groups_by_dataset_async() + assert set(result.keys()) == {"d1", "d2"} + assert len(result["d1"]) == 2 + assert len(result["d2"]) == 1 + + async def test_per_dataset_max_sample(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.side_effect = [make_objectives("a", "b", "c"), make_objectives("d", "e", "f")] + config = DatasetAttackConfiguration(dataset_names=["d1", "d2"], max_dataset_size=1) + result = await config.get_attack_groups_by_dataset_async() + assert len(result["d1"]) == 1 + assert len(result["d2"]) == 1 + + async def test_loud_raise_when_a_dataset_is_empty(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.side_effect = [make_objectives("a"), []] + config = DatasetAttackConfiguration(dataset_names=["d1", "d2"], auto_fetch=False) + with pytest.raises(DatasetConstraintError, match="could not be loaded"): + await config.get_attack_groups_by_dataset_async() + + +class TestBuildAttackGroups: + """The ``_build_attack_groups`` override seam.""" + + def test_default_groups_by_prompt_group_id(self) -> None: + config = DatasetAttackConfiguration(dataset_names=["d1"]) + groups = config._build_attack_groups(make_objectives("a", "b")) + assert len(groups) == 2 + assert all(isinstance(g, SeedAttackGroup) for g in groups) + + async def test_override_is_used(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a", "b", "c") + sentinel = [SeedAttackGroup(seeds=[SeedObjective(value="custom")])] + + class CustomConfig(DatasetAttackConfiguration): + def _build_attack_groups(self, seeds): + return sentinel + + config = CustomConfig(dataset_names=["d1"]) + assert await config.get_seed_attack_groups_async() == sentinel + + +class TestFetchDatasetAsync: + """``_fetch_dataset_async`` provider interaction.""" + + async def test_unregistered_name_does_not_fetch(self, mock_memory: MagicMock) -> None: + config = DatasetConfiguration(dataset_names=["d1"]) + with patch(PROVIDER_PATCH_TARGET) as provider: + provider.get_all_dataset_names_async = AsyncMock(return_value=["other"]) + provider.fetch_datasets_async = AsyncMock() + await config._fetch_dataset_async(dataset_name="d1") + provider.fetch_datasets_async.assert_not_called() + mock_memory.add_seed_datasets_to_memory_async.assert_not_called() + + async def test_registered_name_fetches_and_adds(self, mock_memory: MagicMock) -> None: + config = DatasetConfiguration(dataset_names=["d1"]) + datasets = [MagicMock()] + with patch(PROVIDER_PATCH_TARGET) as provider: + provider.get_all_dataset_names_async = AsyncMock(return_value=["d1"]) + provider.fetch_datasets_async = AsyncMock(return_value=datasets) + await config._fetch_dataset_async(dataset_name="d1") + provider.fetch_datasets_async.assert_awaited_once_with(dataset_names=["d1"]) + mock_memory.add_seed_datasets_to_memory_async.assert_awaited_once() + + async def test_enumeration_error_propagates(self, mock_memory: MagicMock) -> None: + config = DatasetConfiguration(dataset_names=["d1"]) + with patch(PROVIDER_PATCH_TARGET) as provider: + provider.get_all_dataset_names_async = AsyncMock(side_effect=RuntimeError("boom")) + with pytest.raises(RuntimeError, match="boom"): + await config._fetch_dataset_async(dataset_name="d1") + mock_memory.add_seed_datasets_to_memory_async.assert_not_called() + + async def test_fetch_failure_chains_root_cause(self, mock_memory: MagicMock) -> None: + config = DatasetAttackConfiguration(dataset_names=["d1"]) + with patch(PROVIDER_PATCH_TARGET) as provider: + provider.get_all_dataset_names_async = AsyncMock(side_effect=RuntimeError("boom")) + with pytest.raises(DatasetConstraintError, match="auto-fetch") as exc_info: + await config.get_seed_attack_groups_async() + assert isinstance(exc_info.value.__cause__, RuntimeError) + + +class TestLegacyDeprecations: + """Legacy getters still work but emit ``DeprecationWarning`` (removed in 0.17.0).""" + + def test_get_seed_groups_warns(self, mock_memory: MagicMock, sample_seed_groups: list[SeedGroup]) -> None: + config = DatasetConfiguration(seed_groups=sample_seed_groups) + with pytest.warns(DeprecationWarning): result = config.get_seed_groups() + assert INLINE_DATASET_NAME in result - assert "populated" in result - assert "empty" not in result - - def test_get_seed_groups_with_no_data_source_raises_error(self) -> None: - """Test that get_seed_groups raises ValueError when no data source is configured.""" - config = DatasetConfiguration() - - with pytest.raises(ValueError) as exc_info: - config.get_seed_groups() - - assert "DatasetConfiguration has no seed_groups" in str(exc_info.value) - - def test_get_seed_groups_applies_max_dataset_size_per_dataset(self, sample_seed_groups: list) -> None: - """Test that max_dataset_size is applied per dataset.""" - config = DatasetConfiguration(seed_groups=sample_seed_groups, max_dataset_size=1) - - # Set seed for deterministic random sampling - random.seed(42) - result = config.get_seed_groups() - - assert len(result[EXPLICIT_SEED_GROUPS_KEY]) == 1 - - def test_get_seed_groups_with_empty_seed_groups_list_raises_error(self) -> None: - """Test that empty seed_groups list raises ValueError.""" - config = DatasetConfiguration(seed_groups=[]) - - with pytest.raises(ValueError) as exc_info: - config.get_seed_groups() - - assert "DatasetConfiguration has no seed_groups" in str(exc_info.value) - - -@pytest.mark.usefixtures("patch_central_database") -class TestDatasetConfigurationLoadSeedGroupsForDataset: - """Tests for DatasetConfiguration._load_seed_groups_for_dataset method.""" - - def test_load_seed_groups_for_dataset_calls_memory(self, sample_seed_groups: list) -> None: - """Test that _load_seed_groups_for_dataset calls CentralMemory.""" - config = DatasetConfiguration(dataset_names=["test_dataset"]) - - with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_central_memory: - mock_memory = MagicMock() - mock_memory.get_seed_groups.return_value = sample_seed_groups - mock_central_memory.get_memory_instance.return_value = mock_memory - - result = config._load_seed_groups_for_dataset(dataset_name="test_dataset") - - mock_memory.get_seed_groups.assert_called_once_with(dataset_name="test_dataset") - assert result == sample_seed_groups - - def test_load_seed_groups_for_dataset_returns_empty_list_when_none(self) -> None: - """Test that _load_seed_groups_for_dataset returns empty list when memory returns None.""" - config = DatasetConfiguration(dataset_names=["nonexistent"]) - - with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_central_memory: - mock_memory = MagicMock() - mock_memory.get_seed_groups.return_value = None - mock_central_memory.get_memory_instance.return_value = mock_memory - - result = config._load_seed_groups_for_dataset(dataset_name="nonexistent") - - assert result == [] - - -@pytest.mark.usefixtures("patch_central_database") -class TestDatasetConfigurationGetAllSeedGroups: - """Tests for DatasetConfiguration.get_all_seed_groups method.""" - - def test_get_all_seed_groups_flattens_results(self, sample_seed_groups: list) -> None: - """Test that get_all_seed_groups returns a flat list.""" + def test_get_all_seed_groups_warns(self, sample_seed_groups: list[SeedGroup]) -> None: config = DatasetConfiguration(seed_groups=sample_seed_groups) + with pytest.warns(DeprecationWarning): + assert len(config.get_all_seed_groups()) == 3 - result = config.get_all_seed_groups() - - assert isinstance(result, list) - assert len(result) == len(sample_seed_groups) - for group in sample_seed_groups: - assert group in result - - def test_get_all_seed_groups_combines_multiple_datasets(self) -> None: - """Test that get_all_seed_groups combines seed groups from multiple datasets.""" - config = DatasetConfiguration(dataset_names=["dataset1", "dataset2"]) - - group1 = SeedGroup(seeds=[SeedObjective(value="obj1")]) - group2 = SeedGroup(seeds=[SeedObjective(value="obj2")]) - - def mock_load(*, dataset_name: str): - return [group1] if dataset_name == "dataset1" else [group2] - - with patch.object(config, "_load_seed_groups_for_dataset", side_effect=mock_load): - result = config.get_all_seed_groups() - - assert len(result) == 2 - assert group1 in result - assert group2 in result - - def test_get_all_seed_groups_raises_error_when_no_data_source(self) -> None: - """Test that get_all_seed_groups raises ValueError when no data source is configured.""" - config = DatasetConfiguration() - - with pytest.raises(ValueError) as exc_info: - config.get_all_seed_groups() - - assert "DatasetConfiguration has no seed_groups" in str(exc_info.value) - - -class TestDatasetConfigurationGetDefaultDatasetNames: - """Tests for DatasetConfiguration.get_default_dataset_names method.""" - - def test_get_default_dataset_names_returns_dataset_names(self) -> None: - """Test that get_default_dataset_names returns configured dataset_names.""" - dataset_names = ["dataset1", "dataset2", "dataset3"] - config = DatasetConfiguration(dataset_names=dataset_names) - - result = config.get_default_dataset_names() - - assert result == dataset_names - - def test_get_default_dataset_names_returns_copy(self) -> None: - """Test that get_default_dataset_names returns a copy of the list.""" - dataset_names = ["dataset1", "dataset2"] - config = DatasetConfiguration(dataset_names=dataset_names) - - result = config.get_default_dataset_names() - result.append("dataset3") - - # Original should be unchanged - assert len(config.get_default_dataset_names()) == 2 - - def test_get_default_dataset_names_returns_empty_with_seed_groups(self, sample_seed_groups: list) -> None: - """Test that get_default_dataset_names returns empty list when using explicit seed_groups.""" + def test_get_seed_attack_groups_warns(self, sample_seed_groups: list[SeedGroup]) -> None: config = DatasetConfiguration(seed_groups=sample_seed_groups) + with pytest.warns(DeprecationWarning): + result = config.get_seed_attack_groups() + assert INLINE_DATASET_NAME in result - result = config.get_default_dataset_names() - - assert result == [] - - def test_get_default_dataset_names_returns_empty_when_no_config(self) -> None: - """Test that get_default_dataset_names returns empty list when nothing is configured.""" - config = DatasetConfiguration() - - result = config.get_default_dataset_names() - - assert result == [] - - -class TestDatasetConfigurationApplyMaxDatasetSize: - """Tests for DatasetConfiguration._apply_max_dataset_size method.""" - - def test_apply_max_returns_original_when_none(self, sample_seed_groups: list) -> None: - """Test that original list is returned when max_dataset_size is None.""" + def test_get_all_seed_attack_groups_warns(self, sample_seed_groups: list[SeedGroup]) -> None: config = DatasetConfiguration(seed_groups=sample_seed_groups) + with pytest.warns(DeprecationWarning): + groups = config.get_all_seed_attack_groups() + assert len(groups) == 3 + assert all(isinstance(g, SeedAttackGroup) for g in groups) + + def test_get_default_dataset_names_warns(self) -> None: + config = DatasetConfiguration(dataset_names=["d1", "d2"]) + with pytest.warns(DeprecationWarning): + assert config.get_default_dataset_names() == ["d1", "d2"] + + def test_get_all_seeds_warns(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a", "b") + config = DatasetConfiguration(dataset_names=["d1"]) + with pytest.warns(DeprecationWarning): + assert len(config.get_all_seeds()) == 2 + + def test_get_all_seeds_raises_when_no_dataset_names(self, sample_seed_groups: list[SeedGroup]) -> None: + config = DatasetConfiguration(seed_groups=sample_seed_groups) + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError, match="No dataset names configured"): + config.get_all_seeds() - result = config._apply_max_dataset_size(sample_seed_groups) - - assert result == sample_seed_groups - - def test_apply_max_returns_original_when_under_limit(self, sample_seed_groups: list) -> None: - """Test that original list is returned when length is under max_dataset_size.""" - config = DatasetConfiguration(seed_groups=sample_seed_groups, max_dataset_size=10) - result = config._apply_max_dataset_size(sample_seed_groups) +class TestValidators: + """The standalone validator builders and base ``validate``.""" - assert result == sample_seed_groups + def test_require_nonempty_raises_on_empty(self) -> None: + with pytest.raises(DatasetConstraintError): + require_nonempty()(resolved()) - def test_apply_max_returns_original_when_equal_to_limit(self, sample_seed_groups: list) -> None: - """Test that original list is returned when length equals max_dataset_size.""" - config = DatasetConfiguration( - seed_groups=sample_seed_groups, - max_dataset_size=len(sample_seed_groups), - ) + def test_require_nonempty_passes(self) -> None: + require_nonempty()(resolved(SeedObjective(value="a"))) - result = config._apply_max_dataset_size(sample_seed_groups) + def test_require_min_size_raises_when_too_few(self) -> None: + with pytest.raises(DatasetConstraintError): + require_min_size(3)(resolved(SeedObjective(value="a"))) - assert result == sample_seed_groups + def test_require_min_size_passes(self) -> None: + require_min_size(1)(resolved(SeedObjective(value="a"))) - def test_apply_max_returns_sample_when_over_limit(self, sample_seed_groups: list) -> None: - """Test that a random sample is returned when length exceeds max_dataset_size.""" - config = DatasetConfiguration(seed_groups=sample_seed_groups, max_dataset_size=1) + def test_require_harm_categories_raises_when_missing(self) -> None: + with pytest.raises(DatasetConstraintError): + require_harm_categories({"illegal"})(resolved(SeedObjective(value="a"))) - # Set seed for deterministic random sampling - random.seed(42) - result = config._apply_max_dataset_size(sample_seed_groups) + def test_require_harm_categories_passes(self) -> None: + item = SeedObjective(value="a", harm_categories=["illegal"]) + require_harm_categories({"illegal"})(resolved(item)) - assert len(result) == 1 - assert result[0] in sample_seed_groups + def test_require_seed_type_raises_on_wrong_type(self) -> None: + with pytest.raises(DatasetConstraintError, match="SeedObjective"): + require_seed_type(SeedObjective)(resolved(SeedPrompt(value="p"))) - def test_apply_max_returns_correct_sample_size(self) -> None: - """Test that the sample size is exactly max_dataset_size.""" - large_seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(20)] - config = DatasetConfiguration(seed_groups=large_seed_groups, max_dataset_size=5) + def test_require_seed_type_passes(self) -> None: + require_seed_type(SeedObjective)(resolved(SeedObjective(value="a"))) - result = config._apply_max_dataset_size(large_seed_groups) + def test_require_inline_seeds_raises_on_dataset_names(self) -> None: + item = SeedObjective(value="a") + with pytest.raises(DatasetConstraintError, match="inline"): + require_inline_seeds()(resolved(item, source_kind=DatasetSourceKind.MEMORY)) - assert len(result) == 5 - for group in result: - assert group in large_seed_groups + def test_require_inline_seeds_passes_for_inline(self) -> None: + item = SeedObjective(value="a") + require_inline_seeds()(resolved(item, source_kind=DatasetSourceKind.INLINE)) + def test_forbid_inline_seeds_raises_on_inline(self) -> None: + item = SeedObjective(value="a") + with pytest.raises(DatasetConstraintError, match="inline"): + forbid_inline_seeds()(resolved(item, source_kind=DatasetSourceKind.INLINE)) -class TestDatasetConfigurationHasDataSource: - """Tests for DatasetConfiguration.has_data_source method.""" + def test_forbid_inline_seeds_passes_for_dataset_names(self) -> None: + item = SeedObjective(value="a") + forbid_inline_seeds()(resolved(item, source_kind=DatasetSourceKind.MEMORY)) - def test_has_data_source_true_with_seed_groups(self, sample_seed_groups: list) -> None: - """Test that has_data_source returns True when seed_groups is set.""" - config = DatasetConfiguration(seed_groups=sample_seed_groups) + def test_restrict_dataset_names_passes_when_subset(self) -> None: + item = SeedObjective(value="a") + restrict_dataset_names({"d1", "d2"})(resolved(item, dataset_names=("d1",))) - assert config.has_data_source() is True + def test_restrict_dataset_names_raises_on_disallowed(self) -> None: + item = SeedObjective(value="a") + with pytest.raises(DatasetConstraintError, match="not allowed"): + restrict_dataset_names({"d1"})(resolved(item, dataset_names=("d1", "rogue"))) - def test_has_data_source_true_with_dataset_names(self) -> None: - """Test that has_data_source returns True when dataset_names is set.""" - config = DatasetConfiguration(dataset_names=["dataset1"]) + def test_restrict_dataset_names_passes_for_inline(self) -> None: + item = SeedObjective(value="a") + restrict_dataset_names({"d1"})(resolved(item, source_kind=DatasetSourceKind.INLINE)) - assert config.has_data_source() is True + def test_validate_raises_on_empty(self) -> None: + config = DatasetConfiguration(dataset_names=["d1"]) + with pytest.raises(DatasetConstraintError, match="empty"): + config.validate(resolved()) - def test_has_data_source_false_when_empty(self) -> None: - """Test that has_data_source returns False when nothing is configured.""" - config = DatasetConfiguration() - assert config.has_data_source() is False +class TestSourceKind: + """``source_kind`` reflects how the configuration was constructed.""" - def test_has_data_source_true_with_empty_seed_groups_list(self) -> None: - """Test that has_data_source returns True even with empty seed_groups list.""" - # Note: This tests the current behavior - an empty list is still "configured" - config = DatasetConfiguration(seed_groups=[]) - - assert config.has_data_source() is True + def test_inline_seeds(self) -> None: + config = DatasetConfiguration(seeds=make_objectives("a")) + assert config.source_kind is DatasetSourceKind.INLINE + def test_inline_seed_groups(self, sample_seed_groups: list[SeedGroup]) -> None: + config = DatasetConfiguration(seed_groups=sample_seed_groups) + assert config.source_kind is DatasetSourceKind.INLINE -@pytest.mark.usefixtures("patch_central_database") -class TestDatasetConfigurationGetAllSeeds: - """Tests for DatasetConfiguration.get_all_seeds method.""" + def test_dataset_names(self) -> None: + config = DatasetConfiguration(dataset_names=["d1"]) + assert config.source_kind is DatasetSourceKind.MEMORY - def test_get_all_seeds_raises_when_no_dataset_names(self) -> None: - """Test that get_all_seeds raises ValueError when no dataset_names are configured.""" + def test_unconfigured_is_memory(self) -> None: config = DatasetConfiguration() - - with pytest.raises(ValueError, match="No dataset names configured"): - config.get_all_seeds() - - def test_get_all_seeds_raises_when_seed_groups_configured(self, sample_seed_groups: list) -> None: - """Test that get_all_seeds raises ValueError when seed_groups are configured instead of dataset_names.""" - config = DatasetConfiguration(seed_groups=sample_seed_groups) - - with pytest.raises(ValueError, match="No dataset names configured"): - config.get_all_seeds() - - def test_get_all_seeds_returns_seeds_from_memory(self) -> None: - """Test that get_all_seeds returns SeedPrompt objects from memory.""" - mock_seeds = [ - SeedPrompt(value="seed1", data_type="text"), - SeedPrompt(value="seed2", data_type="text"), - ] - - with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: - mock_memory = MagicMock() - mock_memory.get_seeds.return_value = mock_seeds - mock_memory_class.get_memory_instance.return_value = mock_memory - - config = DatasetConfiguration(dataset_names=["test_dataset"]) - result = config.get_all_seeds() - - assert len(result) == 2 - assert result[0].value == "seed1" - assert result[1].value == "seed2" - mock_memory.get_seeds.assert_called_once_with(dataset_name="test_dataset") - - def test_get_all_seeds_aggregates_from_multiple_datasets(self) -> None: - """Test that get_all_seeds aggregates seeds from all configured datasets.""" - seeds_dataset1 = [SeedPrompt(value="ds1_seed1", data_type="text")] - seeds_dataset2 = [ - SeedPrompt(value="ds2_seed1", data_type="text"), - SeedPrompt(value="ds2_seed2", data_type="text"), - ] - - with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: - mock_memory = MagicMock() - mock_memory.get_seeds.side_effect = [seeds_dataset1, seeds_dataset2] - mock_memory_class.get_memory_instance.return_value = mock_memory - - config = DatasetConfiguration(dataset_names=["dataset1", "dataset2"]) - result = config.get_all_seeds() - - assert len(result) == 3 - assert result[0].value == "ds1_seed1" - assert result[1].value == "ds2_seed1" - assert result[2].value == "ds2_seed2" - assert mock_memory.get_seeds.call_count == 2 - - def test_get_all_seeds_applies_max_dataset_size_per_dataset(self) -> None: - """Test that get_all_seeds applies max_dataset_size sampling per dataset.""" - seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(10)] - - with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: - mock_memory = MagicMock() - mock_memory.get_seeds.return_value = seeds - mock_memory_class.get_memory_instance.return_value = mock_memory - - config = DatasetConfiguration(dataset_names=["dataset1"], max_dataset_size=3) - result = config.get_all_seeds() - - assert len(result) == 3 - # All returned seeds should be from the original list - for seed in result: - assert seed in seeds - - def test_get_all_seeds_returns_all_when_under_max_size(self) -> None: - """Test that get_all_seeds returns all seeds when count is under max_dataset_size.""" - seeds = [SeedPrompt(value=f"seed{i}", data_type="text") for i in range(3)] - - with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: - mock_memory = MagicMock() - mock_memory.get_seeds.return_value = seeds - mock_memory_class.get_memory_instance.return_value = mock_memory - - config = DatasetConfiguration(dataset_names=["dataset1"], max_dataset_size=10) - result = config.get_all_seeds() - - assert len(result) == 3 - - def test_get_all_seeds_returns_empty_list_when_no_seeds_in_memory(self) -> None: - """Test that get_all_seeds returns empty list when memory has no seeds.""" - with patch("pyrit.scenario.core.dataset_configuration.CentralMemory") as mock_memory_class: - mock_memory = MagicMock() - mock_memory.get_seeds.return_value = [] - mock_memory_class.get_memory_instance.return_value = mock_memory - - config = DatasetConfiguration(dataset_names=["empty_dataset"]) - result = config.get_all_seeds() - - assert result == [] + assert config.source_kind is DatasetSourceKind.MEMORY + + +class TestSourceValidatorsEndToEnd: + """Source-kind validators wired through ``DatasetAttackConfiguration`` resolution.""" + + async def test_require_inline_seeds_raises_for_dataset_names(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a") + config = DatasetAttackConfiguration(dataset_names=["d1"], validators=[require_inline_seeds()]) + with pytest.raises(DatasetConstraintError, match="inline"): + await config.get_seed_attack_groups_async() + + async def test_require_inline_seeds_passes_for_inline(self) -> None: + seeds = make_objectives("a", "b") + config = DatasetAttackConfiguration(seeds=seeds, validators=[require_inline_seeds()]) + assert len(await config.get_seed_attack_groups_async()) == 2 + + async def test_forbid_inline_seeds_raises_for_inline(self) -> None: + config = DatasetAttackConfiguration(seeds=make_objectives("a"), validators=[forbid_inline_seeds()]) + with pytest.raises(DatasetConstraintError, match="inline"): + await config.get_seed_attack_groups_async() + + +class TestResolvedDatasetNames: + """``ResolvedDataset.dataset_names`` carries the contributing dataset names to validators.""" + + async def test_resolution_exposes_contributing_names(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.side_effect = [make_objectives("a"), make_objectives("b")] + seen: list[ResolvedDataset] = [] + config = DatasetAttackConfiguration(dataset_names=["d1", "d2"], validators=[seen.append]) + await config.get_seed_attack_groups_async() + assert seen[0].dataset_names == ("d1", "d2") + + async def test_inline_reports_no_dataset_names(self) -> None: + seen: list[ResolvedDataset] = [] + config = DatasetAttackConfiguration(seeds=make_objectives("a"), validators=[seen.append]) + await config.get_seed_attack_groups_async() + assert seen[0].dataset_names == () + + async def test_attack_groups_by_dataset_exposes_contributing_names(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.side_effect = [make_objectives("a"), make_objectives("b")] + seen: list[ResolvedDataset] = [] + config = DatasetAttackConfiguration(dataset_names=["d1", "d2"], validators=[seen.append]) + await config.get_attack_groups_by_dataset_async() + assert seen[0].dataset_names == ("d1", "d2") + + async def test_restrict_dataset_names_raises_for_rogue_dataset(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a") + config = DatasetAttackConfiguration(dataset_names=["rogue"], validators=[restrict_dataset_names({"d1", "d2"})]) + with pytest.raises(DatasetConstraintError, match="not allowed"): + await config.get_seed_attack_groups_async() + + async def test_restrict_dataset_names_passes_for_allowed_dataset(self, mock_memory: MagicMock) -> None: + mock_memory.get_seeds.return_value = make_objectives("a") + config = DatasetAttackConfiguration(dataset_names=["d1"], validators=[restrict_dataset_names({"d1", "d2"})]) + groups = await config.get_seed_attack_groups_async() + assert [g.objective.value for g in groups] == ["a"] diff --git a/tests/unit/scenario/core/test_scenario.py b/tests/unit/scenario/core/test_scenario.py index c5e886944f..0ce39bdea2 100644 --- a/tests/unit/scenario/core/test_scenario.py +++ b/tests/unit/scenario/core/test_scenario.py @@ -17,7 +17,7 @@ from pyrit.executor.attack.core import AttackExecutorResult from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult, ComponentIdentifier -from pyrit.scenario import DatasetConfiguration, ScenarioIdentifier, ScenarioResult +from pyrit.scenario import DatasetAttackConfiguration, DatasetConfiguration, ScenarioIdentifier, ScenarioResult from pyrit.scenario.core import AtomicAttack, BaselineAttackPolicy, Scenario, ScenarioStrategy from pyrit.score import Scorer @@ -1019,10 +1019,10 @@ class TestBaselineEmissionDeprecationRescue: @staticmethod def _dataset_config(): - from pyrit.models import SeedGroup, SeedObjective + from pyrit.models import SeedAttackGroup, SeedObjective - return DatasetConfiguration( - seed_groups=[SeedGroup(seeds=[SeedObjective(value="x")])], + return DatasetAttackConfiguration( + seed_groups=[SeedAttackGroup(seeds=[SeedObjective(value="x")])], ) async def test_rescue_emits_warning_and_injects_baseline(self, mock_objective_target): diff --git a/tests/unit/scenario/foundry/test_red_team_agent.py b/tests/unit/scenario/foundry/test_red_team_agent.py index e1c939bc5f..1799f6946b 100644 --- a/tests/unit/scenario/foundry/test_red_team_agent.py +++ b/tests/unit/scenario/foundry/test_red_team_agent.py @@ -3,7 +3,7 @@ """Tests for the RedTeamAgent class.""" -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -13,7 +13,7 @@ from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedObjective from pyrit.prompt_converter import Base64Converter from pyrit.prompt_target import PromptTarget -from pyrit.scenario import AtomicAttack, DatasetConfiguration, ScenarioCompositeStrategy +from pyrit.scenario import AtomicAttack, DatasetAttackConfiguration, ScenarioCompositeStrategy from pyrit.scenario.foundry import FoundryComposite, FoundryStrategy, RedTeamAgent # type: ignore[ty:unresolved-import] from pyrit.score import FloatScaleThresholdScorer, TrueFalseScorer @@ -49,10 +49,9 @@ def mock_memory_seed_groups(): @pytest.fixture def mock_dataset_config(mock_memory_seed_groups): """Create a mock dataset config that returns the seed groups.""" - mock_config = MagicMock(spec=DatasetConfiguration) - mock_config.get_all_seed_attack_groups.return_value = mock_memory_seed_groups - mock_config.get_default_dataset_names.return_value = ["foundry_red_team"] - mock_config.has_data_source.return_value = True + mock_config = MagicMock(spec=DatasetAttackConfiguration) + mock_config.get_seed_attack_groups_async = AsyncMock(return_value=mock_memory_seed_groups) + mock_config.dataset_names = ["foundry_red_team"] return mock_config @@ -116,7 +115,9 @@ async def test_init_with_single_strategy( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test initialization with a single attack strategy.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -139,7 +140,9 @@ async def test_init_with_multiple_strategies( FoundryStrategy.Leetspeak, ] - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -176,7 +179,9 @@ async def test_init_with_memory_labels( """Test initialization with memory labels.""" memory_labels = {"test": "foundry", "category": "attack"} - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -199,7 +204,9 @@ def test_init_creates_default_scorer_when_not_provided( mock_scorer_instance = MagicMock(spec=TrueFalseScorer) mock_get_scorer.return_value = mock_scorer_instance - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent() # Verify default scorer was used @@ -214,9 +221,14 @@ async def test_init_raises_exception_when_no_datasets_available(self, mock_objec # Don't mock _resolve_seed_groups, let it try to load from empty memory scenario = RedTeamAgent(attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer)) - # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed groups - with pytest.raises(ValueError, match="DatasetConfiguration has no seed_groups"): - await scenario.initialize_async(objective_target=mock_objective_target) + # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed groups. + # Neutralize the provider fetch so the empty-memory path raises loudly instead of fetching. + with patch( + "pyrit.scenario.core.dataset_configuration.DatasetConfiguration._fetch_dataset_async", + new_callable=AsyncMock, + ): + with pytest.raises(ValueError, match="could not be loaded"): + await scenario.initialize_async(objective_target=mock_objective_target) @pytest.mark.usefixtures(*FIXTURES) @@ -227,7 +239,9 @@ async def test_normalize_easy_strategies( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test that EASY strategy expands to easy attack strategies.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -244,7 +258,9 @@ async def test_normalize_moderate_strategies( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test that MODERATE strategy expands to moderate attack strategies.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -261,7 +277,9 @@ async def test_normalize_difficult_strategies( self, mock_objective_target, mock_float_threshold_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test that DIFFICULT strategy expands to difficult attack strategies.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): # DIFFICULT strategy includes TAP which requires FloatScaleThresholdScorer scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_float_threshold_scorer), @@ -279,7 +297,9 @@ async def test_normalize_mixed_difficulty_levels( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test that multiple difficulty levels expand correctly.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -296,7 +316,9 @@ async def test_normalize_with_specific_and_difficulty_levels( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test that specific strategies combined with difficulty levels work correctly.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -321,7 +343,9 @@ async def test_get_attack_from_single_turn_strategy( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test creating an attack from a single-turn strategy.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -348,7 +372,9 @@ async def test_get_attack_from_multi_turn_strategy( mock_dataset_config, ): """Test creating a multi-turn attack strategy.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( adversarial_chat=mock_adversarial_target, attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), @@ -376,7 +402,9 @@ async def test_get_attack_single_turn_with_converters( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config ): """Test creating a single-turn attack with converters.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -403,7 +431,9 @@ async def test_get_attack_multi_turn_with_adversarial_target( mock_dataset_config, ): """Test creating a multi-turn attack.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( adversarial_chat=mock_adversarial_target, attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), @@ -457,7 +487,9 @@ async def test_all_single_turn_strategies_create_attack_runs( self, mock_objective_target, mock_objective_scorer, mock_memory_seed_groups, mock_dataset_config, strategy ): """Test that all single-turn strategies can create attack runs.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -490,7 +522,9 @@ async def test_all_multi_turn_strategies_create_attack_runs( strategy, ): """Test that all multi-turn strategies can create attack runs.""" - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( adversarial_chat=mock_adversarial_target, attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), @@ -518,7 +552,9 @@ async def test_scenario_composites_set_after_initialize( """Test that scenario composites are set after initialize_async.""" strategies = [FoundryStrategy.Base64, FoundryStrategy.ROT13] - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -555,7 +591,9 @@ async def test_scenario_atomic_attack_count_matches_strategies( FoundryStrategy.Leetspeak, ] - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -574,7 +612,9 @@ async def test_initialize_with_foundry_composite_directly( """FoundryComposite objects passed to initialize_async are used as-is.""" composite = FoundryComposite(attack=FoundryStrategy.Crescendo, converters=[FoundryStrategy.Base64]) - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -597,7 +637,9 @@ async def test_initialize_with_mixed_composites_and_strategies( """A mix of bare FoundryStrategy and FoundryComposite can be passed together.""" composite = FoundryComposite(attack=FoundryStrategy.Crescendo, converters=[FoundryStrategy.Base64]) - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -620,7 +662,9 @@ async def test_initialize_converts_scenario_composite_strategy_to_foundry_compos """ScenarioCompositeStrategy passed to initialize_async is converted to FoundryComposite.""" legacy = ScenarioCompositeStrategy(strategies=[FoundryStrategy.Crescendo, FoundryStrategy.Base64]) - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -643,7 +687,9 @@ async def test_initialize_converts_converter_first_composite_strategy( """Converter-first ScenarioCompositeStrategy is routed by tags, not position.""" legacy = ScenarioCompositeStrategy(strategies=[FoundryStrategy.Base64, FoundryStrategy.Crescendo]) - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -665,7 +711,9 @@ async def test_initialize_converts_converter_only_composite_strategy( """Converter-only ScenarioCompositeStrategy maps to attack=None.""" legacy = ScenarioCompositeStrategy(strategies=[FoundryStrategy.Base64, FoundryStrategy.ROT13]) - with patch.object(RedTeamAgent, "_resolve_seed_groups", return_value=mock_memory_seed_groups): + with patch.object( + RedTeamAgent, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_memory_seed_groups + ): scenario = RedTeamAgent( attack_scoring_config=AttackScoringConfig(objective_scorer=mock_objective_scorer), ) @@ -686,10 +734,10 @@ class TestRedTeamAgentBaselineUniformity: """ADO 9012 regression: baseline shares objectives with strategies under max_dataset_size.""" async def test_one_resolution_call_baseline_matches_strategies(self, mock_objective_target, mock_objective_scorer): - from pyrit.models import SeedGroup, SeedObjective + from pyrit.models import SeedAttackGroup, SeedObjective - seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] - config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + seed_groups = [SeedAttackGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] + config = DatasetAttackConfiguration(seed_groups=seed_groups, max_dataset_size=3) first_sample = seed_groups[:3] second_sample = seed_groups[5:8] diff --git a/tests/unit/scenario/garak/test_encoding.py b/tests/unit/scenario/garak/test_encoding.py index 64b6622ae8..9aee5d2216 100644 --- a/tests/unit/scenario/garak/test_encoding.py +++ b/tests/unit/scenario/garak/test_encoding.py @@ -3,7 +3,7 @@ """Tests for the Encoding class.""" -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -11,7 +11,7 @@ from pyrit.models import ComponentIdentifier, SeedAttackGroup, SeedObjective, SeedPrompt from pyrit.prompt_converter import Base64Converter from pyrit.prompt_target import PromptTarget -from pyrit.scenario import DatasetConfiguration +from pyrit.scenario import DatasetAttackConfiguration, DatasetConfiguration from pyrit.scenario.garak import Encoding, EncodingStrategy # type: ignore[ty:unresolved-import] from pyrit.scenario.scenarios.garak.encoding import EncodingDatasetConfiguration from pyrit.score import DecodingScorer, TrueFalseScorer @@ -62,9 +62,8 @@ def mock_seed_attack_groups(mock_memory_seeds): def mock_dataset_config(mock_seed_attack_groups): """Create a mock dataset config that returns the seed attack groups.""" mock_config = MagicMock(spec=EncodingDatasetConfiguration) - mock_config.get_all_seed_attack_groups.return_value = mock_seed_attack_groups - mock_config.get_default_dataset_names.return_value = ["garak_slur_terms_en", "garak_web_html_js"] - mock_config.has_data_source.return_value = True + mock_config.get_seed_attack_groups_async = AsyncMock(return_value=mock_seed_attack_groups) + mock_config.dataset_names = ["garak_slur_terms_en", "garak_web_html_js"] return mock_config @@ -98,7 +97,7 @@ def test_init_with_default_seed_prompts(self, mock_objective_target, mock_object """Test initialization with default seed prompts (Garak dataset).""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=[]): + with patch.object(Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=[]): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -110,7 +109,7 @@ def test_init_with_custom_scorer(self, mock_objective_target, mock_objective_sco """Test initialization with custom objective scorer.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=[]): + with patch.object(Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=[]): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -121,7 +120,7 @@ def test_init_creates_default_scorer_when_not_provided(self, mock_objective_targ """Test that initialization creates default DecodingScorer when not provided.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=[]): + with patch.object(Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=[]): scenario = Encoding() # Should create a DecodingScorer by default @@ -129,20 +128,23 @@ def test_init_creates_default_scorer_when_not_provided(self, mock_objective_targ assert isinstance(scenario._scorer_config.objective_scorer, DecodingScorer) async def test_init_raises_exception_when_no_datasets_available(self, mock_objective_target, mock_objective_scorer): - """Test that initialization raises ValueError when datasets are not available in memory.""" + """Test that initialization raises ValueError when datasets are not available and auto-fetch finds nothing.""" + from unittest.mock import patch - # Don't mock _resolve_seed_groups, let it try to load from empty memory + # Don't mock _resolve_seed_groups_async; let it try to load from empty memory. + # Disable the provider fallback so memory stays empty and the scenario raises. scenario = Encoding(objective_scorer=mock_objective_scorer) - # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed prompts - with pytest.raises(ValueError, match="No seeds found in the configured datasets"): - await scenario.initialize_async(objective_target=mock_objective_target) + with patch.object(EncodingDatasetConfiguration, "_fetch_dataset_async", new_callable=AsyncMock): + # Error should occur during initialize_async when _get_atomic_attacks_async resolves seed prompts + with pytest.raises(ValueError, match="Dataset is not available or failed to load"): + await scenario.initialize_async(objective_target=mock_objective_target) def test_init_with_memory_labels(self, mock_objective_target, mock_objective_scorer, mock_memory_seeds): """Test initialization with memory labels.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=[]): + with patch.object(Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=[]): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -156,7 +158,7 @@ def test_init_with_custom_encoding_templates(self, mock_objective_target, mock_o custom_templates = ["template1", "template2"] - with patch.object(Encoding, "_resolve_seed_groups", return_value=[]): + with patch.object(Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=[]): scenario = Encoding( encoding_templates=custom_templates, objective_scorer=mock_objective_scorer, @@ -168,7 +170,7 @@ def test_init_with_max_concurrency(self, mock_objective_target, mock_objective_s """Test initialization with custom max_concurrency.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=[]): + with patch.object(Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=[]): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -182,7 +184,9 @@ async def test_init_attack_strategies( """Test that attack strategies are set correctly.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=mock_seed_attack_groups): + with patch.object( + Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_seed_attack_groups + ): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -207,7 +211,9 @@ async def test_get_atomic_attacks_async_returns_attacks( """Test that _get_atomic_attacks_async returns atomic attacks.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=mock_seed_attack_groups): + with patch.object( + Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_seed_attack_groups + ): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -225,7 +231,9 @@ async def test_get_converter_attacks_returns_multiple_encodings( """Test that _get_converter_attacks returns attacks for multiple encoding types.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=mock_seed_attack_groups): + with patch.object( + Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_seed_attack_groups + ): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -244,7 +252,9 @@ async def test_get_prompt_attacks_creates_attack_runs( """Test that _get_prompt_attacks creates attack runs with correct structure.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=mock_seed_attack_groups): + with patch.object( + Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_seed_attack_groups + ): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -271,7 +281,9 @@ async def test_attack_runs_include_objectives( """Test that attack runs include objectives for each seed prompt.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=mock_seed_attack_groups): + with patch.object( + Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_seed_attack_groups + ): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -300,7 +312,9 @@ async def test_scenario_initialization( """Test that scenario can be initialized successfully.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=mock_seed_attack_groups): + with patch.object( + Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_seed_attack_groups + ): scenario = Encoding( objective_scorer=mock_objective_scorer, ) @@ -313,16 +327,18 @@ async def test_scenario_initialization( async def test_resolve_seed_groups_loads_garak_data( self, mock_objective_target, mock_objective_scorer, mock_seed_attack_groups, mock_dataset_config ): - """Test that _resolve_seed_groups loads data from Garak datasets.""" + """Test that _resolve_seed_groups_async loads data from Garak datasets.""" from unittest.mock import patch - with patch.object(Encoding, "_resolve_seed_groups", return_value=mock_seed_attack_groups): + with patch.object( + Encoding, "_resolve_seed_groups_async", new_callable=AsyncMock, return_value=mock_seed_attack_groups + ): scenario = Encoding( objective_scorer=mock_objective_scorer, ) # After resolve, should have seed groups - resolved = scenario._resolve_seed_groups() + resolved = await scenario._resolve_seed_groups_async() assert len(resolved) > 0 # Verify it's returning SeedAttackGroup objects @@ -341,7 +357,7 @@ def test_default_dataset_config_returns_encoding_config(self, mock_objective_sco def test_default_dataset_config_uses_garak_datasets(self, mock_objective_scorer): """Test that the default config uses the expected garak datasets.""" config = Encoding(objective_scorer=mock_objective_scorer)._default_dataset_config - dataset_names = config.get_default_dataset_names() + dataset_names = config.dataset_names assert "garak_slur_terms_en" in dataset_names assert "garak_web_html_js" in dataset_names @@ -352,38 +368,54 @@ def test_default_dataset_config_has_max_size(self, mock_objective_scorer): @pytest.mark.usefixtures("patch_central_database") -@pytest.mark.usefixtures("patch_central_database") -class TestEncodingDatasetConfigurationGetAllSeedAttackGroups: - """Tests for EncodingDatasetConfiguration.get_all_seed_attack_groups method.""" - - def test_get_all_seed_attack_groups_transforms_seeds(self, mock_memory_seeds): - """Test that get_all_seed_attack_groups transforms seeds into SeedAttackGroups.""" +class TestEncodingDatasetConfigurationBuildAttackGroups: + """Tests for EncodingDatasetConfiguration._build_attack_groups and resolution.""" + + def test_build_attack_groups_transforms_seeds(self, mock_memory_seeds): + """Test that _build_attack_groups transforms raw seeds into objective-bearing SeedAttackGroups.""" + config = EncodingDatasetConfiguration(dataset_names=["garak_slur_terms_en"]) + result = config._build_attack_groups(mock_memory_seeds) + + assert len(result) == len(mock_memory_seeds) + for i, group in enumerate(result): + assert isinstance(group, SeedAttackGroup) + # First seed should be a SeedObjective with the encoding objective format + assert isinstance(group.seeds[0], SeedObjective) + assert "Make the model say an encoded payload:" in group.seeds[0].value + assert mock_memory_seeds[i].value in group.seeds[0].value + # Second seed should be the original SeedPrompt + assert isinstance(group.seeds[1], SeedPrompt) + assert group.seeds[1].value == mock_memory_seeds[i].value + + def test_build_attack_groups_empty_returns_empty(self): + """Test that _build_attack_groups returns an empty list when given no seeds.""" + config = EncodingDatasetConfiguration(dataset_names=["empty_dataset"]) + assert config._build_attack_groups([]) == [] + + async def test_get_seed_attack_groups_async_transforms_memory_seeds(self, mock_memory_seeds): + """Test that get_seed_attack_groups_async loads seeds and shapes them via _build_attack_groups.""" from unittest.mock import patch - with patch.object(EncodingDatasetConfiguration, "get_all_seeds", return_value=mock_memory_seeds): - config = EncodingDatasetConfiguration(dataset_names=["garak_slur_terms_en"]) - result = config.get_all_seed_attack_groups() - - assert len(result) == len(mock_memory_seeds) - for i, group in enumerate(result): - assert isinstance(group, SeedAttackGroup) - # First seed should be a SeedObjective with the encoding objective format - assert isinstance(group.seeds[0], SeedObjective) - assert "Make the model say an encoded payload:" in group.seeds[0].value - assert mock_memory_seeds[i].value in group.seeds[0].value - # Second seed should be the original SeedPrompt - assert isinstance(group.seeds[1], SeedPrompt) - assert group.seeds[1].value == mock_memory_seeds[i].value - - def test_get_all_seed_attack_groups_raises_when_no_seeds(self): - """Test that get_all_seed_attack_groups raises ValueError when no seeds found.""" - from unittest.mock import patch + config = EncodingDatasetConfiguration(dataset_names=["garak_slur_terms_en"], auto_fetch=False) + with patch.object( + EncodingDatasetConfiguration, + "_collect_named_seeds_async", + new_callable=AsyncMock, + return_value={"garak_slur_terms_en": mock_memory_seeds}, + ): + result = await config.get_seed_attack_groups_async() + + assert len(result) == len(mock_memory_seeds) + assert all(isinstance(group, SeedAttackGroup) for group in result) + + async def test_get_seed_attack_groups_async_raises_when_empty(self): + """Test that get_seed_attack_groups_async raises DatasetConstraintError when nothing resolves.""" + from pyrit.scenario.core.dataset_configuration import DatasetConstraintError - with patch.object(EncodingDatasetConfiguration, "get_all_seeds", return_value=[]): - config = EncodingDatasetConfiguration(dataset_names=["empty_dataset"]) + config = EncodingDatasetConfiguration(dataset_names=["empty_dataset"], auto_fetch=False) - with pytest.raises(ValueError, match="No seeds found in the configured datasets"): - config.get_all_seed_attack_groups() + with pytest.raises(DatasetConstraintError): + await config.get_seed_attack_groups_async() def test_encoding_dataset_config_inherits_from_dataset_config(self): """Test that EncodingDatasetConfiguration is a subclass of DatasetConfiguration.""" @@ -407,10 +439,10 @@ class TestEncodingBaselineUniformity: async def test_one_resolution_call_baseline_matches_strategies(self, mock_objective_target, mock_objective_scorer): from unittest.mock import patch - from pyrit.models import SeedGroup, SeedObjective + from pyrit.models import SeedAttackGroup, SeedObjective - seed_groups = [SeedGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] - config = DatasetConfiguration(seed_groups=seed_groups, max_dataset_size=3) + seed_groups = [SeedAttackGroup(seeds=[SeedObjective(value=f"obj{i}")]) for i in range(10)] + config = DatasetAttackConfiguration(seed_groups=seed_groups, max_dataset_size=3) first_sample = seed_groups[:3] second_sample = seed_groups[5:8] diff --git a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py index 848851ae0a..2bada6036a 100644 --- a/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py +++ b/tests/unit/scenario/scenarios/adaptive/test_text_adaptive.py @@ -7,7 +7,7 @@ import uuid import warnings -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -15,7 +15,10 @@ from pyrit.models.identifiers import ComponentIdentifier from pyrit.prompt_target import PromptTarget from pyrit.registry.object_registries.attack_technique_registry import AttackTechniqueRegistry -from pyrit.scenario.core.dataset_configuration import DatasetConfiguration +from pyrit.scenario.core.dataset_configuration import ( + DatasetAttackConfiguration, + DatasetConfiguration, +) from pyrit.scenario.core.scenario import BaselineAttackPolicy from pyrit.scenario.scenarios.adaptive.dispatcher import ( AdaptiveTechniqueDispatcher, @@ -164,7 +167,12 @@ async def _build_scenario_and_attacks( seed_groups: dict[str, list[SeedAttackGroup]], **scenario_kwargs, ): - with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=seed_groups): + with patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=seed_groups, + ): scenario = TextAdaptive( objective_scorer=mock_objective_scorer, **scenario_kwargs, @@ -211,7 +219,12 @@ def _spy_init(self, *args, **kwargs): "violence": [_make_seed_group(value="obj-v1", harm_categories=["violence"])], "hate": [_make_seed_group(value="obj-h1", harm_categories=["hate"])], } - with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + with patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) await scenario.initialize_async( objective_target=mock_objective_target, @@ -257,7 +270,12 @@ async def test_display_group_is_dataset_name(self, mock_objective_target, mock_o async def test_no_usable_techniques_raises(self, mock_objective_target, mock_objective_scorer): groups = {"violence": [_make_seed_group(value="obj")]} - with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + with patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) await scenario.initialize_async( objective_target=mock_objective_target, @@ -277,7 +295,12 @@ async def test_techniques_with_seed_technique_are_kept(self, mock_objective_targ seeded_factory = _make_fake_factory(seed_technique=MagicMock(name="seed_technique")) with ( - patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ), patch.object(SeedAttackGroup, "is_compatible_with_technique", return_value=True), ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) @@ -313,7 +336,12 @@ async def test_incompatible_seed_technique_is_filtered_per_objective( # Only the plain factory (no seed_technique) is compatible. with ( - patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ), patch.object(SeedAttackGroup, "is_compatible_with_technique", return_value=False), ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) @@ -356,7 +384,12 @@ def _selective_compat(self_group, *, technique): return self_group.objective.value == "obj-keep" with ( - patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ), patch.object(SeedAttackGroup, "is_compatible_with_technique", _selective_compat), ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) @@ -394,7 +427,12 @@ class NarrowScoringConfig(AttackScoringConfig): groups = {"violence": [_make_seed_group(value="obj")]} narrow_factory = _make_fake_factory(scoring_config_type=NarrowScoringConfig) with ( - patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ), patch.object(SeedAttackGroup, "is_compatible_with_technique", return_value=True), ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) @@ -436,7 +474,12 @@ def __init__(self, *, objective_scorer): strict_factory = _make_fake_factory(scoring_config_type=StrictScoringConfig) with ( - patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ), patch.object(SeedAttackGroup, "is_compatible_with_technique", return_value=True), ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) @@ -474,7 +517,12 @@ async def test_factory_create_failure_skips_technique(self, mock_objective_targe bad_factory.create.side_effect = ValueError("requires FloatScaleThresholdScorer") with ( - patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ), patch.object(SeedAttackGroup, "is_compatible_with_technique", return_value=True), ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) @@ -503,7 +551,12 @@ async def test_all_factories_failing_raises_with_reason(self, mock_objective_tar bad_factory.create.side_effect = ValueError("requires FloatScaleThresholdScorer") with ( - patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups), + patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ), patch.object(SeedAttackGroup, "is_compatible_with_technique", return_value=True), ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) @@ -525,7 +578,12 @@ async def test_all_factories_failing_raises_with_reason(self, mock_objective_tar class TestTextAdaptiveBaselinePolicy: async def test_initialize_async_accepts_explicit_baseline(self, mock_objective_target, mock_objective_scorer): groups = {"violence": [_make_seed_group(value="obj", harm_categories=["violence"])]} - with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + with patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) # Baseline is Enabled by default, so explicit include_baseline=True must not raise. await scenario.initialize_async( @@ -541,7 +599,12 @@ async def test_baseline_emitted_at_index_zero_by_default(self, mock_objective_ta for removal in 0.16.0) is bypassed and no DeprecationWarning fires. """ groups = {"violence": [_make_seed_group(value="obj", harm_categories=["violence"])]} - with patch.object(DatasetConfiguration, "get_seed_attack_groups", return_value=groups): + with patch.object( + DatasetAttackConfiguration, + "get_attack_groups_by_dataset_async", + new_callable=AsyncMock, + return_value=groups, + ): scenario = TextAdaptive(objective_scorer=mock_objective_scorer) with warnings.catch_warnings(): warnings.simplefilter("error", DeprecationWarning)