diff --git a/isaaclab_arena/environments/relation_solver_interface.py b/isaaclab_arena/environments/relation_solver_interface.py index aad1019d3..298df3633 100644 --- a/isaaclab_arena/environments/relation_solver_interface.py +++ b/isaaclab_arena/environments/relation_solver_interface.py @@ -68,8 +68,8 @@ def solve_and_apply_relation_placement( if placement_pool.had_fallbacks: print( - "Warning: Relation placement pool accepted best-loss fallback layouts " - "that failed strict placement validation." + "Warning: Relation placement pool served best-loss fallback layouts " + "that did not meet the pool's acceptance criteria." ) return _apply_relation_placement_result( diff --git a/isaaclab_arena/relations/layout_pool_serialization.py b/isaaclab_arena/relations/layout_pool_serialization.py new file mode 100644 index 000000000..548c0ef37 --- /dev/null +++ b/isaaclab_arena/relations/layout_pool_serialization.py @@ -0,0 +1,233 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""JSON serialization for PooledObjectPlacer layout pools. + +Owns the on-disk schema and its validation so the placer keeps only pool-state orchestration. +The pool is regenerable by re-solving, so a stale/incompatible file is meant to be re-saved +rather than migrated; load fails loudly on any structural problem instead of placing wrong poses. + +On-disk schema (PoolDocument.to_dict / from_dict): + { + "placement_seed": int | null, # restored on load so sampling reproduces the saved run + "num_envs": int, + "uses_env_specific_bboxes": bool, + "had_fallbacks": bool, # whether any stored layout was a best-loss fallback + "env_pools": [ # one list per env, outer length == num_envs + [ { # one entry per stored layout (serialize_layout) + "positions": {obj_name: [x, y, z]}, + "orientations": {obj_name: yaw}, + "validation": {check_name: bool}, + "final_loss": float, + "attempts": int, + }, ... ], + ], + } +""" + +from __future__ import annotations + +import json +import math +import os +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport + +if TYPE_CHECKING: + from isaaclab_arena.assets.object_base import ObjectBase + +_POOL_REQUIRED_KEYS = ("placement_seed", "num_envs", "uses_env_specific_bboxes", "had_fallbacks", "env_pools") +_LAYOUT_REQUIRED_KEYS = ("positions", "orientations", "validation", "final_loss", "attempts") + + +@dataclass(frozen=True) +class PoolDocument: + """Pool-wide metadata plus raw per-env layout dicts. + + env_pools holds serialized per-layout dicts (see serialize_layout); the caller materializes + them with deserialize_layout once it knows the live objects. + """ + + placement_seed: int | None + num_envs: int + uses_env_specific_bboxes: bool + had_fallbacks: bool + env_pools: list[list[dict]] + + def __post_init__(self) -> None: + # Self-validate at every construction path (incl. direct construction in save), so a + # PoolDocument can't exist in a state from_dict would reject. from_dict re-checks the same + # invariants earlier with the file path for better load-time messages. + assert isinstance(self.num_envs, int) and not isinstance(self.num_envs, bool), "num_envs must be an int." + assert self.placement_seed is None or ( + isinstance(self.placement_seed, int) and not isinstance(self.placement_seed, bool) + ), "placement_seed must be int or None." + assert isinstance(self.uses_env_specific_bboxes, bool), "uses_env_specific_bboxes must be a bool." + assert isinstance(self.had_fallbacks, bool), "had_fallbacks must be a bool." + assert self.num_envs == len( + self.env_pools + ), f"PoolDocument num_envs ({self.num_envs}) must match env_pools length ({len(self.env_pools)})." + + def to_dict(self) -> dict: + return { + "placement_seed": self.placement_seed, + "num_envs": self.num_envs, + "uses_env_specific_bboxes": self.uses_env_specific_bboxes, + "had_fallbacks": self.had_fallbacks, + "env_pools": self.env_pools, + } + + @classmethod + def from_dict(cls, data: object, path: Path) -> PoolDocument: + """Structurally validate a parsed document, naming the path on any problem.""" + assert isinstance(data, dict), f"Layout pool file is not a JSON object: {path}" + for key in _POOL_REQUIRED_KEYS: + assert key in data, f"Layout pool file is missing required key '{key}': {path}" + + env_pools = data["env_pools"] + num_envs = data["num_envs"] + seed = data["placement_seed"] + assert isinstance(env_pools, list), f"Layout pool 'env_pools' must be a list: {path}" + assert isinstance(num_envs, int) and not isinstance( + num_envs, bool + ), f"Layout pool 'num_envs' must be an int: {path}" + assert seed is None or ( + isinstance(seed, int) and not isinstance(seed, bool) + ), f"Layout pool 'placement_seed' must be int or null: {path}" + assert isinstance( + data["uses_env_specific_bboxes"], bool + ), f"Layout pool 'uses_env_specific_bboxes' must be a bool: {path}" + assert isinstance(data["had_fallbacks"], bool), f"Layout pool 'had_fallbacks' must be a bool: {path}" + assert num_envs == len(env_pools), f"Corrupt layout pool: num_envs does not match env_pools length: {path}" + for cur_env, env_layouts in enumerate(env_pools): + assert isinstance(env_layouts, list), f"Layout pool env {cur_env} must be a list: {path}" + for entry in env_layouts: + assert isinstance(entry, dict), f"Layout pool env {cur_env} has a non-dict layout entry: {path}" + return cls( + placement_seed=seed, + num_envs=num_envs, + uses_env_specific_bboxes=data["uses_env_specific_bboxes"], + had_fallbacks=data["had_fallbacks"], + env_pools=env_pools, + ) + + +def write_pool_document(path: Path, document: PoolDocument) -> None: + """Atomically write a pool document, failing loudly on non-finite values. + + Serializes with allow_nan=False first, so a NaN/inf pose or loss raises before any file is + touched; then writes a temp file and os.replace, so the destination is never half-written and a + mid-write OSError removes the temp. A hard crash between write and replace may leave a stale + .tmp, but the destination stays intact. + """ + payload = json.dumps(document.to_dict(), indent=2, allow_nan=False) + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_name(f"{path.name}.tmp") + try: + tmp_path.write_text(payload, encoding="utf-8") + os.replace(tmp_path, path) + except OSError: + tmp_path.unlink(missing_ok=True) + raise + + +def read_pool_document(path: Path) -> PoolDocument: + """Read and structurally validate a pool document, naming the path on any problem. + + Validates only path-level structure; the caller still checks caller-dependent invariants + (requested num_envs, heterogeneity, objects). + """ + assert path.is_file(), f"Layout pool file not found: {path}" + + def reject_non_finite(token: str): + # json.loads accepts NaN/Infinity by default; reject here to mirror the allow_nan=False write. + raise ValueError(f"Layout pool file contains non-finite JSON constant '{token}': {path}") + + try: + text = path.read_text(encoding="utf-8") + except OSError as exc: + raise ValueError(f"Layout pool file could not be read: {path}") from exc + try: + data = json.loads(text, parse_constant=reject_non_finite) + except json.JSONDecodeError as exc: + raise ValueError(f"Layout pool file is not valid JSON: {path}") from exc + return PoolDocument.from_dict(data, path) + + +def serialize_layout(result: PlacementResult) -> dict: + """Flatten one layout to JSON-safe primitives, keyed by object name.""" + return { + "positions": {obj.name: list(pos) for obj, pos in result.positions.items()}, + "orientations": {obj.name: yaw for obj, yaw in result.orientations.items()}, + "validation": dict(result.validation.checks), + "final_loss": result.final_loss, + "attempts": result.attempts, + } + + +def deserialize_layout(data: dict, name_to_obj: dict[str, ObjectBase]) -> PlacementResult: + """Rebuild a PlacementResult, re-keying by the live objects that match each saved name.""" + for key in _LAYOUT_REQUIRED_KEYS: + assert key in data, f"Serialized layout is missing required key '{key}'." + + positions_data = data["positions"] + orientations_data = data["orientations"] + checks = data["validation"] + final_loss = data["final_loss"] + attempts = data["attempts"] + assert isinstance( + positions_data, dict + ), f"Serialized layout 'positions' must be a dict, got {type(positions_data).__name__}." + assert isinstance( + orientations_data, dict + ), f"Serialized layout 'orientations' must be a dict, got {type(orientations_data).__name__}." + assert isinstance(checks, dict), f"Serialized layout 'validation' must be a dict, got {type(checks).__name__}." + assert bool(checks), "Serialized layout has an empty validation map; it would load as a failing layout." + for name, ok in checks.items(): + assert isinstance(ok, bool), f"Validation check '{name}' must be a JSON bool, got {type(ok).__name__}." + assert ( + isinstance(final_loss, (int, float)) and not isinstance(final_loss, bool) and math.isfinite(final_loss) + ), f"Serialized 'final_loss' must be a finite number, got {final_loss!r}." + assert isinstance(attempts, int) and not isinstance( + attempts, bool + ), f"Serialized 'attempts' must be an int, got {attempts!r}." + + # Every live object must have a saved pose, and vice versa, so a stale file can't silently + # leave an object at its origin (or reference one no longer in the scene). + assert set(positions_data) == set(name_to_obj), ( + f"Saved layout objects {sorted(positions_data)} do not match the provided objects " + f"{sorted(name_to_obj)}; re-solve instead of loading this cache." + ) + assert set(orientations_data) <= set( + positions_data + ), f"Saved orientations {sorted(orientations_data)} are not a subset of positions {sorted(positions_data)}." + + def parse_position(name: str, pos: object) -> tuple[float, float, float]: + assert ( + isinstance(pos, (list, tuple)) and len(pos) == 3 + ), f"Serialized position for '{name}' must be a length-3 sequence, got {pos!r}." + assert all( + isinstance(c, (int, float)) and not isinstance(c, bool) and math.isfinite(c) for c in pos + ), f"Serialized position for '{name}' must be finite numbers, got {pos!r}." + return (float(pos[0]), float(pos[1]), float(pos[2])) + + def parse_yaw(name: str, yaw: object) -> float: + assert ( + isinstance(yaw, (int, float)) and not isinstance(yaw, bool) and math.isfinite(yaw) + ), f"Serialized orientation for '{name}' must be a finite number, got {yaw!r}." + return float(yaw) + + positions = {name_to_obj[name]: parse_position(name, pos) for name, pos in positions_data.items()} + orientations = {name_to_obj[name]: parse_yaw(name, yaw) for name, yaw in orientations_data.items()} + return PlacementResult( + positions=positions, + orientations=orientations, + validation=ValidationReport(checks=dict(checks)), + final_loss=float(final_loss), + attempts=attempts, + ) diff --git a/isaaclab_arena/relations/object_placer.py b/isaaclab_arena/relations/object_placer.py index b9e6c9b2a..1d2d7d756 100644 --- a/isaaclab_arena/relations/object_placer.py +++ b/isaaclab_arena/relations/object_placer.py @@ -11,7 +11,13 @@ from isaaclab_arena.relations.bounding_box_helpers import assign_variants_for_envs, build_per_env_bounding_boxes from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams -from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult +from isaaclab_arena.relations.placement_result import ( + LayoutFilter, + MultiEnvPlacementResult, + PlacementResult, + ValidationReport, + default_layout_filter, +) from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relations import ( IsAnchor, @@ -38,8 +44,8 @@ class PlacementCandidate: positions: dict[ObjectBase, tuple[float, float, float]] """Solved positions for each object.""" - is_valid: bool - """Whether the placement passed validation checks.""" + validation: ValidationReport + """Per-check validation outcome for this candidate.""" orientations: dict[ObjectBase, float] = field(default_factory=dict) """Per-object yaw (radians about Z) sampled for this candidate. Empty when unrotated.""" @@ -52,7 +58,7 @@ class ObjectPlacer: 1. Random initialization of candidate positions per environment 2. Running the RelationSolver on all candidates in one batch 3. Validating each candidate - 4. Ranking candidates per environment (valid first, then by loss) + 4. Ranking candidates per environment (accepted first, then by loss) 5. Applying the best layout per environment to the objects Supports single-env (num_envs=1) and batched (num_envs>1) placement. @@ -64,9 +70,12 @@ class ObjectPlacer: position may fall outside the actual surface. """ - def __init__(self, params: ObjectPlacerParams | None = None): + def __init__(self, params: ObjectPlacerParams | None = None, layout_filter: LayoutFilter | None = None): self.params = params or ObjectPlacerParams() self._solver = RelationSolver(params=self.params.solver_params) + # Acceptance predicate that ranking sorts by, so place() returns the best accepted layout. + # Defaults to "all checks pass"; a pool injects its own to keep ranking and storage aligned. + self._accepts = layout_filter or default_layout_filter def place( self, @@ -123,7 +132,7 @@ def place_ranked_per_env( candidate layouts. Use place() for selected placement results. The return value has shape (num_envs, results_per_env): each outer list entry corresponds to a real env, and each inner list is - sorted with valid lower-loss layouts first. + sorted with accepted lower-loss layouts first. """ assert results_per_env > 0, f"results_per_env must be positive, got {results_per_env}" anchor_objects_set, generator = self._prepare_placement(objects) @@ -182,7 +191,7 @@ def _place_ranked( """Solve and rank placement candidates per environment. Each env is solved against its own per-env bounding boxes, and its - candidates are ranked independently (valid first, then by loss), so a + candidates are ranked independently (accepted first, then by loss), so a candidate is never compared against another env's geometry. """ # Variant assignment fixes the env-to-USD mapping before bbox expansion. @@ -214,7 +223,7 @@ def _place_ranked( assert self._solver.last_loss_per_env is not None all_losses: list[float] = self._solver.last_loss_per_env.cpu().tolist() all_validations = [ - self._validate_placement( + self._validate_geometry( positions, self._get_bounding_boxes_for_candidate_index(candidate_bboxes, candidate_idx) ) for candidate_idx, positions in enumerate(all_positions) @@ -235,7 +244,7 @@ def _place_ranked( ranked_results = [ [ PlacementResult( - success=candidate.is_valid, + validation=candidate.validation, positions=candidate.positions, final_loss=candidate.loss, attempts=attempts_per_result, @@ -251,19 +260,19 @@ def _place_ranked( return ranked_results - @staticmethod def _rank_candidates( + self, candidates: list[PlacementCandidate], num_envs: int, candidates_per_env: int, ) -> list[list[PlacementCandidate]]: - """Return one loss-sorted candidate slice per env (valid candidates first).""" + """Return one loss-sorted candidate slice per env (accepted candidates first).""" ranked_candidate_slices: list[list[PlacementCandidate]] = [] for cur_env in range(num_envs): start = cur_env * candidates_per_env env_candidates = candidates[start : start + candidates_per_env] ranked_candidate_slices.append( - sorted(env_candidates, key=lambda candidate: (not candidate.is_valid, candidate.loss)) + sorted(env_candidates, key=lambda candidate: (not self._accepts(candidate.validation), candidate.loss)) ) return ranked_candidate_slices @@ -273,8 +282,10 @@ def _print_ranked_summary( num_candidates: int, num_envs: int, ) -> None: - n_valid = sum(1 for candidate_slice in ranked_candidate_slices if candidate_slice[0].is_valid) - print(f"Solved {num_candidates} candidates in one batch: {n_valid}/{num_envs} env(s) valid") + n_accepted = sum( + 1 for candidate_slice in ranked_candidate_slices if self._accepts(candidate_slice[0].validation) + ) + print(f"Solved {num_candidates} candidates in one batch: {n_accepted}/{num_envs} env(s) accepted") def _generate_initial_positions( self, @@ -598,21 +609,29 @@ def _validate_no_overlap( return False return True - def _validate_placement( + def _validate_geometry( self, positions: dict[ObjectBase, tuple[float, float, float]], env_bboxes: dict[ObjectBase, AxisAlignedBoundingBox], - ) -> bool: - """Validate that no two objects overlap in 3D and On relations are satisfied. + ) -> ValidationReport: + """Run the geometry checks and return them as a per-check ValidationReport. + + This is the geometry validation stage. Further validation checks extend the result with + ValidationReport.with_check rather than adding cases here. Args: positions: Dictionary mapping objects to their solved (x, y, z) positions. env_bboxes: Per-object bboxes for the current env, each with shape (1, 3). Returns: - True if no overlaps exist and On relations hold, False otherwise. + A ValidationReport mapping each geometry check name to whether it passed. """ - return self._validate_no_overlap(positions, env_bboxes) and self._validate_on_relations(positions, env_bboxes) + return ValidationReport( + checks={ + "no_overlap": self._validate_no_overlap(positions, env_bboxes), + "on_relations": self._validate_on_relations(positions, env_bboxes), + } + ) def _apply_poses( self, diff --git a/isaaclab_arena/relations/placement_events.py b/isaaclab_arena/relations/placement_events.py index dfd6ebcaf..601fb4831 100644 --- a/isaaclab_arena/relations/placement_events.py +++ b/isaaclab_arena/relations/placement_events.py @@ -67,10 +67,10 @@ def solve_and_place_objects( for cur_env in reset_env_ids: env_id_tensor = torch.tensor([cur_env], device=env.device) result = results_by_env[cur_env] - if not result.success: + if not placement_pool.accepts(result): print( "Warning: Writing best-loss fallback placement for " - f"env {cur_env}; layout failed strict placement validation." + f"env {cur_env}; layout did not meet the placement pool's acceptance criteria." ) for obj, pos in result.positions.items(): if obj in anchor_objects_set: diff --git a/isaaclab_arena/relations/placement_result.py b/isaaclab_arena/relations/placement_result.py index cc4a9e7b3..c34b0d65c 100644 --- a/isaaclab_arena/relations/placement_result.py +++ b/isaaclab_arena/relations/placement_result.py @@ -5,20 +5,73 @@ from __future__ import annotations +from collections.abc import Callable, Mapping from dataclasses import dataclass, field +from types import MappingProxyType from typing import TYPE_CHECKING if TYPE_CHECKING: from isaaclab_arena.assets.object_base import ObjectBase +@dataclass(frozen=True) +class ValidationReport: + """Per-check outcome of placement validation. + + The check set is open: further validation checks add their own named result by deriving a new + report via with_check, so acceptance can consider more than the built-in geometry checks. + """ + + checks: Mapping[str, bool] + """Each check name (e.g. "no_overlap", "on_relations") mapped to its pass/fail result.""" + + def __post_init__(self) -> None: + # Enforce bool here so passed/failed_checks stay sound: with_check (the path engineers use) + # must not store a truthy non-bool (e.g. a tensor) that silently satisfies all(...). + assert all( + isinstance(v, bool) for v in self.checks.values() + ), f"ValidationReport checks must be bools, got {dict(self.checks)}" + # Read-only snapshot: neither the caller's original dict nor report.checks[...] can mutate it. + object.__setattr__(self, "checks", MappingProxyType(dict(self.checks))) + + def __reduce__(self): + # MappingProxyType can't be pickled/deepcopied, and Isaac Lab deepcopies the EventTermCfg + # params that carry this report; rebuild from a plain dict so copy/pickle round-trip. + return (self.__class__, (dict(self.checks),)) + + @property + def passed(self) -> bool: + """True only when at least one check ran and every check passed (empty fails closed).""" + return bool(self.checks) and all(self.checks.values()) + + @property + def failed_checks(self) -> tuple[str, ...]: + """Names of the checks that failed, in insertion order.""" + return tuple(name for name, ok in self.checks.items() if not ok) + + def with_check(self, name: str, passed: bool) -> ValidationReport: + """Return a new report with one more named check (an existing name is overwritten). + + Reports are immutable, so a further validation check records its outcome by deriving a new + report rather than mutating this one. + """ + assert isinstance(passed, bool), f"with_check('{name}', ...) requires a bool, got {type(passed).__name__}" + return ValidationReport(checks={**self.checks, name: passed}) + + +LayoutFilter = Callable[[ValidationReport], bool] +"""Acceptance predicate: given a layout's ValidationReport, whether the layout is kept.""" + + +def default_layout_filter(report: ValidationReport) -> bool: + """Default acceptance: keep a layout iff every check passed (the built-in geometry checks and any added later).""" + return report.passed + + @dataclass class PlacementResult: """Result of an ObjectPlacer.place() call.""" - success: bool - """Whether placement passed validation checks.""" - positions: dict[ObjectBase, tuple[float, float, float]] """Final positions for each object.""" @@ -28,10 +81,18 @@ class PlacementResult: attempts: int """Number of attempts made.""" + validation: ValidationReport + """Per-check validation outcome; success is derived from it.""" + orientations: dict[ObjectBase, float] = field(default_factory=dict) """Per-object yaw (radians) about the world up (Z) axis, composed on top of each object's base rotation. Keyed by object, like positions. Empty when unrotated.""" + @property + def success(self) -> bool: + """Whether placement passed validation checks.""" + return self.validation.passed + @dataclass class MultiEnvPlacementResult: diff --git a/isaaclab_arena/relations/pooled_object_placer.py b/isaaclab_arena/relations/pooled_object_placer.py index 2fc8f020b..3489050e0 100644 --- a/isaaclab_arena/relations/pooled_object_placer.py +++ b/isaaclab_arena/relations/pooled_object_placer.py @@ -5,26 +5,60 @@ from __future__ import annotations +import random import torch from dataclasses import dataclass, replace +from pathlib import Path from typing import TYPE_CHECKING from isaaclab_arena.relations.bounding_box_helpers import has_heterogeneous_objects +from isaaclab_arena.relations.layout_pool_serialization import ( + PoolDocument, + deserialize_layout, + read_pool_document, + serialize_layout, + write_pool_document, +) from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams -from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult +from isaaclab_arena.relations.placement_result import ( + LayoutFilter, + MultiEnvPlacementResult, + PlacementResult, + default_layout_filter, +) from isaaclab_arena.utils.random import get_rngs if TYPE_CHECKING: from isaaclab_arena.assets.object_base import ObjectBase +@dataclass +class PooledLayout: + """A stored layout plus its use_count. + + use_count lives here, not on PlacementResult, because one stored layout can be drawn repeatedly + by sample_with_replacement (via _draw) and tracks how often this entry was served. + """ + + result: PlacementResult + use_count: int = 0 + """Number of times this layout has been served (consuming or not).""" + + def mark_used(self) -> None: + """Record that this layout was served to a caller.""" + self.use_count += 1 + + @dataclass class EnvLayoutPool: - """Unread layout queue for one absolute environment.""" + """Layout queue for one absolute environment.""" + + layouts: list[PooledLayout] + """All layouts in pool order, including consumed ones, so stored_layouts/save() still see them.""" - layouts: list[PlacementResult] cursor: int = 0 + """Split point: layouts[:cursor] are consumed, layouts[cursor:] are unread.""" @property def available(self) -> int: @@ -34,17 +68,18 @@ def discard_consumed(self) -> None: self.layouts = self.layouts[self.cursor :] self.cursor = 0 - def append(self, layout: PlacementResult) -> None: - self.layouts.append(layout) + def append(self, result: PlacementResult) -> None: + self.layouts.append(PooledLayout(result)) - def extend(self, layouts: list[PlacementResult]) -> None: - self.layouts.extend(layouts) + def extend(self, results: list[PlacementResult]) -> None: + self.layouts.extend(PooledLayout(result) for result in results) def next(self) -> PlacementResult: assert self.cursor < len(self.layouts), "No unread layouts remain in this env pool." layout = self.layouts[self.cursor] self.cursor += 1 - return layout + layout.mark_used() + return layout.result class PooledObjectPlacer: @@ -56,7 +91,7 @@ class PooledObjectPlacer: env ids. Reusable layouts are interchangeable and can be consumed one at a time from the pooled queues. - Strictly valid layouts are preferred. On the final retry batch, best-loss + Accepted layouts are preferred. On the final retry batch, best-loss solver results may be kept as a fallback. The pool is refilled automatically when an env's queue runs out. @@ -67,8 +102,11 @@ class PooledObjectPlacer: * sample_for_envs(env_ids) consumes one layout for each requested absolute env id (used for partial resets). * sample_with_replacement(count) is non-consuming. Env-specific layouts - are sampled from matching env slots; reusable layouts stay marginally - uniform over all stored layouts. + are sampled from matching env slots; reusable layouts are drawn from the + flattened pool, with the per-env RNG selecting the stream. + + save(path) / load(path, ...) persist solved layouts to JSON so a scene can reuse + pre-existing object poses without re-solving. Args: objects: All objects (including anchors) participating in relation solving. @@ -76,6 +114,9 @@ class PooledObjectPlacer: pool_size: Number of layouts to solve per batch. num_envs: Total number of simulation environments. Required when layouts use env-specific object variants and defaults to 1 otherwise. + layout_filter: Predicate over a layout's ValidationReport deciding which layouts to keep. + Defaults to accepting layouts whose checks all pass. A custom predicate must tolerate + missing keys (e.g. report.checks.get(name, False)), since checks can change. """ def __init__( @@ -84,8 +125,11 @@ def __init__( placer_params: ObjectPlacerParams, pool_size: int = 100, num_envs: int | None = None, + layout_filter: LayoutFilter | None = None, + _skip_initial_solve: bool = False, ) -> None: assert pool_size >= 1, f"pool_size must be >= 1, got {pool_size}" + self._layout_filter: LayoutFilter = layout_filter or default_layout_filter self._uses_env_specific_bboxes = has_heterogeneous_objects(objects) assert not ( self._uses_env_specific_bboxes and num_envs is None @@ -94,11 +138,15 @@ def __init__( assert self._num_envs >= 1, f"num_envs must be >= 1, got {self._num_envs}" self._objects = list(objects) - # Pool construction ranks several candidate layouts per env and applies - # poses only when a sampled layout is used. - self._placer = ObjectPlacer(params=replace(placer_params, apply_positions_to_objects=False)) + # The placer ranks by the same filter the pool stores by, so its best layout is one we accept. + # Poses are applied only when a sampled layout is used. + self._placer = ObjectPlacer( + params=replace(placer_params, apply_positions_to_objects=False), layout_filter=self._layout_filter + ) self._pool_size = pool_size self._had_fallbacks = False + # Why the most recent solve rejected layouts (per check, plus "layout_filter"); shown in messages. + self._last_rejection_summary: dict[str, int] = {} self._base_placement_seed = placer_params.placement_seed self._next_seed_offset = 0 # Per-env sampling RNG keyed by (placement_seed, env_id): env i's draws are reproducible @@ -106,18 +154,45 @@ def __init__( self._env_rngs = get_rngs(self._num_envs, placer_params.placement_seed) self._env_pools: list[EnvLayoutPool] = [EnvLayoutPool([]) for _ in range(self._num_envs)] + # load() fills the pools from disk instead of solving, so skip the upfront solve there. + if _skip_initial_solve: + return + + # _solve_and_store fills every env to >= 1 layout or raises with per-env diagnostics, so a + # populated pool is an invariant here rather than a user-facing failure mode. self._solve_and_store(pool_size) - for cur_env, pool in enumerate(self._env_pools): - if not pool.layouts: - raise RuntimeError( - f"Placement pool failed to produce any valid layouts for env {cur_env} " - f"from {pool_size} attempts. Check object relations and constraints." - ) + assert all(pool.layouts for pool in self._env_pools), "Placement pool is empty after solving." + + def accepts(self, result: PlacementResult) -> bool: + """Whether a layout passes the pool's layout_filter. + + Reporting consults this, not result.success: result.success always uses the default predicate + (every check passed), so under a custom filter the two can differ and fallback tracking must + follow the filter the pool actually stores by. + """ + accepted = self._layout_filter(result.validation) + assert isinstance(accepted, bool), f"layout_filter must return a bool, got {type(accepted).__name__}" + return accepted # ------------------------------------------------------------------ # Pool storage internals # ------------------------------------------------------------------ + def _summarize_rejections(self, layouts: list[PlacementResult]) -> dict[str, int]: + """Count rejection causes: each failed named check, or "layout_filter". + + Only filter-rejected layouts are counted. One that passes every built-in check but the + filter still rejects is counted under "layout_filter" rather than vanishing. + """ + counts: dict[str, int] = {} + for layout in layouts: + if self.accepts(layout): + continue + failed_checks = layout.validation.failed_checks + for name in failed_checks or ("layout_filter",): + counts[name] = counts.get(name, 0) + 1 + return counts + def _available_per_env(self) -> list[int]: """Number of unread layouts in each env's pool (length num_envs).""" return [pool.available for pool in self._env_pools] @@ -150,6 +225,9 @@ def _solve_and_store(self, num_layouts: int) -> None: max_solve_batches = max(1, self._placer.params.max_placement_attempts) for batch_idx in range(max_solve_batches): + # Reset each iteration so a fresh solve never carries a prior solve's stale counts and + # terminal diagnostics report only the batch that ultimately failed. + self._last_rejection_summary = {} max_missing = target_per_env - min(self._available_per_env()) if max_missing <= 0: return @@ -173,16 +251,16 @@ def _solve_and_store(self, num_layouts: int) -> None: raise RuntimeError( f"Placement pool could not fill {target_per_env} layouts per env after " - f"{max_solve_batches} solve batches. Available per env: {self._available_per_env()}." + f"{max_solve_batches} solve batches. Available per env: {self._available_per_env()}. " + f"Most recent rejection reasons: {self._last_rejection_summary or 'none recorded'}." ) def _solve_reusable_layouts(self, num_layouts: int, allow_fallback: bool = False) -> list[PlacementResult]: """Solve layouts that can be used by any env pool. - Invalid candidates are discarded when at least one valid layout exists. - If no candidate passes strict validation on the final retry batch, fall - back to best-loss results so environments with imperfect validation can - still run. + Rejected candidates are discarded when at least one accepted layout exists. + If no candidate is accepted on the final retry batch, fall back to best-loss + results so environments with imperfect validation can still run. """ self._prepare_seeded_solve(num_layouts * self._placer.params.max_placement_attempts) with torch.inference_mode(False): @@ -190,22 +268,27 @@ def _solve_reusable_layouts(self, num_layouts: int, allow_fallback: bool = False # place() returns a single PlacementResult only when num_envs == 1. all_layouts = result.results if isinstance(result, MultiEnvPlacementResult) else [result] - valid_layouts = [layout for layout in all_layouts if layout.success] + accepted_layouts = [layout for layout in all_layouts if self.accepts(layout)] - if len(valid_layouts) < num_layouts: + if len(accepted_layouts) < num_layouts: + self._last_rejection_summary = self._summarize_rejections(all_layouts) print( f"Pooled object placer: solved {num_layouts} layouts," - f" {len(valid_layouts)} valid, {num_layouts - len(valid_layouts)} failed validation" + f" {len(accepted_layouts)} accepted, {num_layouts - len(accepted_layouts)} rejected" + f" (rejection reasons: {self._last_rejection_summary})" ) - if valid_layouts: - return valid_layouts + if accepted_layouts: + return accepted_layouts if not allow_fallback: return [] self._had_fallbacks = True - print("Warning: No candidates passed strict validation. Accepting best-loss layouts as fallback.") + print( + "WARNING: No candidates met the pool's acceptance criteria. Accepting best-loss layouts as " + f"fallback. Rejection reasons across {len(all_layouts)} candidates: {self._last_rejection_summary}" + ) return all_layouts def _store_reusable_results(self, layouts: list[PlacementResult]) -> None: @@ -257,34 +340,37 @@ def _store_env_matched_results( """Store env-matched results into their corresponding pools. Each env is filled only up to target_per_env unread layouts, so envs - that already met the target are not overfilled. Successful layouts are - preferred; if allow_fallback is set and an env has no valid layouts, + that already met the target are not overfilled. Accepted layouts are + preferred; if allow_fallback is set and an env has no accepted layouts, fall back to its best-loss results so environments with imperfect validation can still run. """ - total_valid = 0 + total_accepted = 0 fallback_envs = [] for cur_env in range(self._num_envs): env_results = ranked_results_per_env[cur_env][:layouts_per_env] - valid_results = [r for r in env_results if r.success] + accepted_results = [r for r in env_results if self.accepts(r)] missing = target_per_env - self._env_pools[cur_env].available - if valid_results: + if accepted_results: if missing > 0: - enqueued = valid_results[:missing] - total_valid += len(enqueued) + enqueued = accepted_results[:missing] + total_accepted += len(enqueued) self._env_pools[cur_env].extend(enqueued) else: - total_valid += len(valid_results) - elif allow_fallback and missing > 0: + total_accepted += len(accepted_results) + elif allow_fallback and missing > 0 and env_results: self._env_pools[cur_env].extend(env_results[:missing]) fallback_envs.append(cur_env) self._had_fallbacks = True total_solved = sum(min(len(env_results), layouts_per_env) for env_results in ranked_results_per_env) - if total_valid < total_solved: + if total_accepted < total_solved: + considered = [r for env_results in ranked_results_per_env for r in env_results[:layouts_per_env]] + self._last_rejection_summary = self._summarize_rejections(considered) msg = ( f"Placement pool (env-specific bbox layouts) solved {total_solved} candidates," - f" {total_valid} valid, {total_solved - total_valid} failed validation" + f" {total_accepted} accepted, {total_solved - total_accepted} rejected" + f" (rejection reasons: {self._last_rejection_summary})" ) if fallback_envs: msg += f". Falling back to best-loss layouts for envs: {fallback_envs}" @@ -328,8 +414,8 @@ def _sample_env_indexed_without_replacement(self, count: int) -> list[PlacementR pool = self._env_pools[cur_env] if pool.available <= 0: raise RuntimeError( - f"Placement pool: env {cur_env} has no more valid layouts. " - "The solver is not producing enough valid placements." + f"Placement pool: env {cur_env} has no more accepted layouts. " + "The solver is not producing enough accepted placements." ) results.append(pool.next()) return results @@ -351,8 +437,8 @@ def sample_for_envs(self, env_ids: list[int]) -> dict[int, PlacementResult]: pool = self._env_pools[env_id] if pool.available <= 0: raise RuntimeError( - f"Placement pool: env {env_id} has no more valid layouts. " - "The solver is not producing enough valid placements." + f"Placement pool: env {env_id} has no more accepted layouts. " + "The solver is not producing enough accepted placements." ) results[env_id] = pool.next() return results @@ -366,7 +452,7 @@ def _sample_reusable_without_replacement(self, count: int) -> list[PlacementResu if sum(available) < count: raise RuntimeError( f"Placement pool has {sum(available)} reusable layouts but {count} were requested. " - "The solver is not producing enough valid placements." + "The solver is not producing enough accepted placements." ) results: list[PlacementResult] = [] @@ -375,13 +461,44 @@ def _sample_reusable_without_replacement(self, count: int) -> list[PlacementResu pool = self._env_pools[cur_env] if pool.available <= 0: raise RuntimeError( - f"Placement pool: env {cur_env} has no more valid layouts. " - "The solver is not producing enough valid placements." + f"Placement pool: env {cur_env} has no more accepted layouts. " + "The solver is not producing enough accepted placements." ) results.append(pool.next()) available[cur_env] -= 1 return results + def sample_with_replacement(self, count: int) -> list[PlacementResult]: + """Pick count layouts at random with replacement (non-consuming). + + For env-specific layouts, slot i picks from env i % num_envs's pool + so each result matches its absolute env. For reusable layouts, each slot + draws from the flattened pool, with the per-env RNG only selecting the stream. + """ + # Non-consuming (reads pool.layouts, ignoring the cursor). Slot i draws from env + # (i % num_envs)'s RNG, so given identical pool contents each env's sequence replays + # under (placement_seed, env_id), independent of other envs' draws. + if self._uses_env_specific_bboxes: + results: list[PlacementResult] = [] + for i in range(count): + cur_env = i % self._num_envs + pooled = self._env_pools[cur_env].layouts + assert pooled, f"Env {cur_env} has no accepted layouts to sample from." + results.append(self._draw(self._env_rngs[cur_env], pooled)) + return results + # Reusable layouts are interchangeable: draw each slot from the flattened pool, with the + # per-env RNG only selecting which stream the slot draws from. + all_layouts = [layout for pool in self._env_pools for layout in pool.layouts] + assert all_layouts, "No accepted layouts to sample from." + return [self._draw(self._env_rngs[i % self._num_envs], all_layouts) for i in range(count)] + + @staticmethod + def _draw(rng: random.Random, pooled_layouts: list[PooledLayout]) -> PlacementResult: + """Pick a stored layout, record the use, and return its result.""" + layout = rng.choice(pooled_layouts) + layout.mark_used() + return layout.result + @property def requires_env_indexed_layouts(self) -> bool: """Whether sampled layouts must be matched back to absolute env ids.""" @@ -394,47 +511,14 @@ def num_envs(self) -> int: @property def had_fallbacks(self) -> bool: - """Whether any pool refill accepted best-loss layouts that failed strict validation.""" + """Whether any solve stored best-loss fallback layouts (set when stored, even if never drawn).""" return self._had_fallbacks - def sample_with_replacement(self, count: int) -> list[PlacementResult]: - """Pick count layouts at random with replacement (non-consuming). - - For env-specific layouts, slot i picks from env i % num_envs's pool - so each result matches its absolute env. For reusable layouts, each - slot stays marginally uniform over the full pool; the per-env RNGs only - fix which stream a slot draws from, for parity with the env-specific branch. - """ - # Non-consuming: reads pool.layouts directly, ignoring the consumption cursor. - if self._uses_env_specific_bboxes: - results: list[PlacementResult] = [] - for i in range(count): - cur_env = i % self._num_envs - pool = self._env_pools[cur_env].layouts - assert pool, f"Env {cur_env} has no valid layouts to sample from." - results.append(self._env_rngs[cur_env].choice(pool)) - return results - # Serialize all layouts into one flat pool. - all_layouts: list[PlacementResult] = [] - for pool in self._env_pools: - for layout in pool.layouts: - all_layouts.append(layout) - assert all_layouts, "No valid layouts to sample from across any env pool." - - # Draw each slot from its env's RNG over the serialized pool. - results: list[PlacementResult] = [] - for layout_idx in range(count): - rng = self._env_rngs[layout_idx % self._num_envs] - results.append(rng.choice(all_layouts)) - return results - @property def remaining(self) -> int: - """Number of complete env rounds available to :meth:`sample_without_replacement`. + """Complete env rounds available to sample_without_replacement. - Returns the minimum unread count across env pools. A single round - consumes one layout from every env, so the minimum is what limits - without-replacement capacity. + One round consumes a layout from every env, so the per-env minimum is the limit. """ return min(self._available_per_env()) @@ -447,3 +531,106 @@ def pool_size(self) -> int: def total_remaining(self) -> int: """Total unread layouts across all env pools.""" return self._total_available() + + @property + def stored_layouts(self) -> tuple[tuple[PlacementResult, ...], ...]: + """Read-only view of every stored layout, grouped by absolute env id. + + Returns one inner tuple per env (outer length num_envs), each holding that env's stored + PlacementResults in pool order, including already-consumed ones. The grouping is meaningful + for env-specific layouts (see requires_env_indexed_layouts) and an arbitrary partition for + reusable ones, which can be flattened. + + The tuples are a snapshot (later refills do not appear), but the PlacementResult objects are + live, so a post-pool check (e.g. a simulation collision test) records its outcome by + reassigning result.validation = result.validation.with_check(name, passed) (with_check returns + a new report; it does not mutate), and accepts() then reflects it. Resampling is left to the + caller via the existing sample_* methods. + """ + return tuple(tuple(pooled.result for pooled in pool.layouts) for pool in self._env_pools) + + # ------------------------------------------------------------------ + # Persistence: save/load solved layouts to reuse poses without re-solving + # ------------------------------------------------------------------ + + def save(self, path: str | Path) -> None: + """Write all stored layouts to path as JSON for reuse without re-solving. + + Layouts are keyed by object name and include already-consumed ones (unread again on load). + Cursor, use_count, and refill offset are not persisted; placement_seed and had_fallbacks + are, so a loaded pool samples like a fresh pool with the same seed. See + layout_pool_serialization for the on-disk schema and the atomic, fail-loud write. + """ + assert len({obj.name for obj in self._objects}) == len( + self._objects + ), f"Object names must be unique to save a layout pool keyed by name: {path}" + document = PoolDocument( + placement_seed=self._base_placement_seed, + num_envs=self._num_envs, + uses_env_specific_bboxes=self._uses_env_specific_bboxes, + had_fallbacks=self._had_fallbacks, + env_pools=[[serialize_layout(pooled.result) for pooled in pool.layouts] for pool in self._env_pools], + ) + write_pool_document(Path(path), document) + + @classmethod + def load( + cls, + path: str | Path, + objects: list[ObjectBase], + placer_params: ObjectPlacerParams, + *, + num_envs: int | None = None, + layout_filter: LayoutFilter | None = None, + ) -> PooledObjectPlacer: + """Rebuild a pool from a save() file, reusing stored poses instead of solving. + + objects must contain, by name, every object referenced in the file. Malformed files and + env-count/heterogeneity/object-name mismatches fail loudly (see layout_pool_serialization + for the structural checks). The saved placement_seed is restored (placer_params.placement_seed is + overridden by the saved seed after construction) so sampling matches the saved run; refill offset + is not persisted, so + a refill restarts from the first solve batch. pool_size becomes the total loaded layout count + across all envs (not per-env), so a multi-env refill batches larger than the original + per-batch pool_size; harmless because refills on a loaded pool are rare. + """ + path = Path(path) + document = read_pool_document(path) + assert ( + num_envs is None or num_envs == document.num_envs + ), f"num_envs={num_envs} does not match the {document.num_envs} envs saved in {path}." + + name_to_obj = {obj.name: obj for obj in objects} + assert len(name_to_obj) == len(objects), f"Object names must be unique to load a layout pool by name: {path}" + + loaded_count = sum(len(env_layouts) for env_layouts in document.env_pools) + placer = cls( + objects=objects, + placer_params=placer_params, + pool_size=max(1, loaded_count), + num_envs=document.num_envs, + layout_filter=layout_filter, + _skip_initial_solve=True, + ) + assert placer._uses_env_specific_bboxes == document.uses_env_specific_bboxes, ( + "Loaded objects' heterogeneity does not match the saved pool; re-solve instead of loading this cache:" + f" {path}" + ) + + # Reproduce the saved run rather than the freshly-passed seed. + placer._base_placement_seed = document.placement_seed + placer._had_fallbacks = document.had_fallbacks + placer._env_rngs = get_rngs(document.num_envs, document.placement_seed) + placer._env_pools = [] + for cur_env, env_layouts in enumerate(document.env_pools): + pooled = [] + for layout_idx, layout in enumerate(env_layouts): + try: + pooled.append(PooledLayout(deserialize_layout(layout, name_to_obj))) + except (AssertionError, OverflowError) as exc: + # OverflowError: an out-of-range int coordinate passes isinstance but overflows math.isfinite. + raise AssertionError(f"{exc} (env {cur_env}, layout {layout_idx} in {path})") from exc + placer._env_pools.append(EnvLayoutPool(pooled)) + for cur_env, pool in enumerate(placer._env_pools): + assert pool.layouts, f"Loaded layout pool has no layouts for env {cur_env}: {path}" + return placer diff --git a/isaaclab_arena/tests/test_heterogeneous_placement.py b/isaaclab_arena/tests/test_heterogeneous_placement.py index 43bafa956..4d7693d3b 100644 --- a/isaaclab_arena/tests/test_heterogeneous_placement.py +++ b/isaaclab_arena/tests/test_heterogeneous_placement.py @@ -15,16 +15,23 @@ from isaaclab_arena.assets.dummy_object import DummyObject from isaaclab_arena.relations.bounding_box_helpers import build_per_env_bounding_boxes, get_bounding_box_per_env -from isaaclab_arena.relations.object_placer import ObjectPlacer +from isaaclab_arena.relations.object_placer import ObjectPlacer, PlacementCandidate from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams -from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult -from isaaclab_arena.relations.pooled_object_placer import PooledObjectPlacer +from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult, ValidationReport +from isaaclab_arena.relations.pooled_object_placer import PooledLayout, PooledObjectPlacer from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.relations.relations import IsAnchor, On +from isaaclab_arena.tests.utils.placement import layout_signature from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox from isaaclab_arena.utils.pose import Pose + +def _report(passed: bool) -> ValidationReport: + """Build a single-check ValidationReport for fixtures that only care about pass/fail.""" + return ValidationReport(checks={"no_overlap": passed}) + + # --------------------------------------------------------------------------- # Fixture: let HeterogeneousDummyObject trigger the heterogeneous path # --------------------------------------------------------------------------- @@ -397,6 +404,24 @@ def test_object_placer_place_ranked_per_env_returns_sorted_env_lists(): assert hetero.get_initial_pose() is None +def test_object_placer_ranking_uses_layout_filter(): + """Ranking should rank a filter-accepted candidate ahead of a lower-loss rejected one.""" + placer = ObjectPlacer( + params=ObjectPlacerParams(), + layout_filter=lambda report: report.checks.get("on_relations", False), + ) + low_loss_rejected = PlacementCandidate( + loss=0.1, positions={}, validation=ValidationReport(checks={"no_overlap": True, "on_relations": False}) + ) + high_loss_accepted = PlacementCandidate( + loss=0.9, positions={}, validation=ValidationReport(checks={"no_overlap": True, "on_relations": True}) + ) + + ranked = placer._rank_candidates([low_loss_rejected, high_loss_accepted], num_envs=1, candidates_per_env=2) + + assert ranked[0][0] is high_loss_accepted + + def test_object_placer_homogeneous_objects_return_multi_env_result(): """Homogeneous objects return one layout per env (bboxes identical across envs).""" @@ -479,11 +504,13 @@ def test_pooled_placer_sample_for_envs_consumes_only_requested_envs(): for env_id in range(4): pool._env_pools[env_id].layouts = [ - PlacementResult( - success=True, - positions={hetero: (float(env_id), 0.0, 0.0)}, - final_loss=0.0, - attempts=1, + PooledLayout( + PlacementResult( + validation=_report(True), + positions={hetero: (float(env_id), 0.0, 0.0)}, + final_loss=0.0, + attempts=1, + ) ) ] pool._env_pools[env_id].cursor = 0 @@ -503,7 +530,11 @@ def test_pooled_placer_heterogeneous_sample_with_replacement(): for env_id in range(4): pool._env_pools[env_id].layouts = [ - PlacementResult(success=True, positions={hetero: (float(env_id), 0.0, 0.0)}, final_loss=0.0, attempts=1) + PooledLayout( + PlacementResult( + validation=_report(True), positions={hetero: (float(env_id), 0.0, 0.0)}, final_loss=0.0, attempts=1 + ) + ) ] pool._env_pools[env_id].cursor = 0 initial_remaining = pool.remaining @@ -512,6 +543,17 @@ def test_pooled_placer_heterogeneous_sample_with_replacement(): assert pool.remaining == initial_remaining, "sample_with_replacement should not consume layouts" +def test_pooled_placer_heterogeneous_sample_with_replacement_empty_env_pool_asserts(): + """An empty env pool in the env-specific branch should fail loudly rather than draw silently.""" + desk, hetero, placer_params = _make_hetero_pool_objects() + pool = PooledObjectPlacer(objects=[desk, hetero], placer_params=placer_params, pool_size=8, num_envs=2) + pool._env_pools[1].layouts = [] + pool._env_pools[1].cursor = 0 + + with pytest.raises(AssertionError, match="no accepted layouts"): + pool.sample_with_replacement(2) + + def test_pooled_placer_heterogeneous_sample_with_replacement_reproducible_per_env_id(): """sample_with_replacement should reproduce each env's layout under a fixed seed (env-specific branch).""" num_envs = 4 @@ -566,8 +608,8 @@ def _draw_sequence(): assert _draw_sequence() == _draw_sequence() -def test_pooled_placer_env_specific_fallbacks_are_reported(capsys): - """Env-specific best-loss fallbacks should be reported to callers.""" +def test_pooled_placer_env_specific_fallbacks_are_reported(): + """Env-specific best-loss fallbacks should be reported via had_fallbacks.""" desk, hetero, placer_params = _make_hetero_pool_objects() pool = PooledObjectPlacer(objects=[desk, hetero], placer_params=placer_params, pool_size=2, num_envs=2) for env_pool in pool._env_pools: @@ -578,7 +620,7 @@ def test_pooled_placer_env_specific_fallbacks_are_reported(capsys): fallback_results = [ [ PlacementResult( - success=False, + validation=_report(False), positions={hetero: (float(cur_env), 0.0, 0.0)}, final_loss=1.0, attempts=1, @@ -588,14 +630,12 @@ def test_pooled_placer_env_specific_fallbacks_are_reported(capsys): ] pool._store_env_matched_results(fallback_results, layouts_per_env=1, target_per_env=1, allow_fallback=True) - captured = capsys.readouterr() assert pool.had_fallbacks - assert "Falling back to best-loss layouts" in captured.out assert [env_pool.available for env_pool in pool._env_pools] == [1, 1] -def test_pooled_placer_env_specific_fallbacks_wait_for_final_retry(capsys): +def test_pooled_placer_env_specific_fallbacks_wait_for_final_retry(): """Invalid env-specific candidates should not fill pools before fallback is allowed.""" desk, hetero, placer_params = _make_hetero_pool_objects() pool = PooledObjectPlacer(objects=[desk, hetero], placer_params=placer_params, pool_size=2, num_envs=2) @@ -607,7 +647,7 @@ def test_pooled_placer_env_specific_fallbacks_wait_for_final_retry(capsys): fallback_results = [ [ PlacementResult( - success=False, + validation=_report(False), positions={hetero: (float(cur_env), 0.0, 0.0)}, final_loss=1.0, attempts=1, @@ -617,14 +657,12 @@ def test_pooled_placer_env_specific_fallbacks_wait_for_final_retry(capsys): ] pool._store_env_matched_results(fallback_results, layouts_per_env=1, target_per_env=1) - captured = capsys.readouterr() assert not pool.had_fallbacks - assert "Falling back to best-loss layouts" not in captured.out assert [env_pool.available for env_pool in pool._env_pools] == [0, 0] -def test_pooled_placer_env_specific_fallback_only_fills_short_env(capsys): +def test_pooled_placer_env_specific_fallback_only_fills_short_env(): """Final-batch fallback should not overfill env pools that already met the target.""" desk, hetero, placer_params = _make_hetero_pool_objects() pool = PooledObjectPlacer(objects=[desk, hetero], placer_params=placer_params, pool_size=2, num_envs=2) @@ -634,7 +672,7 @@ def test_pooled_placer_env_specific_fallback_only_fills_short_env(capsys): pool._had_fallbacks = False existing_layout = PlacementResult( - success=True, + validation=_report(True), positions={hetero: (10.0, 0.0, 0.0)}, final_loss=0.0, attempts=1, @@ -644,7 +682,7 @@ def test_pooled_placer_env_specific_fallback_only_fills_short_env(capsys): fallback_results = [ [ PlacementResult( - success=False, + validation=_report(False), positions={hetero: (float(cur_env), 0.0, 0.0)}, final_loss=1.0, attempts=1, @@ -659,13 +697,11 @@ def test_pooled_placer_env_specific_fallback_only_fills_short_env(capsys): allow_fallback=True, target_per_env=1, ) - captured = capsys.readouterr() assert pool.had_fallbacks - assert "envs: [1]" in captured.out assert [env_pool.available for env_pool in pool._env_pools] == [1, 1] - assert pool._env_pools[0].layouts == [existing_layout] - assert pool._env_pools[1].layouts[0] is fallback_results[1][0] + assert pool._env_pools[0].layouts[0].result is existing_layout + assert pool._env_pools[1].layouts[0].result is fallback_results[1][0] def test_pooled_placer_env_specific_valid_results_only_fill_short_envs(): @@ -677,7 +713,7 @@ def test_pooled_placer_env_specific_valid_results_only_fill_short_envs(): env_pool.cursor = 0 existing_layout = PlacementResult( - success=True, + validation=_report(True), positions={hetero: (10.0, 0.0, 0.0)}, final_loss=0.0, attempts=1, @@ -687,7 +723,7 @@ def test_pooled_placer_env_specific_valid_results_only_fill_short_envs(): ranked_results = [ [ PlacementResult( - success=True, + validation=_report(True), positions={hetero: (float(cur_env), float(candidate_idx), 0.0)}, final_loss=0.0, attempts=1, @@ -700,8 +736,8 @@ def test_pooled_placer_env_specific_valid_results_only_fill_short_envs(): pool._store_env_matched_results(ranked_results, layouts_per_env=2, target_per_env=1) assert [env_pool.available for env_pool in pool._env_pools] == [1, 1] - assert pool._env_pools[0].layouts == [existing_layout] - assert pool._env_pools[1].layouts == [ranked_results[1][0]] + assert pool._env_pools[0].layouts[0].result is existing_layout + assert pool._env_pools[1].layouts[0].result is ranked_results[1][0] def test_pooled_placer_reusable_layouts_report_complete_env_rounds(): @@ -741,7 +777,7 @@ def test_pooled_placer_reusable_layouts_keep_partial_valid_results(): env_pool.cursor = 0 layouts = [ - PlacementResult(success=True, positions={box: (float(i), 0.0, 0.0)}, final_loss=0.0, attempts=1) + PlacementResult(validation=_report(True), positions={box: (float(i), 0.0, 0.0)}, final_loss=0.0, attempts=1) for i in range(3) ] pool._store_reusable_results(layouts) @@ -938,3 +974,56 @@ def test_real_rigid_object_set_through_pooled_placer(): assert obj_set in draw.positions z = draw.positions[obj_set][2] assert abs(z - 0.11) < 0.05, f"z={z:.4f}, expected ~0.11" + + +# --------------------------------------------------------------------------- +# Save/load round trip — env-specific (variant) layouts +# --------------------------------------------------------------------------- + + +def _seeded_hetero_params() -> ObjectPlacerParams: + return ObjectPlacerParams( + solver_params=RelationSolverParams(max_iters=200, convergence_threshold=1e-3, verbose=False), + placement_seed=7, + ) + + +def test_pooled_placer_heterogeneous_save_load_round_trip(tmp_path): + """Env-specific pools round-trip per env, preserving each env's variant-specific layouts.""" + desk, hetero, _ = _make_hetero_pool_objects() + placer_params = _seeded_hetero_params() + pool = PooledObjectPlacer(objects=[desk, hetero], placer_params=placer_params, pool_size=20, num_envs=4) + + path = tmp_path / "hetero_layouts.json" + pool.save(path) + + desk2, hetero2, _ = _make_hetero_pool_objects() + loaded = PooledObjectPlacer.load(path, [desk2, hetero2], placer_params, num_envs=4) + + original = [[layout_signature(p.result) for p in env.layouts] for env in pool._env_pools] + restored = [[layout_signature(p.result) for p in env.layouts] for env in loaded._env_pools] + assert restored == original + assert loaded.requires_env_indexed_layouts + # Variant geometry differs across envs, so a swapped env mapping would change these signatures. + assert original[0] != original[1] + + +def test_pooled_placer_load_rejects_heterogeneity_mismatch(tmp_path): + """A heterogeneous-saved file loaded against homogeneous objects must fail loudly.""" + desk, hetero, _ = _make_hetero_pool_objects() + placer_params = _seeded_hetero_params() + pool = PooledObjectPlacer(objects=[desk, hetero], placer_params=placer_params, pool_size=20, num_envs=4) + + path = tmp_path / "hetero_layouts.json" + pool.save(path) + + # Homogeneous stand-ins with the same names flip uses_env_specific_bboxes to False. + desk2 = _make_desk() + box = DummyObject( + name="hetero", + bounding_box=AxisAlignedBoundingBox(min_point=(0.0, 0.0, 0.0), max_point=(0.1, 0.1, 0.1)), + ) + box.add_relation(On(desk2, clearance_m=0.01)) + + with pytest.raises(AssertionError, match="heterogeneity does not match"): + PooledObjectPlacer.load(path, [desk2, box], placer_params, num_envs=4) diff --git a/isaaclab_arena/tests/test_layout_pool_serialization.py b/isaaclab_arena/tests/test_layout_pool_serialization.py new file mode 100644 index 000000000..edf745c7c --- /dev/null +++ b/isaaclab_arena/tests/test_layout_pool_serialization.py @@ -0,0 +1,285 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the pure layout-pool JSON serialization helpers (no simulation).""" + +from pathlib import Path + +import pytest + +from isaaclab_arena.relations.layout_pool_serialization import ( + PoolDocument, + deserialize_layout, + read_pool_document, + serialize_layout, + write_pool_document, +) +from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport + + +class _Obj: + """Minimal stand-in for an ObjectBase: hashable, with a name.""" + + def __init__(self, name: str): + self.name = name + + +def _layout_dict(): + """A valid serialized single-layout dict for objects 'a' and 'b'.""" + return { + "positions": {"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}, + "orientations": {"a": 0.5}, + "validation": {"no_overlap": True, "on_relations": True}, + "final_loss": 0.25, + "attempts": 3, + } + + +def _name_to_obj(): + return {"a": _Obj("a"), "b": _Obj("b")} + + +def _pool_dict(num_envs: int = 1): + return { + "placement_seed": 42, + "num_envs": num_envs, + "uses_env_specific_bboxes": False, + "had_fallbacks": False, + "env_pools": [[_layout_dict()] for _ in range(num_envs)], + } + + +# --------------------------------------------------------------------------- +# Round trips +# --------------------------------------------------------------------------- + + +def test_serialize_deserialize_layout_round_trip(): + name_to_obj = _name_to_obj() + a, b = name_to_obj["a"], name_to_obj["b"] + result = PlacementResult( + positions={a: (1.0, 2.0, 3.0), b: (4.0, 5.0, 6.0)}, + orientations={a: 0.5}, + validation=ValidationReport(checks={"no_overlap": True}), + final_loss=0.25, + attempts=3, + ) + + restored = deserialize_layout(serialize_layout(result), name_to_obj) + + assert restored.positions == {a: (1.0, 2.0, 3.0), b: (4.0, 5.0, 6.0)} + assert restored.orientations == {a: 0.5} + assert dict(restored.validation.checks) == {"no_overlap": True} + assert restored.final_loss == 0.25 + assert restored.attempts == 3 + + +def test_pool_document_to_from_dict_round_trip(): + document = PoolDocument( + placement_seed=7, num_envs=2, uses_env_specific_bboxes=True, had_fallbacks=True, env_pools=[[], []] + ) + assert PoolDocument.from_dict(document.to_dict(), path="mem") == document + + +def test_write_then_read_round_trip(tmp_path): + path = tmp_path / "pool.json" + document = PoolDocument.from_dict(_pool_dict(num_envs=2), path=path) + write_pool_document(path, document) + assert read_pool_document(path) == document + + +# --------------------------------------------------------------------------- +# PoolDocument.from_dict structural guards +# --------------------------------------------------------------------------- + + +def test_from_dict_rejects_non_dict(): + with pytest.raises(AssertionError, match="not a JSON object"): + PoolDocument.from_dict([1, 2, 3], path="mem") + + +def test_from_dict_rejects_missing_key(): + data = _pool_dict() + del data["had_fallbacks"] + with pytest.raises(AssertionError, match="missing required key 'had_fallbacks'"): + PoolDocument.from_dict(data, path="mem") + + +@pytest.mark.parametrize( + "key, value, message", + [ + ("env_pools", {}, "'env_pools' must be a list"), + ("num_envs", True, "'num_envs' must be an int"), + ("num_envs", "x", "'num_envs' must be an int"), + ("placement_seed", True, "'placement_seed' must be int or null"), + ("placement_seed", 1.5, "'placement_seed' must be int or null"), + ("uses_env_specific_bboxes", "no", "'uses_env_specific_bboxes' must be a bool"), + ("had_fallbacks", 1, "'had_fallbacks' must be a bool"), + ], +) +def test_from_dict_rejects_wrong_top_level_type(key, value, message): + data = _pool_dict() + data[key] = value + with pytest.raises(AssertionError, match=message): + PoolDocument.from_dict(data, path="mem") + + +def test_from_dict_rejects_num_envs_length_mismatch(): + data = _pool_dict(num_envs=1) + data["num_envs"] = 2 + with pytest.raises(AssertionError, match="num_envs does not match env_pools length"): + PoolDocument.from_dict(data, path="mem") + + +def test_from_dict_rejects_non_list_env_entry(): + data = _pool_dict(num_envs=1) + data["env_pools"] = [_layout_dict()] + with pytest.raises(AssertionError, match="env 0 must be a list"): + PoolDocument.from_dict(data, path="mem") + + +def test_from_dict_rejects_non_dict_layout_entry(): + data = _pool_dict(num_envs=1) + data["env_pools"] = [["not a dict"]] + with pytest.raises(AssertionError, match="non-dict layout entry"): + PoolDocument.from_dict(data, path="mem") + + +# --------------------------------------------------------------------------- +# write/read IO +# --------------------------------------------------------------------------- + + +def test_write_pool_document_rejects_non_finite_leaving_no_files(tmp_path): + path = tmp_path / "pool.json" + document = PoolDocument( + placement_seed=0, + num_envs=1, + uses_env_specific_bboxes=False, + had_fallbacks=False, + env_pools=[[{"positions": {"a": [float("inf"), 0.0, 0.0]}}]], + ) + with pytest.raises(ValueError, match="JSON compliant"): + write_pool_document(path, document) + assert not path.exists() + assert not path.with_name(f"{path.name}.tmp").exists() + + +def test_write_pool_document_cleans_up_tmp_on_os_error(tmp_path, monkeypatch): + path = tmp_path / "pool.json" + document = PoolDocument.from_dict(_pool_dict(num_envs=1), path=path) + + def fail_mid_write(self, *args, **kwargs): + self.write_bytes(b"partial") # leave an orphan .tmp, then fail like a full disk + raise OSError("No space left on device") + + monkeypatch.setattr(Path, "write_text", fail_mid_write) + with pytest.raises(OSError, match="No space left on device"): + write_pool_document(path, document) + assert not path.exists() + assert not path.with_name(f"{path.name}.tmp").exists() + + +def test_read_pool_document_missing_file(tmp_path): + with pytest.raises(AssertionError, match="not found"): + read_pool_document(tmp_path / "missing.json") + + +def test_read_pool_document_malformed_json(tmp_path): + path = tmp_path / "bad.json" + path.write_text("{ not json") + with pytest.raises(ValueError, match="not valid JSON"): + read_pool_document(path) + + +def test_read_pool_document_rejects_non_finite_constant(tmp_path): + """json.loads accepts NaN/Infinity by default; read must reject it to mirror the write guard.""" + path = tmp_path / "nan.json" + path.write_text('{"placement_seed": NaN}') + with pytest.raises(ValueError, match="non-finite JSON constant 'NaN'"): + read_pool_document(path) + + +# --------------------------------------------------------------------------- +# deserialize_layout leaf guards +# --------------------------------------------------------------------------- + + +def test_deserialize_rejects_missing_layout_key(): + data = _layout_dict() + del data["final_loss"] + with pytest.raises(AssertionError, match="missing required key 'final_loss'"): + deserialize_layout(data, _name_to_obj()) + + +@pytest.mark.parametrize( + "key, value, message", + [ + ("positions", [1, 2], "'positions' must be a dict"), + ("orientations", [1, 2], "'orientations' must be a dict"), + ("validation", [1, 2], "'validation' must be a dict"), + ("validation", {}, "empty validation map"), + ("final_loss", "x", "'final_loss' must be a finite number"), + ("final_loss", float("nan"), "'final_loss' must be a finite number"), + ("attempts", 1.5, "'attempts' must be an int"), + ("attempts", True, "'attempts' must be an int"), + ], +) +def test_deserialize_rejects_bad_leaf(key, value, message): + data = _layout_dict() + data[key] = value + with pytest.raises(AssertionError, match=message): + deserialize_layout(data, _name_to_obj()) + + +def test_deserialize_rejects_non_bool_validation_value(): + data = _layout_dict() + data["validation"] = {"no_overlap": "true"} + with pytest.raises(AssertionError, match="must be a JSON bool"): + deserialize_layout(data, _name_to_obj()) + + +@pytest.mark.parametrize( + "pos, message", + [ + ([1.0, 2.0], "length-3 sequence"), + ([1.0, 2.0, 3.0, 4.0], "length-3 sequence"), + (["a", "b", "c"], "finite numbers"), + ([float("inf"), 0.0, 0.0], "finite numbers"), + ], +) +def test_deserialize_rejects_bad_position(pos, message): + data = _layout_dict() + data["positions"]["a"] = pos + with pytest.raises(AssertionError, match=message): + deserialize_layout(data, _name_to_obj()) + + +def test_deserialize_rejects_non_finite_yaw(): + data = _layout_dict() + data["orientations"]["a"] = float("nan") + with pytest.raises(AssertionError, match="orientation for 'a' must be a finite number"): + deserialize_layout(data, _name_to_obj()) + + +def test_deserialize_rejects_object_set_mismatch_missing_live_object(): + data = _layout_dict() + name_to_obj = {"a": _Obj("a"), "b": _Obj("b"), "c": _Obj("c")} + with pytest.raises(AssertionError, match="do not match the provided objects"): + deserialize_layout(data, name_to_obj) + + +def test_deserialize_rejects_unknown_saved_object(): + data = _layout_dict() + with pytest.raises(AssertionError, match="do not match the provided objects"): + deserialize_layout(data, {"a": _Obj("a")}) + + +def test_deserialize_rejects_orientation_without_position(): + # positions == live objects {a, b}, but an orientation names 'z' which has no position. + data = _layout_dict() + data["orientations"]["z"] = 0.1 + with pytest.raises(AssertionError, match="not a subset of positions"): + deserialize_layout(data, _name_to_obj()) diff --git a/isaaclab_arena/tests/test_object_placer_reproducibility.py b/isaaclab_arena/tests/test_object_placer_reproducibility.py index 690018e12..6f269bb23 100644 --- a/isaaclab_arena/tests/test_object_placer_reproducibility.py +++ b/isaaclab_arena/tests/test_object_placer_reproducibility.py @@ -5,6 +5,7 @@ """Tests for ObjectPlacer and RelationSolver reproducibility.""" +import json import math import pytest @@ -12,11 +13,12 @@ from isaaclab_arena.assets.dummy_object import DummyObject from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams -from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult -from isaaclab_arena.relations.pooled_object_placer import PooledObjectPlacer +from isaaclab_arena.relations.placement_result import MultiEnvPlacementResult, PlacementResult, ValidationReport +from isaaclab_arena.relations.pooled_object_placer import PooledLayout, PooledObjectPlacer from isaaclab_arena.relations.relation_solver import RelationSolver from isaaclab_arena.relations.relation_solver_params import RelationSolverParams from isaaclab_arena.relations.relations import IsAnchor, NextTo, On, RotateAroundSolution, Side +from isaaclab_arena.tests.utils.placement import layout_signature from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox, get_random_pose_within_bounding_box from isaaclab_arena.utils.pose import Pose, PosePerEnv, rotate_quat_by_yaw, wrap_angle_to_pi @@ -337,6 +339,11 @@ def _positions_by_name(result: PlacementResult) -> dict[str, tuple[float, float, return {obj.name: pos for obj, pos in result.positions.items()} +def _pool_signatures(pool: PooledObjectPlacer) -> list[list[tuple]]: + """All stored layouts per env, as comparable signatures.""" + return [[layout_signature(layout.result) for layout in env_pool.layouts] for env_pool in pool._env_pools] + + # --------------------------------------------------------------------------- # PooledObjectPlacer reproducibility — homogeneous objects. # Heterogeneous-object (per-env variant) counterparts live in test_heterogeneous_placement.py. @@ -455,6 +462,18 @@ def test_pooled_placer_homogeneous_unseeded_does_not_crash(): assert _positions_by_name(sample) +def test_pooled_placer_reusable_sample_with_replacement_empty_pool_asserts(): + """The reusable (homogeneous) branch must also fail loudly when no layouts are available.""" + placer_params = _make_seeded_params() + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=6) + for env_pool in pool._env_pools: + env_pool.layouts = [] + env_pool.cursor = 0 + + with pytest.raises(AssertionError, match="No accepted layouts to sample from"): + pool.sample_with_replacement(1) + + def test_pooled_placer_homogeneous_stored_layouts_have_distinct_positions_dicts(): """Each stored layout must own a distinct positions dict (no aliasing across pool entries).""" solver_params = RelationSolverParams(max_iters=50) @@ -470,6 +489,53 @@ def test_pooled_placer_homogeneous_stored_layouts_have_distinct_positions_dicts( ), f"Layouts {i} and {j} share the same positions dict reference" +def test_pooled_placer_stored_layouts_groups_live_results_by_env(): + """stored_layouts must return a per-env snapshot of the live stored PlacementResults.""" + num_envs = 3 + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + pool = PooledObjectPlacer( + objects=list(_create_test_objects()), placer_params=placer_params, pool_size=12, num_envs=num_envs + ) + + grouped = pool.stored_layouts + assert len(grouped) == num_envs + assert all(isinstance(env_layouts, tuple) for env_layouts in grouped) + # Each returned result is the same live object the pool stores (no copy), grouped per env. + for cur_env, env_layouts in enumerate(grouped): + assert list(env_layouts) == [pooled.result for pooled in pool._env_pools[cur_env].layouts] + + +def test_pooled_placer_stored_layouts_post_validation_flows_into_accepts(): + """Enriching a stored layout's report in place (the post-validation use case) must change accepts().""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) + result = pool.stored_layouts[0][0] + assert pool.accepts(result) + + # Record a failing post-pool check (e.g. a simulation collision test) on the live layout. + result.validation = result.validation.with_check("sim_collision_free", False) + assert "sim_collision_free" in result.validation.failed_checks + assert not pool.accepts(result) + + +def test_pooled_placer_stored_layouts_snapshot_is_isolated_from_refills(): + """A captured stored_layouts snapshot is frozen: later draws/refills don't change it.""" + placer_params = _make_seeded_params() + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) + + snapshot = pool.stored_layouts + counts_before = [len(env_layouts) for env_layouts in snapshot] + + pool.sample_without_replacement(pool.total_remaining) + pool.sample_without_replacement(2) # forces a refill that mutates the live pools + + assert [len(env_layouts) for env_layouts in snapshot] == counts_before + + def test_pooled_placer_homogeneous_sample_without_replacement_count_exceeds_pool_size(): """sample_without_replacement(count) where count > pool_size must solve a larger batch in one shot.""" solver_params = RelationSolverParams(max_iters=50) @@ -514,3 +580,512 @@ def test_pooled_placer_homogeneous_rejects_pool_size_below_one(): placer_params = ObjectPlacerParams(placement_seed=42, solver_params=RelationSolverParams(max_iters=10)) with pytest.raises(AssertionError, match="pool_size must be >= 1"): PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=0) + + +def test_pooled_placer_layout_filter_receives_validation_reports(): + """The injected layout_filter should be consulted with each layout's ValidationReport.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + seen: list[ValidationReport] = [] + + def record(report: ValidationReport) -> bool: + seen.append(report) + return report.passed + + PooledObjectPlacer( + objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4, layout_filter=record + ) + assert seen, "layout_filter should be consulted while filling the pool" + assert all(isinstance(report, ValidationReport) for report in seen) + + +def test_pooled_placer_rejecting_layout_filter_forces_fallback(): + """A layout_filter that rejects everything should drive the pool onto best-loss fallbacks.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + default_pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) + assert default_pool.had_fallbacks is False + + strict_pool = PooledObjectPlacer( + objects=list(_create_test_objects()), + placer_params=placer_params, + pool_size=4, + layout_filter=lambda report: False, + ) + assert strict_pool.had_fallbacks is True + + +def test_summarize_rejections_attributes_filter_only_rejections(): + """A layout the filter rejects despite passing every built-in check counts under 'layout_filter'.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + pool = PooledObjectPlacer( + objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4, layout_filter=lambda r: False + ) + + filter_only = PlacementResult( + positions={}, final_loss=0.0, attempts=1, validation=ValidationReport(checks={"no_overlap": True}) + ) + check_failed = PlacementResult( + positions={}, final_loss=0.0, attempts=1, validation=ValidationReport(checks={"no_overlap": False}) + ) + summary = pool._summarize_rejections([filter_only, check_failed]) + assert summary == {"layout_filter": 1, "no_overlap": 1} + + +def test_summarize_rejections_empty_when_all_accepted(): + """All-accepted layouts produce no rejection counts.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) + + accepted = [pooled.result for env_pool in pool._env_pools for pooled in env_pool.layouts] + assert accepted and all(pool.accepts(result) for result in accepted) + assert pool._summarize_rejections(accepted) == {} + + +def test_solve_and_store_resets_stale_rejection_summary(): + """A fresh solve must not surface rejection counts left over from an earlier solve.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=4) + + pool._last_rejection_summary = {"stale_check": 99} + pool._solve_and_store(4) + + assert pool._last_rejection_summary == {} + + +def test_pooled_placer_layout_filter_can_accept_a_subset_of_checks(): + """A layout_filter keyed on a single named check should accept/reject via accepts accordingly.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + # Accept any layout whose no_overlap check passes, even if on_relations failed. + pool = PooledObjectPlacer( + objects=list(_create_test_objects()), + placer_params=placer_params, + pool_size=4, + layout_filter=lambda report: report.checks.get("no_overlap", False), + ) + + overlap_ok = PlacementResult( + positions={}, + final_loss=0.0, + attempts=1, + validation=ValidationReport(checks={"no_overlap": True, "on_relations": False}), + ) + overlap_bad = PlacementResult( + positions={}, + final_loss=0.0, + attempts=1, + validation=ValidationReport(checks={"no_overlap": False, "on_relations": True}), + ) + assert pool.accepts(overlap_ok) is True + assert pool.accepts(overlap_bad) is False + + +def test_pooled_layout_mark_used_increments_use_count(): + """PooledLayout.mark_used should increment use_count without replacing the wrapped result.""" + result = PlacementResult( + positions={}, final_loss=0.0, attempts=1, validation=ValidationReport(checks={"no_overlap": True}) + ) + layout = PooledLayout(result) + assert layout.use_count == 0 + layout.mark_used() + layout.mark_used() + assert layout.use_count == 2 + assert layout.result is result + + +def test_pooled_placer_sample_with_replacement_tracks_use_count(): + """Each sample_with_replacement draw should bump exactly one stored layout's use_count.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=8) + pool.sample_with_replacement(20) + + total_uses = sum(layout.use_count for env_pool in pool._env_pools for layout in env_pool.layouts) + assert total_uses == 20 + + +def test_pooled_placer_reusable_draws_span_multiple_env_pools(): + """Reusable sampling flattens every env pool, so a multi-env pool serves layouts from >1 origin.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + pool = PooledObjectPlacer( + objects=list(_create_test_objects()), placer_params=placer_params, pool_size=9, num_envs=3 + ) + origin_by_id = { + id(layout.result): env_idx for env_idx, env_pool in enumerate(pool._env_pools) for layout in env_pool.layouts + } + assert len(set(origin_by_id.values())) > 1, "fixture must spread layouts across multiple env pools" + + draws = pool.sample_with_replacement(30) + origins_hit = {origin_by_id[id(result)] for result in draws} + assert len(origins_hit) > 1, "reusable draws should span more than one origin env pool" + + +def test_pooled_placer_sample_without_replacement_marks_consumed_layouts_used(): + """sample_without_replacement should mark exactly the consumed layouts used once each.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=8) + pool.sample_without_replacement(4) + + use_counts = [layout.use_count for env_pool in pool._env_pools for layout in env_pool.layouts] + assert sorted(use_counts, reverse=True)[:4] == [1, 1, 1, 1] + assert sum(use_counts) == 4 + + +def test_pooled_placer_use_count_progression_replays_under_fixed_seed(): + """Per-slot use_count after sampling should replay identically under a fixed seed.""" + solver_params = RelationSolverParams(max_iters=50) + placer_params = ObjectPlacerParams(placement_seed=42, solver_params=solver_params, apply_positions_to_objects=False) + + pool1 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=8) + pool2 = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=8) + + pool1.sample_with_replacement(20) + pool2.sample_with_replacement(20) + + counts1 = [layout.use_count for env_pool in pool1._env_pools for layout in env_pool.layouts] + counts2 = [layout.use_count for env_pool in pool2._env_pools for layout in env_pool.layouts] + assert counts1 == counts2 + assert sum(counts1) == 20 + # Guard against a degenerate _draw (e.g. always slot 0): draws must spread across the pool. + assert sum(1 for count in counts1 if count > 0) > 1 + + +def _make_seeded_params(seed: int = 42) -> ObjectPlacerParams: + return ObjectPlacerParams( + placement_seed=seed, + solver_params=RelationSolverParams(max_iters=50), + apply_positions_to_objects=False, + ) + + +def test_pooled_placer_save_load_round_trip_preserves_layouts(tmp_path): + """save() then load() must reproduce every stored layout, with use_count reset to 0.""" + placer_params = _make_seeded_params() + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=6) + + path = tmp_path / "layouts.json" + pool.save(path) + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), placer_params) + + assert _pool_signatures(loaded) == _pool_signatures(pool) + assert all(layout.use_count == 0 for env_pool in loaded._env_pools for layout in env_pool.layouts) + + +def test_pooled_placer_save_includes_consumed_layouts(tmp_path): + """save() persists already-consumed layouts, so a loaded pool offers them again (cursor resets).""" + placer_params = _make_seeded_params() + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=6) + total_before = pool.total_remaining + + pool.sample_without_replacement(2) # consume part of the pool before saving + assert pool.total_remaining < total_before + + path = tmp_path / "layouts.json" + pool.save(path) + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), placer_params) + + assert loaded.total_remaining == total_before + assert _pool_signatures(loaded) == _pool_signatures(pool) + + +def test_pooled_placer_save_load_round_trip_multi_env_homogeneous(tmp_path): + """A multi-env homogeneous pool's per-env layouts survive the round trip (variants covered separately).""" + num_envs = 3 + placer_params = _make_seeded_params() + pool = PooledObjectPlacer( + objects=list(_create_test_objects()), placer_params=placer_params, pool_size=9, num_envs=num_envs + ) + + path = tmp_path / "layouts.json" + pool.save(path) + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), placer_params, num_envs=num_envs) + + assert loaded.num_envs == num_envs + assert _pool_signatures(loaded) == _pool_signatures(pool) + + +def test_pooled_placer_load_replays_saved_poses_without_resolving(tmp_path): + """load() reuses stored poses: a different solve seed must not change the loaded layouts.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(42), pool_size=6) + path = tmp_path / "layouts.json" + pool.save(path) + + # A fresh solve under seed 999 would differ; loading must replay the seed-42 layouts instead. + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params(999)) + assert _pool_signatures(loaded) == _pool_signatures(pool) + + +def test_pooled_placer_loaded_pool_samples_match_origin(tmp_path): + """A loaded pool draws the same layouts as the saved one under a shared seed.""" + placer_params = _make_seeded_params() + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=6) + path = tmp_path / "layouts.json" + pool.save(path) + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), placer_params) + + for original, restored in zip(pool.sample_with_replacement(20), loaded.sample_with_replacement(20)): + assert _positions_by_name(original) == _positions_by_name(restored) + + +def test_pooled_placer_loaded_pool_refills_by_resolving_when_drained(tmp_path): + """Draining a loaded pool past its stored layouts triggers a re-solve refill, not a failure.""" + placer_params = _make_seeded_params() + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=2) + path = tmp_path / "layouts.json" + pool.save(path) + + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), placer_params) + loaded.sample_without_replacement(loaded.total_remaining) + assert loaded.total_remaining == 0 + + refilled = loaded.sample_without_replacement(3) + assert len(refilled) == 3 + assert all(_positions_by_name(result) for result in refilled) + + +def test_pooled_placer_load_rejects_unknown_object(tmp_path): + """A saved layout naming an object absent from the provided objects must fail loudly.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + desk, box1, _box2 = _create_test_objects() + with pytest.raises(AssertionError, match="do not match the provided objects"): + PooledObjectPlacer.load(path, [desk, box1], _make_seeded_params()) + + +def test_pooled_placer_load_rejects_num_envs_mismatch(tmp_path): + """Requesting a num_envs different from the saved file must fail loudly.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + with pytest.raises(AssertionError, match="num_envs=3 does not match"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params(), num_envs=3) + + +def test_pooled_placer_save_load_round_trip_preserves_orientations(tmp_path): + """With random_yaw_init, per-object yaws must survive the round trip (not silently empty).""" + placer_params = ObjectPlacerParams( + placement_seed=42, + solver_params=RelationSolverParams(max_iters=50), + apply_positions_to_objects=False, + random_yaw_init=True, + ) + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=placer_params, pool_size=6) + # Guard the fixture: the round trip must actually exercise non-empty orientations. + assert any(layout.result.orientations for env_pool in pool._env_pools for layout in env_pool.layouts) + + path = tmp_path / "layouts.json" + pool.save(path) + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), placer_params) + + assert _pool_signatures(loaded) == _pool_signatures(pool) + + +def test_pooled_placer_load_missing_file_raises(tmp_path): + """A missing cache file names the path rather than raising a bare FileNotFoundError.""" + with pytest.raises(AssertionError, match="not found"): + PooledObjectPlacer.load(tmp_path / "nope.json", list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_malformed_json(tmp_path): + """A truncated/hand-edited file raises an attributable ValueError, not a bare JSONDecodeError.""" + path = tmp_path / "bad.json" + path.write_text("{ not valid json") + with pytest.raises(ValueError, match="not valid JSON"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_missing_key(tmp_path): + """A file missing a required top-level key names that key.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + data = json.loads(path.read_text()) + del data["env_pools"] + path.write_text(json.dumps(data)) + + with pytest.raises(AssertionError, match="missing required key 'env_pools'"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_non_bool_validation(tmp_path): + """A corrupt non-bool validation value fails loudly instead of coercing to passing.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + data = json.loads(path.read_text()) + checks = data["env_pools"][0][0]["validation"] + checks[next(iter(checks))] = "false" + path.write_text(json.dumps(data)) + + with pytest.raises(AssertionError, match="must be a JSON bool"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_empty_validation(tmp_path): + """An empty validation map would load as a failing layout, so it must be rejected on load.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + data = json.loads(path.read_text()) + data["env_pools"][0][0]["validation"] = {} + path.write_text(json.dumps(data)) + + with pytest.raises(AssertionError, match="empty validation map"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_non_bool_had_fallbacks(tmp_path): + """A non-bool had_fallbacks must fail loudly rather than re-suppressing the fallback warning.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + data = json.loads(path.read_text()) + data["had_fallbacks"] = "yes" + path.write_text(json.dumps(data)) + + with pytest.raises(AssertionError, match="had_fallbacks' must be a bool"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_malformed_position(tmp_path): + """A wrong-length position would silently mis-place an object, so load must reject it.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + data = json.loads(path.read_text()) + positions = data["env_pools"][0][0]["positions"] + positions[next(iter(positions))] = [1.0, 2.0] + path.write_text(json.dumps(data)) + + with pytest.raises(AssertionError, match="length-3 sequence"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_duplicate_object_names(tmp_path): + """Duplicate object names would collapse to one slot, so load must reject them.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + objects = list(_create_test_objects()) + objects.append(objects[0]) + with pytest.raises(AssertionError, match="Object names must be unique"): + PooledObjectPlacer.load(path, objects, _make_seeded_params()) + + +def test_pooled_placer_save_rejects_non_finite_pose(tmp_path): + """A non-finite coordinate must fail at save, leaving neither a target nor an orphan temp file.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + result = pool._env_pools[0].layouts[0].result + obj = next(iter(result.positions)) + result.positions[obj] = (float("nan"), 0.0, 0.0) + + path = tmp_path / "layouts.json" + with pytest.raises(ValueError, match="JSON compliant"): + pool.save(path) + assert not path.exists() + assert not path.with_name(f"{path.name}.tmp").exists() + + +def test_pooled_placer_save_rejects_duplicate_object_names(tmp_path): + """Saving with duplicate names would collapse a pose, so it must fail loudly like load does.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + pool._objects.append(pool._objects[0]) + + with pytest.raises(AssertionError, match="must be unique to save"): + pool.save(tmp_path / "layouts.json") + + +def test_pooled_placer_load_rejects_non_numeric_position(tmp_path): + """A non-numeric coordinate must fail loudly rather than crash deep in float().""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + data = json.loads(path.read_text()) + positions = data["env_pools"][0][0]["positions"] + positions[next(iter(positions))] = ["a", "b", "c"] + path.write_text(json.dumps(data)) + + with pytest.raises(AssertionError, match="must be finite numbers"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_layout_missing_key(tmp_path): + """A layout missing a required field names that field.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + data = json.loads(path.read_text()) + del data["env_pools"][0][0]["final_loss"] + path.write_text(json.dumps(data)) + + with pytest.raises(AssertionError, match="missing required key 'final_loss'"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_rejects_non_bool_heterogeneity_flag(tmp_path): + """A non-bool uses_env_specific_bboxes is a structural problem and must fail loudly.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + path = tmp_path / "layouts.json" + pool.save(path) + + data = json.loads(path.read_text()) + data["uses_env_specific_bboxes"] = "no" + path.write_text(json.dumps(data)) + + with pytest.raises(AssertionError, match="'uses_env_specific_bboxes' must be a bool"): + PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + + +def test_pooled_placer_load_restores_saved_seed_not_passed_seed(tmp_path): + """load() must sample under the saved seed, not the seed in the freshly-passed params.""" + pool_a = PooledObjectPlacer( + objects=list(_create_test_objects()), placer_params=_make_seeded_params(42), pool_size=6 + ) + path = tmp_path / "layouts.json" + pool_a.save(path) + + # Load with a different seed in params; restoration must ignore it. + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params(999)) + assert loaded._base_placement_seed == 42 + + # A fresh pool under seed 42 is the reference: matching it proves the saved seed drives sampling. + pool_a_ref = PooledObjectPlacer( + objects=list(_create_test_objects()), placer_params=_make_seeded_params(42), pool_size=6 + ) + loaded_seq = [_positions_by_name(r) for r in loaded.sample_with_replacement(4)] + ref_seq = [_positions_by_name(r) for r in pool_a_ref.sample_with_replacement(4)] + assert loaded_seq == ref_seq + + +def test_pooled_placer_save_load_preserves_had_fallbacks(tmp_path): + """had_fallbacks must survive the round trip so a post-load caller gating a warning isn't misled.""" + pool = PooledObjectPlacer(objects=list(_create_test_objects()), placer_params=_make_seeded_params(), pool_size=4) + pool._had_fallbacks = True + path = tmp_path / "layouts.json" + pool.save(path) + + loaded = PooledObjectPlacer.load(path, list(_create_test_objects()), _make_seeded_params()) + assert loaded.had_fallbacks diff --git a/isaaclab_arena/tests/test_placement_events.py b/isaaclab_arena/tests/test_placement_events.py index 6af3c1b27..c3009d1d3 100644 --- a/isaaclab_arena/tests/test_placement_events.py +++ b/isaaclab_arena/tests/test_placement_events.py @@ -316,7 +316,7 @@ def test_solve_and_place_objects_writes_invalid_fallback_layout(capsys): """Invalid fallback layouts should still be written, matching pool fallback behavior.""" from isaaclab_arena.relations.placement_events import solve_and_place_objects - from isaaclab_arena.relations.placement_result import PlacementResult + from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport desk, box1, box2 = _create_test_objects() objects = [desk, box1, box2] @@ -329,13 +329,16 @@ def sample_without_replacement(self, count: int) -> list[PlacementResult]: assert count == 1 return [ PlacementResult( - success=False, + validation=ValidationReport(checks={"no_overlap": False}), positions={box1: (0.0, 0.0, 0.0), box2: (0.0, 0.0, 0.0)}, final_loss=float("nan"), attempts=1, ) ] + def accepts(self, result: PlacementResult) -> bool: + return result.success + solve_and_place_objects(env, torch.tensor([0]), objects, InvalidPool()) captured = capsys.readouterr() @@ -345,11 +348,44 @@ def sample_without_replacement(self, count: int) -> list[PlacementResult]: assert "Writing best-loss fallback placement" in captured.out +def test_solve_and_place_objects_warns_when_filter_rejects_a_passing_layout(capsys): + """A layout that passes built-in checks but misses the pool's stricter filter is a fallback.""" + + from isaaclab_arena.relations.placement_events import solve_and_place_objects + from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport + + desk, box1, box2 = _create_test_objects() + objects = [desk, box1, box2] + env = _make_mock_env(num_envs=1) + + class StrictFilterPool: + requires_env_indexed_layouts = False + + def sample_without_replacement(self, count: int) -> list[PlacementResult]: + assert count == 1 + return [ + PlacementResult( + validation=ValidationReport(checks={"no_overlap": True}), + positions={box1: (0.0, 0.0, 0.0), box2: (0.0, 0.0, 0.0)}, + final_loss=0.0, + attempts=1, + ) + ] + + def accepts(self, result: PlacementResult) -> bool: + return False + + solve_and_place_objects(env, torch.tensor([0]), objects, StrictFilterPool()) + captured = capsys.readouterr() + + assert "Writing best-loss fallback placement" in captured.out + + def test_solve_and_place_objects_partial_reset_env_indexed_uses_absolute_env_result(): """Env-indexed partial resets should write the result for each absolute env id.""" from isaaclab_arena.relations.placement_events import solve_and_place_objects - from isaaclab_arena.relations.placement_result import PlacementResult + from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport desk, box1, box2 = _create_test_objects() objects = [desk, box1, box2] @@ -367,7 +403,7 @@ def sample_for_envs(self, env_ids: list[int]) -> dict[int, PlacementResult]: self.requested_env_ids = env_ids return { cur_env: PlacementResult( - success=True, + validation=ValidationReport(checks={"no_overlap": True}), positions={ box1: (float(cur_env), 0.0, 0.0), box2: (float(cur_env), 1.0, 0.0), @@ -378,6 +414,9 @@ def sample_for_envs(self, env_ids: list[int]) -> dict[int, PlacementResult]: for cur_env in env_ids } + def accepts(self, result: PlacementResult) -> bool: + return result.success + pool = EnvIndexedPool() solve_and_place_objects(env, torch.tensor([2]), objects, pool) @@ -512,7 +551,7 @@ def test_env_indexed_pool_seeds_init_state_before_reset_without_event(): from types import SimpleNamespace from isaaclab_arena.environments.relation_solver_interface import _apply_dynamic_spawn_pose - from isaaclab_arena.relations.placement_result import PlacementResult + from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport class MinimalObject: def __init__(self, name: str): @@ -536,7 +575,7 @@ def sample_with_replacement(self, count: int): assert count == 1 return [ PlacementResult( - success=True, + validation=ValidationReport(checks={"no_overlap": True}), positions={box: (float(env_id), 0.0, 0.1)}, final_loss=0.0, attempts=1, @@ -564,7 +603,7 @@ def test_env_indexed_static_poses_apply_per_env_positions(): """Static initial poses should apply per-env positions from env-indexed layouts.""" from isaaclab_arena.assets.dummy_object import DummyObject from isaaclab_arena.environments.relation_solver_interface import _apply_static_initial_poses - from isaaclab_arena.relations.placement_result import PlacementResult + from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport from isaaclab_arena.relations.relations import IsAnchor, On from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox from isaaclab_arena.utils.pose import Pose, PosePerEnv @@ -591,7 +630,7 @@ class PerEnvPool: def sample_with_replacement(self, count: int): return [ PlacementResult( - success=True, + validation=ValidationReport(checks={"no_overlap": True}), positions={box: (0.1 * env_id, 0.2 * env_id, 0.11)}, final_loss=0.0, attempts=1, diff --git a/isaaclab_arena/tests/test_relation_solver_interface.py b/isaaclab_arena/tests/test_relation_solver_interface.py index a4a3e0054..df0f0d78c 100644 --- a/isaaclab_arena/tests/test_relation_solver_interface.py +++ b/isaaclab_arena/tests/test_relation_solver_interface.py @@ -87,11 +87,15 @@ def test_static_solve_and_apply_relation_placement_reuses_object_only_placement( def test_dynamic_spawn_pose_skips_objects_missing_from_fallback_layout(): from isaaclab_arena.environments.relation_solver_interface import _apply_dynamic_spawn_pose - from isaaclab_arena.relations.placement_result import PlacementResult + from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport desk = _make_desk() box = _make_box() - placement_pool = _FakePlacementPool([PlacementResult(success=False, positions={}, final_loss=1.0, attempts=1)]) + placement_pool = _FakePlacementPool([ + PlacementResult( + validation=ValidationReport(checks={"no_overlap": False}), positions={}, final_loss=1.0, attempts=1 + ) + ]) _apply_dynamic_spawn_pose( objects=[desk, box], @@ -104,7 +108,7 @@ def test_dynamic_spawn_pose_skips_objects_missing_from_fallback_layout(): def test_static_initial_poses_skip_object_when_any_layout_is_missing_position(capsys): from isaaclab_arena.environments.relation_solver_interface import _apply_static_initial_poses - from isaaclab_arena.relations.placement_result import PlacementResult + from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport from isaaclab_arena.utils.pose import PosePerEnv desk = _make_desk() @@ -112,13 +116,13 @@ def test_static_initial_poses_skip_object_when_any_layout_is_missing_position(ca placed_box = _make_box("placed_box") placement_pool = _FakePlacementPool([ PlacementResult( - success=False, + validation=ValidationReport(checks={"no_overlap": False}), positions={placed_box: (0.1, 0.0, 0.2)}, final_loss=1.0, attempts=1, ), PlacementResult( - success=False, + validation=ValidationReport(checks={"no_overlap": False}), positions={placed_box: (0.2, 0.0, 0.2)}, final_loss=1.0, attempts=1, diff --git a/isaaclab_arena/tests/test_validate_placement.py b/isaaclab_arena/tests/test_validate_placement.py index d61ed477c..b3032579e 100644 --- a/isaaclab_arena/tests/test_validate_placement.py +++ b/isaaclab_arena/tests/test_validate_placement.py @@ -3,7 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Tests for ObjectPlacer placement validation (_validate_placement, _validate_no_overlap, _validate_on_relations).""" +"""Tests for ObjectPlacer placement validation (_validate_geometry, _validate_no_overlap, _validate_on_relations).""" import math import torch @@ -11,6 +11,7 @@ from isaaclab_arena.assets.dummy_object import DummyObject from isaaclab_arena.relations.object_placer import ObjectPlacer from isaaclab_arena.relations.object_placer_params import ObjectPlacerParams +from isaaclab_arena.relations.placement_result import PlacementResult, ValidationReport from isaaclab_arena.relations.relations import On, RotateAroundSolution from isaaclab_arena.utils.bounding_box import AxisAlignedBoundingBox @@ -52,7 +53,7 @@ def test_no_overlap_returns_true(): a = _make_box("a") b = _make_box("b") positions = {a: (0.0, 0.0, 0.0), b: (1.0, 0.0, 0.0)} - assert placer._validate_placement(positions, _env_bboxes(positions)) is True + assert placer._validate_geometry(positions, _env_bboxes(positions)).passed is True def test_overlapping_returns_false(): @@ -61,7 +62,7 @@ def test_overlapping_returns_false(): a = _make_box("a") b = _make_box("b") positions = {a: (0.0, 0.0, 0.0), b: (0.0, 0.0, 0.0)} - assert placer._validate_placement(positions, _env_bboxes(positions)) is False + assert placer._validate_geometry(positions, _env_bboxes(positions)).passed is False def test_partial_overlap_returns_false(): @@ -70,7 +71,7 @@ def test_partial_overlap_returns_false(): a = _make_box("a", size=0.2) b = _make_box("b", size=0.2) positions = {a: (0.0, 0.0, 0.0), b: (0.1, 0.1, 0.0)} - assert placer._validate_placement(positions, _env_bboxes(positions)) is False + assert placer._validate_geometry(positions, _env_bboxes(positions)).passed is False def test_separated_in_z_passes(): @@ -79,7 +80,7 @@ def test_separated_in_z_passes(): a = _make_box("a") b = _make_box("b") positions = {a: (0.0, 0.0, 0.0), b: (0.0, 0.0, 5.0)} - assert placer._validate_placement(positions, _env_bboxes(positions)) is True + assert placer._validate_geometry(positions, _env_bboxes(positions)).passed is True def test_object_on_surface_no_overlap(): @@ -89,7 +90,7 @@ def test_object_on_surface_no_overlap(): box = _make_box("box", size=0.2) # Desk top at z=0.05; box at z=0.16 → box occupies z=[0.06, 0.26], clear of desk positions = {desk: (0.0, 0.0, 0.0), box: (0.0, 0.0, 0.16)} - assert placer._validate_placement(positions, _env_bboxes(positions)) is True + assert placer._validate_geometry(positions, _env_bboxes(positions)).passed is True def test_colocated_siblings_overlap_rejected(): @@ -99,7 +100,7 @@ def test_colocated_siblings_overlap_rejected(): a = _make_box("a", size=0.2) b = _make_box("b", size=0.2) positions = {desk: (0.0, 0.0, 0.0), a: (0.0, 0.0, 0.15), b: (0.0, 0.0, 0.15)} - assert placer._validate_placement(positions, _env_bboxes(positions)) is False + assert placer._validate_geometry(positions, _env_bboxes(positions)).passed is False def test_rotation_aware_overlap_uses_yaw(): @@ -109,9 +110,9 @@ def test_rotation_aware_overlap_uses_yaw(): b = _make_box("b", size=0.1) positions = {a: (0.0, 0.0, 0.0), b: (0.0, 0.2, 0.0)} axis_aligned = {a: a.get_bounding_box(), b: b.get_bounding_box()} - assert placer._validate_placement(positions, axis_aligned) is True + assert placer._validate_geometry(positions, axis_aligned).passed is True rotated = {a: a.get_bounding_box().rotated_around_z(math.pi / 2), b: b.get_bounding_box()} - assert placer._validate_placement(positions, rotated) is False + assert placer._validate_geometry(positions, rotated).passed is False def test_candidate_bbox_aligns_with_candidate_yaw(): @@ -127,7 +128,7 @@ def test_candidate_bbox_aligns_with_candidate_yaw(): # Mirrors _place_ranked: each candidate validates against its own bbox row. validations = [ - placer._validate_placement(positions, ObjectPlacer._get_bounding_boxes_for_candidate_index(rotated, idx)) + placer._validate_geometry(positions, ObjectPlacer._get_bounding_boxes_for_candidate_index(rotated, idx)).passed for idx in range(2) ] # Axis-aligned `a` clears b; rotated 90° it sweeps into b. A row/candidate swap would flip both. @@ -229,3 +230,108 @@ def test_on_relation_check_child_outside_xy_returns_false(): box.add_relation(On(desk)) positions = {desk: (0.0, 0.0, 0.0), box: (10.0, 10.0, 0.1)} assert placer._validate_on_relations(positions, _env_bboxes(positions)) is False + + +def test_validate_geometry_reports_named_checks_for_valid_placement(): + """_validate_geometry should report both named checks passing for a valid placement.""" + placer = ObjectPlacer(params=ObjectPlacerParams()) + a = _make_box("a") + b = _make_box("b") + positions = {a: (0.0, 0.0, 0.0), b: (1.0, 0.0, 0.0)} + report = placer._validate_geometry(positions, _env_bboxes(positions)) + assert report.checks == {"no_overlap": True, "on_relations": True} + assert report.passed is True + assert report.failed_checks == () + + +def test_validate_geometry_isolates_the_failing_check(): + """_validate_geometry should fail only no_overlap when objects overlap with relations satisfied.""" + placer = ObjectPlacer(params=ObjectPlacerParams()) + a = _make_box("a") + b = _make_box("b") + positions = {a: (0.0, 0.0, 0.0), b: (0.0, 0.0, 0.0)} + report = placer._validate_geometry(positions, _env_bboxes(positions)) + assert report.checks["no_overlap"] is False + assert report.checks["on_relations"] is True + assert report.passed is False + assert report.failed_checks == ("no_overlap",) + + +def test_validation_report_passed_and_failed_checks(): + """ValidationReport should derive passed and failed_checks from its checks map.""" + assert ValidationReport(checks={"a": True, "b": True}).passed is True + mixed = ValidationReport(checks={"a": True, "b": False, "c": False}) + assert mixed.passed is False + assert mixed.failed_checks == ("b", "c") + + +def test_validation_report_empty_fails_closed(): + """An empty report should fail closed so an unvalidated layout is never treated as valid.""" + assert ValidationReport(checks={}).passed is False + assert ValidationReport(checks={}).failed_checks == () + + +def test_validation_report_snapshots_caller_dict(): + """Mutating the caller's dict after construction must not change the report.""" + source = {"no_overlap": True} + report = ValidationReport(checks=source) + source["no_overlap"] = False + source["on_relations"] = False + assert report.passed is True + assert dict(report.checks) == {"no_overlap": True} + + +def test_validation_report_checks_are_read_only(): + """checks is a read-only view, so a frozen report can't be mutated in place.""" + import pytest + + report = ValidationReport(checks={"no_overlap": True}) + with pytest.raises(TypeError): + report.checks["no_overlap"] = False + + +def test_validation_report_survives_deepcopy(): + """Reports must deepcopy/pickle (Isaac Lab deepcopies the EventTermCfg params that carry them).""" + import copy + + report = ValidationReport(checks={"no_overlap": True, "on_relations": False}) + clone = copy.deepcopy(report) + assert dict(clone.checks) == {"no_overlap": True, "on_relations": False} + assert clone.failed_checks == ("on_relations",) + + +def test_validation_report_with_check_adds_sibling_without_mutating_original(): + """with_check derives a new report with a further check, leaving the original intact.""" + geometry = ValidationReport(checks={"no_overlap": True, "on_relations": True}) + extended = geometry.with_check("extra_check", False) + # Original is untouched (immutable); the derived report carries the sibling check. + assert dict(geometry.checks) == {"no_overlap": True, "on_relations": True} + assert dict(extended.checks) == {"no_overlap": True, "on_relations": True, "extra_check": False} + # A failing sibling check flips acceptance, so a custom layout_filter can require it. + assert geometry.passed is True + assert extended.passed is False + assert extended.failed_checks == ("extra_check",) + + +def test_validation_report_rejects_non_bool_check_value(): + """Non-bool check values must fail at construction, else a truthy non-bool would satisfy all(...).""" + import pytest + + with pytest.raises(AssertionError, match="must be bools"): + ValidationReport(checks={"no_overlap": 1}) + + +def test_validation_report_with_check_overwrites_existing_name(): + """with_check is last-write-wins: re-checking a name replaces its prior outcome.""" + report = ValidationReport(checks={"no_overlap": False}) + rechecked = report.with_check("no_overlap", True) + assert dict(rechecked.checks) == {"no_overlap": True} + assert rechecked.passed is True + + +def test_placement_result_success_is_derived_from_validation(): + """PlacementResult.success should mirror its ValidationReport's passed flag.""" + failing = ValidationReport(checks={"no_overlap": False, "on_relations": True}) + result = PlacementResult(positions={}, final_loss=1.0, attempts=1, validation=failing) + assert result.success is False + assert result.validation.failed_checks == ("no_overlap",) diff --git a/isaaclab_arena/tests/utils/placement.py b/isaaclab_arena/tests/utils/placement.py new file mode 100644 index 000000000..36a7902e5 --- /dev/null +++ b/isaaclab_arena/tests/utils/placement.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Shared placement-test helpers.""" + +from __future__ import annotations + +from isaaclab_arena.relations.placement_result import PlacementResult + + +def layout_signature(result: PlacementResult): + """Name-keyed (positions, orientations, validation) tuple for comparing layouts across instances.""" + return ( + {obj.name: tuple(pos) for obj, pos in result.positions.items()}, + {obj.name: yaw for obj, yaw in result.orientations.items()}, + dict(result.validation.checks), + )