From 78bcc184c810e7aac3c289786e79c999baad5087 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Fri, 9 Jan 2026 15:33:05 -0800 Subject: [PATCH 01/23] added initial version pipeline now correctly reports, removed all modifications to original moved patching logic to seperate file and expanded capability applied workaround for reused input functions bug --- .../rsl_rl/annotate_functions_for_export.py | 201 ++++++++++++++++++ .../reinforcement_learning/rsl_rl/export.py | 200 +++++++++++++++++ .../envs/mdp/actions/joint_actions.py | 2 +- 3 files changed, 402 insertions(+), 1 deletion(-) create mode 100644 scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py create mode 100644 scripts/reinforcement_learning/rsl_rl/export.py diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py new file mode 100644 index 000000000000..0280deab5d61 --- /dev/null +++ b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py @@ -0,0 +1,201 @@ +from leapp import annotate + +# Global storage for original and annotating ArticulationData properties +_articulation_data_originals = {} +_articulation_data_annotating = {} + + +def _setup_articulation_data_annotations(): + """ + Prepares annotating versions of ArticulationData properties without applying them. + + The annotations will be temporarily applied only during compute_group calls, + avoiding conflicts with rewards, terminations, commands, and actions that also + access these properties. + """ + from isaaclab.assets.articulation.articulation_data import ArticulationData + import torch + + # All observation properties - we can include all of them now since annotations + # are only active during compute_group + observation_properties = { + # Root state (position, orientation, velocities) + 'root_pos_w', # base_pos_z, root_pos_w + 'root_quat_w', # root_quat_w + 'root_lin_vel_b', # base_lin_vel + 'root_ang_vel_b', # base_ang_vel + 'root_lin_vel_w', # root_lin_vel_w + 'root_ang_vel_w', # root_ang_vel_w + 'projected_gravity_b', # 'projected_gravity_b', + + # Body state + 'body_pose_w', # body_pose_w + 'body_quat_w', # body_projected_gravity_b + + # Joint state + 'joint_pos', # joint_pos, joint_pos_rel, joint_pos_limit_normalized + 'joint_vel', # joint_vel, joint_vel_rel + 'applied_torque', # joint_effort + } + + for prop_name in observation_properties: + attr = getattr(ArticulationData, prop_name, None) + + # Skip if attribute doesn't exist or isn't a property + if attr is None or not isinstance(attr, property): + continue + + # Skip properties without a getter + if attr.fget is None: + continue + + # Store the original property + _articulation_data_originals[prop_name] = attr + + # Create annotating getter + original_fget = attr.fget + + def make_annotating_fget(original, name): + """Factory function to properly capture variables in closure.""" + + def annotating_fget(self): + result = original(self) + if isinstance(result, torch.Tensor): + result = annotate.input_tensors( + {name: result}, + node_name='observation_manager' + ) + return result + return annotating_fget + + annotating_fget = make_annotating_fget(original_fget, prop_name) + annotating_fget.__doc__ = original_fget.__doc__ + + # Create annotating property + annotating_property = property( + fget=annotating_fget, + fset=attr.fset, + fdel=attr.fdel, + doc=attr.__doc__ + ) + _articulation_data_annotating[prop_name] = annotating_property + + print(f"Prepared {len(_articulation_data_originals)} ArticulationData properties for temporary annotation") + + +def _apply_articulation_annotations(): + """Temporarily applies annotating versions of ArticulationData properties.""" + from isaaclab.assets.articulation.articulation_data import ArticulationData + for prop_name, annotating_prop in _articulation_data_annotating.items(): + setattr(ArticulationData, prop_name, annotating_prop) + + +def _remove_articulation_annotations(): + """Restores original ArticulationData properties.""" + from isaaclab.assets.articulation.articulation_data import ArticulationData + for prop_name, original_prop in _articulation_data_originals.items(): + setattr(ArticulationData, prop_name, original_prop) + + +def annotate_observation_manager(): + """ + Patches observation-related functions and classes to annotate inputs/outputs. + + This patches: + - ArticulationData properties (temporarily, only during compute_group) + - Observation functions at the module level (last_action, generated_commands, etc.) + - ObservationManager.compute_group to annotate outputs + + IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. + """ + from isaaclab.envs.mdp import observations + from isaaclab.managers.observation_manager import ObservationManager + from leapp.leapp_graph.traced_tensor import TracedTensor + + # Prepare (but don't apply) ArticulationData annotations + _setup_articulation_data_annotations() + + # Patch last_action observation function + original_last_action = observations.last_action + + def patched_last_action(env, action_name=None): + result = original_last_action(env, action_name) + result = annotate.input_tensors({"last_actions": result}, node_name='observation_manager') + return result + + # Patch generated_commands observation function + original_generated_commands = observations.generated_commands + + def patched_generated_commands(env, command_name=None): + result = original_generated_commands(env, command_name) + result = annotate.input_tensors({"pose_command": result}, node_name='observation_manager') + return result + + # Apply observation function patches at module level + # Note: Observation functions that use ArticulationData properties (base_pos_z, root_pos_w, + # root_quat_w, body_projected_gravity_b) don't need patching since the underlying + # ArticulationData properties are temporarily annotated during compute_group. + observations.last_action = patched_last_action + observations.generated_commands = patched_generated_commands + + # Patch ObservationManager.compute_group to: + # 1. Temporarily apply ArticulationData annotations before computing + # 2. Annotate outputs + # 3. Restore original ArticulationData properties after computing + original_compute_group = ObservationManager.compute_group + + def patched_compute_group(self, *args, **kwargs): + # Apply ArticulationData annotations only during observation computation + _apply_articulation_annotations() + try: + output = original_compute_group(self, *args, **kwargs) + annotate.output_tensors(output, node_name='observation_manager', export_with='torch', use_trace=True) + if isinstance(output, TracedTensor): + return output.tensor + else: + return output + finally: + # Always restore original properties, even if an exception occurs + _remove_articulation_annotations() + + ObservationManager.compute_group = patched_compute_group + + +def annotate_action_manager(): + """ + Patches ActionManager.process_action to annotate action inputs/outputs. + + IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. + """ + from isaaclab.managers.action_manager import ActionManager + import torch + + # Patch ActionManager.process_action at class level + original_process_action = ActionManager.process_action + + def patched_process_action(self, action: torch.Tensor): + action = annotate.input_tensors({"actions": action}, node_name='action_manager') + original_process_action(self, action) + tensors = {} + for term_name, term_ in self._terms.items(): + tensors[term_name] = term_.processed_actions + annotate.output_tensors(tensors, node_name='action_manager', export_with='torch') + + ActionManager.process_action = patched_process_action + + print("Patched action manager: ActionManager.process_action") + + +def add_leapp_annotations(): + """ + Adds all leapp annotations for exporting Isaac Lab policies. + + This is the main entry point that patches: + - ObservationManager and related observation functions + - ActionManager.process_action + + IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. + """ + annotate_observation_manager() + annotate_action_manager() + print("All leapp annotations added") diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py new file mode 100644 index 000000000000..7b49b5e9870e --- /dev/null +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -0,0 +1,200 @@ +"""Script to play a checkpoint if an RL agent from RSL-RL.""" + +"""Launch Isaac Sim Simulator first.""" + +import argparse +import sys + +from isaaclab.app import AppLauncher + +# local imports +import cli_args # isort: skip + +from leapp import annotate + +# add argparse arguments +parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.") +parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") +parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") +parser.add_argument( + "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." +) +parser.add_argument("--task", type=str, default=None, help="Name of the task.") +parser.add_argument( + "--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point." +) +parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") +parser.add_argument( + "--use_pretrained_checkpoint", + action="store_true", + help="Use the pre-trained checkpoint from Nucleus.", +) +parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.") +# append RSL-RL cli arguments +cli_args.add_rsl_rl_args(parser) +# append AppLauncher cli args +AppLauncher.add_app_launcher_args(parser) +# parse the arguments +args_cli, hydra_args = parser.parse_known_args() +# always enable cameras to record video +if args_cli.video: + args_cli.enable_cameras = True + +# clear out sys.argv for Hydra +sys.argv = [sys.argv[0]] + hydra_args + +# launch omniverse app +app_launcher = AppLauncher(args_cli) +simulation_app = app_launcher.app + +"""Rest everything follows.""" + +import gymnasium as gym +import os +import time +import torch + +from rsl_rl.runners import DistillationRunner, OnPolicyRunner + +from isaaclab.envs import ( + DirectMARLEnv, + DirectMARLEnvCfg, + DirectRLEnvCfg, + ManagerBasedRLEnvCfg, + multi_agent_to_single_agent, +) +from isaaclab.utils.assets import retrieve_file_path +from isaaclab.utils.dict import print_dict + +from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx +from isaaclab_rl.utils.pretrained_checkpoint import get_published_pretrained_checkpoint + +# IMPORTANT: Add leapp annotations BEFORE importing isaaclab_tasks +# This ensures the patched functions are captured when configs are created +from annotate_functions_for_export import add_leapp_annotations +add_leapp_annotations() + +import isaaclab_tasks # noqa: F401 +from isaaclab_tasks.utils import get_checkpoint_path +from isaaclab_tasks.utils.hydra import hydra_task_config + + +@hydra_task_config(args_cli.task, args_cli.agent) +def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg): + """Export a RSL-RL agent.""" + # grab task name for checkpoint path + task_name = args_cli.task.split(":")[-1] + train_task_name = task_name.replace("-Play", "") + + # override configurations with non-hydra CLI arguments + agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) + env_cfg.scene.num_envs = 1 + + # set the environment seed + # note: certain randomizations occur in the environment initialization so we set the seed here + env_cfg.seed = agent_cfg.seed + env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device + + # specify directory for logging experiments + log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) + log_root_path = os.path.abspath(log_root_path) + print(f"[INFO] Loading experiment from directory: {log_root_path}") + if args_cli.use_pretrained_checkpoint: + resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name) + if not resume_path: + print("[INFO] Unfortunately a pre-trained checkpoint is currently unavailable for this task.") + return + elif args_cli.checkpoint: + resume_path = retrieve_file_path(args_cli.checkpoint) + else: + resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint) + + log_dir = os.path.dirname(resume_path) + + # set the log directory for the environment (works for all environment types) + env_cfg.log_dir = log_dir + + # create isaac environment + # Note: observation functions are already patched at module level (before isaaclab_tasks import) + env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) + + # convert to single-agent instance if required by the RL algorithm + if isinstance(env.unwrapped, DirectMARLEnv): + env = multi_agent_to_single_agent(env) + + # wrap around environment for rsl-rl + env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions) + + print(f"[INFO]: Loading model checkpoint from: {resume_path}") + # load previously trained model + if agent_cfg.class_name == "OnPolicyRunner": + runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) + elif agent_cfg.class_name == "DistillationRunner": + runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device) + else: + raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}") + runner.load(resume_path) + + # obtain the trained policy for inference + policy = runner.get_inference_policy(device=env.unwrapped.device) + + # extract the neural network module + # we do this in a try-except to maintain backwards compatibility. + try: + # version 2.3 onwards + policy_nn = runner.alg.policy + except AttributeError: + # version 2.2 and below + policy_nn = runner.alg.actor_critic + + # extract the normalizer + if hasattr(policy_nn, "actor_obs_normalizer"): + normalizer = policy_nn.actor_obs_normalizer + elif hasattr(policy_nn, "student_obs_normalizer"): + normalizer = policy_nn.student_obs_normalizer + else: + normalizer = None + + # export policy to onnx/jit + export_model_dir = os.path.join(os.path.dirname(resume_path), "exported") + export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt") + export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx") + jit_path = os.path.join(export_model_dir, "policy.pt") + onnx_path = os.path.join(export_model_dir, "policy.onnx") + print(f"[INFO]: Exported policy to: jit {jit_path}, onnx {onnx_path}") + + # start annotation tracing + # Note: all patching is done at module/class level before isaaclab_tasks import + annotate.start(task_name) + obs = env.get_observations() + for key, val in obs.items(): + if hasattr(val, 'leapp_tag'): + print('FRANK DEBUG', key, 'tag:', val.leapp_tag) + else: + print('FRANK DEBUG', key, 'no tag') + # simulate environment + while not simulation_app.is_running(): + time.sleep(0.5) + + for _ in range(5): + # run everything in inference mode + with torch.inference_mode(): + # agent stepping + with annotate.block('policy', inputs=['obs'], outputs=['actions'], + backend_params={'model_path': onnx_path, 'copy_original_model': True}): + actions = policy(obs) + # env stepping + obs, _, _, _ = env.step(actions) + + annotate.stop() + annotate.compile_graph() + + # close the simulator + env.close() + + +if __name__ == "__main__": + # run the main function + main() + # close sim app + simulation_app.close() diff --git a/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py b/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py index c32e501b7591..dbddeb8bc3a5 100644 --- a/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py +++ b/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py @@ -169,7 +169,7 @@ def process_actions(self, actions: torch.Tensor): # store the raw actions self._raw_actions[:] = actions # apply the affine transformations - self._processed_actions = self._raw_actions * self._scale + self._offset + self._processed_actions = actions * self._scale + self._offset # clip actions if self.cfg.clip is not None: self._processed_actions = torch.clamp( From 39f5dbc0bb65f58cdc981ccd15833351292da0b8 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Wed, 14 Jan 2026 16:41:07 -0800 Subject: [PATCH 02/23] added static outputs to graph. addressd comments --- .../rsl_rl/annotate_functions_for_export.py | 53 ++++++++++++++----- .../reinforcement_learning/rsl_rl/export.py | 9 +--- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py index 0280deab5d61..dae73b0d8825 100644 --- a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py +++ b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py @@ -1,4 +1,10 @@ from leapp import annotate +from isaaclab.managers.action_manager import ActionManager +from isaaclab.assets.articulation.articulation_data import ArticulationData +from isaaclab.envs.mdp import observations +from isaaclab.managers.observation_manager import ObservationManager +from leapp.leapp_graph.traced_tensor import TracedTensor +import torch # Global storage for original and annotating ArticulationData properties _articulation_data_originals = {} @@ -13,8 +19,7 @@ def _setup_articulation_data_annotations(): avoiding conflicts with rewards, terminations, commands, and actions that also access these properties. """ - from isaaclab.assets.articulation.articulation_data import ArticulationData - import torch + # All observation properties - we can include all of them now since annotations # are only active during compute_group @@ -43,11 +48,11 @@ def _setup_articulation_data_annotations(): # Skip if attribute doesn't exist or isn't a property if attr is None or not isinstance(attr, property): - continue + raise ValueError(f"Attribute {prop_name} does not exist or is not a property") # Skip properties without a getter if attr.fget is None: - continue + raise ValueError(f"Attribute {prop_name} does not have a getter") # Store the original property _articulation_data_originals[prop_name] = attr @@ -85,14 +90,12 @@ def annotating_fget(self): def _apply_articulation_annotations(): """Temporarily applies annotating versions of ArticulationData properties.""" - from isaaclab.assets.articulation.articulation_data import ArticulationData for prop_name, annotating_prop in _articulation_data_annotating.items(): setattr(ArticulationData, prop_name, annotating_prop) def _remove_articulation_annotations(): """Restores original ArticulationData properties.""" - from isaaclab.assets.articulation.articulation_data import ArticulationData for prop_name, original_prop in _articulation_data_originals.items(): setattr(ArticulationData, prop_name, original_prop) @@ -108,9 +111,6 @@ def annotate_observation_manager(): IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. """ - from isaaclab.envs.mdp import observations - from isaaclab.managers.observation_manager import ObservationManager - from leapp.leapp_graph.traced_tensor import TracedTensor # Prepare (but don't apply) ArticulationData annotations _setup_articulation_data_annotations() @@ -128,7 +128,7 @@ def patched_last_action(env, action_name=None): def patched_generated_commands(env, command_name=None): result = original_generated_commands(env, command_name) - result = annotate.input_tensors({"pose_command": result}, node_name='observation_manager') + result = annotate.input_tensors({"commands": result}, node_name='observation_manager') return result # Apply observation function patches at module level @@ -149,7 +149,7 @@ def patched_compute_group(self, *args, **kwargs): _apply_articulation_annotations() try: output = original_compute_group(self, *args, **kwargs) - annotate.output_tensors(output, node_name='observation_manager', export_with='torch', use_trace=True) + annotate.output_tensors('observation_manager', output, export_with='torch', use_trace=True) if isinstance(output, TracedTensor): return output.tensor else: @@ -165,10 +165,11 @@ def annotate_action_manager(): """ Patches ActionManager.process_action to annotate action inputs/outputs. + Also collects static values (default_joint_stiffness and default_joint_damping) + from action terms that have them. + IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. """ - from isaaclab.managers.action_manager import ActionManager - import torch # Patch ActionManager.process_action at class level original_process_action = ActionManager.process_action @@ -176,10 +177,34 @@ def annotate_action_manager(): def patched_process_action(self, action: torch.Tensor): action = annotate.input_tensors({"actions": action}, node_name='action_manager') original_process_action(self, action) + annotate.mirror_leapp_tags(action, self._action) tensors = {} + static_values = {} + for term_name, term_ in self._terms.items(): tensors[term_name] = term_.processed_actions - annotate.output_tensors(tensors, node_name='action_manager', export_with='torch') + + # Collect static values (kp/kd gains) if available + asset = getattr(term_, '_asset', None) + if asset is not None and hasattr(asset, 'data'): + data = asset.data + joint_ids = getattr(term_, '_joint_ids', None) + + # Get default_joint_stiffness (kp gains) + if hasattr(data, 'default_joint_stiffness') and data.default_joint_stiffness is not None: + if joint_ids is not None: + static_values[f"{term_name}_kp_gains"] = data.default_joint_stiffness[:, joint_ids] + else: + static_values[f"{term_name}_kp_gains"] = data.default_joint_stiffness + + # Get default_joint_damping (kd gains) + if hasattr(data, 'default_joint_damping') and data.default_joint_damping is not None: + if joint_ids is not None: + static_values[f"{term_name}_kd_gains"] = data.default_joint_damping[:, joint_ids] + else: + static_values[f"{term_name}_kd_gains"] = data.default_joint_damping + + annotate.output_tensors('action_manager', tensors, static_outputs=static_values, export_with='torch') ActionManager.process_action = patched_process_action diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 7b49b5e9870e..316daa175207 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -1,4 +1,4 @@ -"""Script to play a checkpoint if an RL agent from RSL-RL.""" +"""Script to export a checkpoint if an RL agent from RSL-RL.""" """Launch Isaac Sim Simulator first.""" @@ -167,11 +167,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # Note: all patching is done at module/class level before isaaclab_tasks import annotate.start(task_name) obs = env.get_observations() - for key, val in obs.items(): - if hasattr(val, 'leapp_tag'): - print('FRANK DEBUG', key, 'tag:', val.leapp_tag) - else: - print('FRANK DEBUG', key, 'no tag') # simulate environment while not simulation_app.is_running(): time.sleep(0.5) @@ -187,7 +182,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen obs, _, _, _ = env.step(actions) annotate.stop() - annotate.compile_graph() + annotate.compile_graph(validate=False) # close the simulator env.close() From 6194308468fc3aae081567a343f52855b5abb4ae Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Wed, 14 Jan 2026 16:48:20 -0800 Subject: [PATCH 03/23] added precommit --- .../rsl_rl/annotate_functions_for_export.py | 73 +++++++++---------- .../reinforcement_learning/rsl_rl/export.py | 23 ++++-- 2 files changed, 51 insertions(+), 45 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py index dae73b0d8825..edf8a7c859f7 100644 --- a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py +++ b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py @@ -1,10 +1,17 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import torch + from leapp import annotate -from isaaclab.managers.action_manager import ActionManager +from leapp.leapp_graph.traced_tensor import TracedTensor + from isaaclab.assets.articulation.articulation_data import ArticulationData from isaaclab.envs.mdp import observations +from isaaclab.managers.action_manager import ActionManager from isaaclab.managers.observation_manager import ObservationManager -from leapp.leapp_graph.traced_tensor import TracedTensor -import torch # Global storage for original and annotating ArticulationData properties _articulation_data_originals = {} @@ -20,27 +27,24 @@ def _setup_articulation_data_annotations(): access these properties. """ - # All observation properties - we can include all of them now since annotations # are only active during compute_group observation_properties = { # Root state (position, orientation, velocities) - 'root_pos_w', # base_pos_z, root_pos_w - 'root_quat_w', # root_quat_w - 'root_lin_vel_b', # base_lin_vel - 'root_ang_vel_b', # base_ang_vel - 'root_lin_vel_w', # root_lin_vel_w - 'root_ang_vel_w', # root_ang_vel_w - 'projected_gravity_b', # 'projected_gravity_b', - + "root_pos_w", # base_pos_z, root_pos_w + "root_quat_w", # root_quat_w + "root_lin_vel_b", # base_lin_vel + "root_ang_vel_b", # base_ang_vel + "root_lin_vel_w", # root_lin_vel_w + "root_ang_vel_w", # root_ang_vel_w + "projected_gravity_b", # 'projected_gravity_b', # Body state - 'body_pose_w', # body_pose_w - 'body_quat_w', # body_projected_gravity_b - + "body_pose_w", # body_pose_w + "body_quat_w", # body_projected_gravity_b # Joint state - 'joint_pos', # joint_pos, joint_pos_rel, joint_pos_limit_normalized - 'joint_vel', # joint_vel, joint_vel_rel - 'applied_torque', # joint_effort + "joint_pos", # joint_pos, joint_pos_rel, joint_pos_limit_normalized + "joint_vel", # joint_vel, joint_vel_rel + "applied_torque", # joint_effort } for prop_name in observation_properties: @@ -66,23 +70,16 @@ def make_annotating_fget(original, name): def annotating_fget(self): result = original(self) if isinstance(result, torch.Tensor): - result = annotate.input_tensors( - {name: result}, - node_name='observation_manager' - ) + result = annotate.input_tensors({name: result}, node_name="observation_manager") return result + return annotating_fget annotating_fget = make_annotating_fget(original_fget, prop_name) annotating_fget.__doc__ = original_fget.__doc__ # Create annotating property - annotating_property = property( - fget=annotating_fget, - fset=attr.fset, - fdel=attr.fdel, - doc=attr.__doc__ - ) + annotating_property = property(fget=annotating_fget, fset=attr.fset, fdel=attr.fdel, doc=attr.__doc__) _articulation_data_annotating[prop_name] = annotating_property print(f"Prepared {len(_articulation_data_originals)} ArticulationData properties for temporary annotation") @@ -120,7 +117,7 @@ def annotate_observation_manager(): def patched_last_action(env, action_name=None): result = original_last_action(env, action_name) - result = annotate.input_tensors({"last_actions": result}, node_name='observation_manager') + result = annotate.input_tensors({"last_actions": result}, node_name="observation_manager") return result # Patch generated_commands observation function @@ -128,7 +125,7 @@ def patched_last_action(env, action_name=None): def patched_generated_commands(env, command_name=None): result = original_generated_commands(env, command_name) - result = annotate.input_tensors({"commands": result}, node_name='observation_manager') + result = annotate.input_tensors({"commands": result}, node_name="observation_manager") return result # Apply observation function patches at module level @@ -149,7 +146,7 @@ def patched_compute_group(self, *args, **kwargs): _apply_articulation_annotations() try: output = original_compute_group(self, *args, **kwargs) - annotate.output_tensors('observation_manager', output, export_with='torch', use_trace=True) + annotate.output_tensors("observation_manager", output, export_with="torch", use_trace=True) if isinstance(output, TracedTensor): return output.tensor else: @@ -175,7 +172,7 @@ def annotate_action_manager(): original_process_action = ActionManager.process_action def patched_process_action(self, action: torch.Tensor): - action = annotate.input_tensors({"actions": action}, node_name='action_manager') + action = annotate.input_tensors({"actions": action}, node_name="action_manager") original_process_action(self, action) annotate.mirror_leapp_tags(action, self._action) tensors = {} @@ -185,26 +182,26 @@ def patched_process_action(self, action: torch.Tensor): tensors[term_name] = term_.processed_actions # Collect static values (kp/kd gains) if available - asset = getattr(term_, '_asset', None) - if asset is not None and hasattr(asset, 'data'): + asset = getattr(term_, "_asset", None) + if asset is not None and hasattr(asset, "data"): data = asset.data - joint_ids = getattr(term_, '_joint_ids', None) + joint_ids = getattr(term_, "_joint_ids", None) # Get default_joint_stiffness (kp gains) - if hasattr(data, 'default_joint_stiffness') and data.default_joint_stiffness is not None: + if hasattr(data, "default_joint_stiffness") and data.default_joint_stiffness is not None: if joint_ids is not None: static_values[f"{term_name}_kp_gains"] = data.default_joint_stiffness[:, joint_ids] else: static_values[f"{term_name}_kp_gains"] = data.default_joint_stiffness # Get default_joint_damping (kd gains) - if hasattr(data, 'default_joint_damping') and data.default_joint_damping is not None: + if hasattr(data, "default_joint_damping") and data.default_joint_damping is not None: if joint_ids is not None: static_values[f"{term_name}_kd_gains"] = data.default_joint_damping[:, joint_ids] else: static_values[f"{term_name}_kd_gains"] = data.default_joint_damping - annotate.output_tensors('action_manager', tensors, static_outputs=static_values, export_with='torch') + annotate.output_tensors("action_manager", tensors, static_outputs=static_values, export_with="torch") ActionManager.process_action = patched_process_action diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 316daa175207..a2a6a29dc5dc 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -1,3 +1,8 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + """Script to export a checkpoint if an RL agent from RSL-RL.""" """Launch Isaac Sim Simulator first.""" @@ -5,12 +10,13 @@ import argparse import sys +from leapp import annotate + from isaaclab.app import AppLauncher # local imports import cli_args # isort: skip -from leapp import annotate # add argparse arguments parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.") @@ -54,6 +60,9 @@ import time import torch +# IMPORTANT: Add leapp annotations BEFORE importing isaaclab_tasks +# This ensures the patched functions are captured when configs are created +from annotate_functions_for_export import add_leapp_annotations from rsl_rl.runners import DistillationRunner, OnPolicyRunner from isaaclab.envs import ( @@ -64,14 +73,10 @@ multi_agent_to_single_agent, ) from isaaclab.utils.assets import retrieve_file_path -from isaaclab.utils.dict import print_dict from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx from isaaclab_rl.utils.pretrained_checkpoint import get_published_pretrained_checkpoint -# IMPORTANT: Add leapp annotations BEFORE importing isaaclab_tasks -# This ensures the patched functions are captured when configs are created -from annotate_functions_for_export import add_leapp_annotations add_leapp_annotations() import isaaclab_tasks # noqa: F401 @@ -175,8 +180,12 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # run everything in inference mode with torch.inference_mode(): # agent stepping - with annotate.block('policy', inputs=['obs'], outputs=['actions'], - backend_params={'model_path': onnx_path, 'copy_original_model': True}): + with annotate.block( + "policy", + inputs=["obs"], + outputs=["actions"], + backend_params={"model_path": onnx_path, "copy_original_model": True}, + ): actions = policy(obs) # env stepping obs, _, _, _ = env.step(actions) From a69f55f612c96380d820847263350c365bc86473 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Wed, 14 Jan 2026 17:51:20 -0800 Subject: [PATCH 04/23] added static buffers to graph. we can now trace class tensor variables that copy data using [:] --- .../rsl_rl/annotate_functions_for_export.py | 60 +++++++++++++++++-- .../envs/mdp/actions/joint_actions.py | 2 +- 2 files changed, 56 insertions(+), 6 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py index edf8a7c859f7..9714e3044504 100644 --- a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py +++ b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py @@ -9,6 +9,7 @@ from leapp.leapp_graph.traced_tensor import TracedTensor from isaaclab.assets.articulation.articulation_data import ArticulationData +from isaaclab.controllers.operational_space import OperationalSpaceController from isaaclab.envs.mdp import observations from isaaclab.managers.action_manager import ActionManager from isaaclab.managers.observation_manager import ObservationManager @@ -44,7 +45,6 @@ def _setup_articulation_data_annotations(): # Joint state "joint_pos", # joint_pos, joint_pos_rel, joint_pos_limit_normalized "joint_vel", # joint_vel, joint_vel_rel - "applied_torque", # joint_effort } for prop_name in observation_properties: @@ -162,17 +162,57 @@ def annotate_action_manager(): """ Patches ActionManager.process_action to annotate action inputs/outputs. - Also collects static values (default_joint_stiffness and default_joint_damping) - from action terms that have them. + Also patches OperationalSpaceController.set_command for variable impedance tracing. + + Collects static values (default_joint_stiffness and default_joint_damping) + from action terms that have them. For variable impedance controllers, the gains + are captured as dynamic outputs instead. IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. """ + # Patch OperationalSpaceController.set_command for variable impedance modes + # This must be done before ActionManager.process_action is patched + original_osc_set_command = OperationalSpaceController.set_command + + def patched_osc_set_command( + self, + command: torch.Tensor, + current_ee_pose_b: torch.Tensor | None = None, + current_task_frame_pose_b: torch.Tensor | None = None, + ): + # For variable impedance modes, register gain buffers before in-place assignment + if self.cfg.impedance_mode in ["variable_kp", "variable"]: + # Register the gain tensors as buffers so in-place assignment is traced + self._motion_p_gains_task, self._motion_d_gains_task = annotate.register_buffer( + "action_manager", + { + "motion_p_gains_task": self._motion_p_gains_task, + "motion_d_gains_task": self._motion_d_gains_task, + }, + ) + + # Call original set_command - in-place assignments will now be traced + return original_osc_set_command(self, command, current_ee_pose_b, current_task_frame_pose_b) + + OperationalSpaceController.set_command = patched_osc_set_command + # Patch ActionManager.process_action at class level original_process_action = ActionManager.process_action def patched_process_action(self, action: torch.Tensor): action = annotate.input_tensors({"actions": action}, node_name="action_manager") + + # Register _raw_actions buffers for each action term before processing + # This enables tracing through in-place assignments like: self._raw_actions[:] = actions + for term_name, term_ in self._terms.items(): + if hasattr(term_, "_raw_actions") and term_._raw_actions is not None: + buffers = annotate.register_buffer( + "action_manager", + {"raw_actions": term_._raw_actions}, + ) + term_._raw_actions = buffers["raw_actions"] + original_process_action(self, action) annotate.mirror_leapp_tags(action, self._action) tensors = {} @@ -181,6 +221,16 @@ def patched_process_action(self, action: torch.Tensor): for term_name, term_ in self._terms.items(): tensors[term_name] = term_.processed_actions + # Check for dynamic gains from OperationalSpaceControllerAction + osc = getattr(term_, "_osc", None) + if osc is not None and hasattr(osc, "cfg"): + if osc.cfg.impedance_mode in ["variable", "variable_kp"]: + # Dynamic gains - these are now traced due to register_buffer + # Extract diagonal elements (the actual gain values) + tensors[f"{term_name}_kp_gains"] = torch.diagonal(osc._motion_p_gains_task, dim1=-2, dim2=-1) + tensors[f"{term_name}_kd_gains"] = torch.diagonal(osc._motion_d_gains_task, dim1=-2, dim2=-1) + continue # Skip static gain collection for this term + # Collect static values (kp/kd gains) if available asset = getattr(term_, "_asset", None) if asset is not None and hasattr(asset, "data"): @@ -205,7 +255,7 @@ def patched_process_action(self, action: torch.Tensor): ActionManager.process_action = patched_process_action - print("Patched action manager: ActionManager.process_action") + print("Patched action manager: ActionManager.process_action, OperationalSpaceController.set_command") def add_leapp_annotations(): @@ -214,7 +264,7 @@ def add_leapp_annotations(): This is the main entry point that patches: - ObservationManager and related observation functions - - ActionManager.process_action + - ActionManager.process_action (includes OperationalSpaceController.set_command) IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. """ diff --git a/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py b/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py index dbddeb8bc3a5..13e062119a75 100644 --- a/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py +++ b/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py @@ -169,7 +169,7 @@ def process_actions(self, actions: torch.Tensor): # store the raw actions self._raw_actions[:] = actions # apply the affine transformations - self._processed_actions = actions * self._scale + self._offset + self._processed_actions = self._raw_actions[:] * self._scale + self._offset # clip actions if self.cfg.clip is not None: self._processed_actions = torch.clamp( From 74dbb8fdaab0ae996b0a82c1e36e8932e3b7973f Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Thu, 15 Jan 2026 13:30:04 -0800 Subject: [PATCH 05/23] removed random observation corruption for during export --- .../rsl_rl/annotate_functions_for_export.py | 40 +++++++++++++++++++ .../reinforcement_learning/rsl_rl/export.py | 2 +- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py index 9714e3044504..d9cdb942fd76 100644 --- a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py +++ b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py @@ -97,6 +97,44 @@ def _remove_articulation_annotations(): setattr(ArticulationData, prop_name, original_prop) +def configure_for_export(): + """ + Configures the environment managers for deterministic export. + + This patches ObservationManager to disable noise/corruption during export. + Random operations like torch.rand_like in noise models cause validation + failures because they produce different values on each run. + + IMPORTANT: Must be called BEFORE isaaclab_tasks is imported and before + annotate_observation_manager() and annotate_action_manager(). + """ + from isaaclab.managers.manager_term_cfg import ObservationGroupCfg + + # Patch ObservationManager._prepare_terms to force disable noise + # This ensures deterministic outputs for export validation + original_prepare_terms = ObservationManager._prepare_terms + + def patched_prepare_terms(self): + # Force disable corruption on all observation groups before preparing terms + # Iterate over config items the same way _prepare_terms does + if isinstance(self.cfg, dict): + group_cfg_items = self.cfg.items() + else: + group_cfg_items = self.cfg.__dict__.items() + + for group_name, group_cfg in group_cfg_items: + if group_cfg is None: + continue + if isinstance(group_cfg, ObservationGroupCfg): + group_cfg.enable_corruption = False + + # Call original _prepare_terms + return original_prepare_terms(self) + + ObservationManager._prepare_terms = patched_prepare_terms + print("Configured for export: Disabled observation noise/corruption for deterministic export") + + def annotate_observation_manager(): """ Patches observation-related functions and classes to annotate inputs/outputs. @@ -268,6 +306,8 @@ def add_leapp_annotations(): IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. """ + # Configure for deterministic export first (disables noise/corruption) + configure_for_export() annotate_observation_manager() annotate_action_manager() print("All leapp annotations added") diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index a2a6a29dc5dc..8278a4dad233 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -191,7 +191,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen obs, _, _, _ = env.step(actions) annotate.stop() - annotate.compile_graph(validate=False) + annotate.compile_graph() # close the simulator env.close() From a83a6004bc6f79680af692f6c7fa85975ccb2f4e Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Thu, 15 Jan 2026 16:41:28 -0800 Subject: [PATCH 06/23] added io descriptor info into leapp generated yaml --- .../rsl_rl/annotate_functions_for_export.py | 23 ++++++++++-- .../reinforcement_learning/rsl_rl/export.py | 37 +++++++++++++++++++ 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py index d9cdb942fd76..70c3d1351e53 100644 --- a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py +++ b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: BSD-3-Clause +import inspect import torch from leapp import annotate @@ -153,19 +154,33 @@ def annotate_observation_manager(): # Patch last_action observation function original_last_action = observations.last_action - def patched_last_action(env, action_name=None): - result = original_last_action(env, action_name) + def patched_last_action(env, action_name=None, **kwargs): + # Pass through kwargs (including 'inspect' for IO descriptors) + result = original_last_action(env, action_name, **kwargs) result = annotate.input_tensors({"last_actions": result}, node_name="observation_manager") return result + # Preserve original signature and IO descriptor to pass manager validation checks + patched_last_action.__signature__ = inspect.signature(original_last_action) + if hasattr(original_last_action, "_descriptor"): + patched_last_action._descriptor = original_last_action._descriptor + patched_last_action._has_descriptor = original_last_action._has_descriptor + # Patch generated_commands observation function original_generated_commands = observations.generated_commands - def patched_generated_commands(env, command_name=None): - result = original_generated_commands(env, command_name) + def patched_generated_commands(env, command_name=None, **kwargs): + # Pass through kwargs (including 'inspect' for IO descriptors) + result = original_generated_commands(env, command_name, **kwargs) result = annotate.input_tensors({"commands": result}, node_name="observation_manager") return result + # Preserve original signature and IO descriptor to pass manager validation checks + patched_generated_commands.__signature__ = inspect.signature(original_generated_commands) + if hasattr(original_generated_commands, "_descriptor"): + patched_generated_commands._descriptor = original_generated_commands._descriptor + patched_generated_commands._has_descriptor = original_generated_commands._has_descriptor + # Apply observation function patches at module level # Note: Observation functions that use ArticulationData properties (base_pos_z, root_pos_w, # root_quat_w, body_projected_gravity_b) don't need patching since the underlying diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 8278a4dad233..7ce72004251a 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -9,6 +9,7 @@ import argparse import sys +import yaml from leapp import annotate @@ -193,6 +194,42 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen annotate.stop() annotate.compile_graph() + outs = env.unwrapped.get_IO_descriptors + # only export the policy observations + out_observations = outs["observations"]["policy"] + out_actions = outs["actions"] + out_scene = outs["scene"] + + observations = [] + for k in out_observations: + observation = { + "name": k["name"], + "full_path": k["full_path"], + } + if "joint_names" in k: + observation["joint_names"] = k["joint_names"] + if "units" in k["extras"]: + observation["units"] = k["extras"]["units"] + observations.append(observation) + + actions = [] + for k in out_actions: + action = { + "name": k["name"], + "full_path": k["full_path"], + } + if "joint_names" in k: + action["joint_names"] = k["joint_names"] + actions.append(action) + semantic = { + "observations": observations, + "actions": actions, + "scene": out_scene, + } + + with open(annotate.config_path, "a") as f: + yaml.dump({"semantic": semantic}, f) + # close the simulator env.close() From d6584290e1cfc3db6accc720954835ac6709acf4 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Thu, 15 Jan 2026 17:12:30 -0800 Subject: [PATCH 07/23] ugly implementation for automatic mapping from io descriptors to leapp --- .../rsl_rl/annotate_functions_for_export.py | 65 ++++++++++++++++++- .../reinforcement_learning/rsl_rl/export.py | 16 +++-- 2 files changed, 74 insertions(+), 7 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py index 70c3d1351e53..2b60031df0bc 100644 --- a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py +++ b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py @@ -19,6 +19,54 @@ _articulation_data_originals = {} _articulation_data_annotating = {} +# Global mapping built during execution: observation_function_name -> [articulation_property_names] +OBSERVATION_TO_ARTICULATION_MAP: dict[str, list[str]] = {} + + +def _record_articulation_access(observation_name: str, articulation_property: str): + """Record that an observation function accessed an ArticulationData property.""" + if observation_name not in OBSERVATION_TO_ARTICULATION_MAP: + OBSERVATION_TO_ARTICULATION_MAP[observation_name] = [] + if articulation_property not in OBSERVATION_TO_ARTICULATION_MAP[observation_name]: + OBSERVATION_TO_ARTICULATION_MAP[observation_name].append(articulation_property) + + +def get_observation_to_articulation_map() -> dict[str, list[str]]: + """Get a copy of the observation-to-articulation mapping. + + Returns: + A dictionary mapping observation function names to lists of + ArticulationData property names (leapp input names) they access. + + Example: + { + 'base_lin_vel': ['root_lin_vel_b'], + 'joint_pos_rel': ['joint_pos'], + 'last_action': ['last_actions'], + 'generated_commands': ['commands'], + } + """ + return OBSERVATION_TO_ARTICULATION_MAP.copy() + + +def _find_calling_observation_function() -> str | None: + """Walk up the call stack to find the observation function that triggered this access.""" + for frame_info in inspect.stack(): + # Look for frames in the observations module + if "isaaclab/envs/mdp/observations" in frame_info.filename: + func_name = frame_info.function + # Skip internal/wrapper functions + if not func_name.startswith("_"): + return func_name + + # Also check for custom observation functions in user code + # Could look for functions with _has_descriptor attribute + frame_locals = frame_info.frame.f_locals + if "self" in frame_locals and hasattr(frame_locals.get("self"), "_has_descriptor"): + return frame_info.function + + return None + def _setup_articulation_data_annotations(): """ @@ -65,13 +113,19 @@ def _setup_articulation_data_annotations(): # Create annotating getter original_fget = attr.fget - def make_annotating_fget(original, name): + def make_annotating_fget(original, prop_name): """Factory function to properly capture variables in closure.""" def annotating_fget(self): result = original(self) + + # Find which observation function called us and record the mapping + observation_name = _find_calling_observation_function() + if observation_name: + _record_articulation_access(observation_name, prop_name) + if isinstance(result, torch.Tensor): - result = annotate.input_tensors({name: result}, node_name="observation_manager") + result = annotate.input_tensors({prop_name: result}, node_name="observation_manager") return result return annotating_fget @@ -157,6 +211,8 @@ def annotate_observation_manager(): def patched_last_action(env, action_name=None, **kwargs): # Pass through kwargs (including 'inspect' for IO descriptors) result = original_last_action(env, action_name, **kwargs) + # Record the mapping for this custom observation + _record_articulation_access("last_action", "last_actions") result = annotate.input_tensors({"last_actions": result}, node_name="observation_manager") return result @@ -172,7 +228,10 @@ def patched_last_action(env, action_name=None, **kwargs): def patched_generated_commands(env, command_name=None, **kwargs): # Pass through kwargs (including 'inspect' for IO descriptors) result = original_generated_commands(env, command_name, **kwargs) - result = annotate.input_tensors({"commands": result}, node_name="observation_manager") + # Record the mapping for this custom observation (use command_name as the leapp input name) + leapp_input_name = command_name if command_name else "commands" + _record_articulation_access("generated_commands", leapp_input_name) + result = annotate.input_tensors({leapp_input_name: result}, node_name="observation_manager") return result # Preserve original signature and IO descriptor to pass manager validation checks diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 7ce72004251a..4f21df94d076 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -63,7 +63,7 @@ # IMPORTANT: Add leapp annotations BEFORE importing isaaclab_tasks # This ensures the patched functions are captured when configs are created -from annotate_functions_for_export import add_leapp_annotations +from annotate_functions_for_export import add_leapp_annotations, get_observation_to_articulation_map from rsl_rl.runners import DistillationRunner, OnPolicyRunner from isaaclab.envs import ( @@ -200,12 +200,18 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen out_actions = outs["actions"] out_scene = outs["scene"] + # Get the auto-discovered mapping from observation functions to leapp inputs + obs_to_leapp_map = get_observation_to_articulation_map() + observations = [] for k in out_observations: + obs_name = k["name"] observation = { - "name": k["name"], - "full_path": k["full_path"], + "name": obs_name, } + # Add the leapp input names this observation maps to (copy list to avoid YAML anchors) + if obs_name in obs_to_leapp_map: + observation["leapp_inputs"] = list(obs_to_leapp_map[obs_name]) if "joint_names" in k: observation["joint_names"] = k["joint_names"] if "units" in k["extras"]: @@ -216,11 +222,13 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen for k in out_actions: action = { "name": k["name"], - "full_path": k["full_path"], } if "joint_names" in k: action["joint_names"] = k["joint_names"] + if "units" in k["extras"]: + observation["units"] = k["extras"]["units"] actions.append(action) + semantic = { "observations": observations, "actions": actions, From 7a1a0462f20121152404e0b24ef8a9d2055c108c Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Fri, 16 Jan 2026 11:05:27 -0800 Subject: [PATCH 08/23] added action mapping for io_descriptor output --- .../rsl_rl/annotate_functions_for_export.py | 51 +++++++++++++++++-- .../reinforcement_learning/rsl_rl/export.py | 15 ++++-- 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py index 2b60031df0bc..e335f1220b4e 100644 --- a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py +++ b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py @@ -20,15 +20,28 @@ _articulation_data_annotating = {} # Global mapping built during execution: observation_function_name -> [articulation_property_names] -OBSERVATION_TO_ARTICULATION_MAP: dict[str, list[str]] = {} +OBSERVATION_TO_ARTICULATION_MAP: dict[str, set[str]] = {} + +# Global mapping built during execution: io_descriptor_name -> term_name (leapp output name) +# Maps the semantic action name (from IO descriptor) to the action_manager output name +ACTION_IO_TO_TERM_MAP: dict[str, str] = {} def _record_articulation_access(observation_name: str, articulation_property: str): """Record that an observation function accessed an ArticulationData property.""" if observation_name not in OBSERVATION_TO_ARTICULATION_MAP: - OBSERVATION_TO_ARTICULATION_MAP[observation_name] = [] - if articulation_property not in OBSERVATION_TO_ARTICULATION_MAP[observation_name]: - OBSERVATION_TO_ARTICULATION_MAP[observation_name].append(articulation_property) + OBSERVATION_TO_ARTICULATION_MAP[observation_name] = set() + OBSERVATION_TO_ARTICULATION_MAP[observation_name].add(articulation_property) + + +def _record_action_io_mapping(io_descriptor_name: str, term_name: str): + """Record the mapping from IO descriptor name to action term name. + + Args: + io_descriptor_name: The name from the action's IO descriptor (e.g., 'joint_position_action') + term_name: The action term name used in action_manager outputs (e.g., 'arm_action') + """ + ACTION_IO_TO_TERM_MAP[io_descriptor_name] = [term_name] def get_observation_to_articulation_map() -> dict[str, list[str]]: @@ -46,7 +59,25 @@ def get_observation_to_articulation_map() -> dict[str, list[str]]: 'generated_commands': ['commands'], } """ - return OBSERVATION_TO_ARTICULATION_MAP.copy() + return {k: list(v) for k, v in OBSERVATION_TO_ARTICULATION_MAP.items()} + + +def get_action_io_to_term_map() -> dict[str, str]: + """Get a copy of the action IO descriptor to term name mapping. + + Returns: + A dictionary mapping action IO descriptor names to action term names + (which are used as leapp output names in action_manager). + + Example: + { + 'joint_position_action': 'arm_action', + } + + This allows the semantic.actions section in the YAML to map: + - name: joint_position_action -> leapp_outputs: [arm_action] + """ + return ACTION_IO_TO_TERM_MAP.copy() def _find_calling_observation_function() -> str | None: @@ -325,7 +356,9 @@ def patched_process_action(self, action: torch.Tensor): ) term_._raw_actions = buffers["raw_actions"] + # run the original process action original_process_action(self, action) + annotate.mirror_leapp_tags(action, self._action) tensors = {} static_values = {} @@ -333,6 +366,14 @@ def patched_process_action(self, action: torch.Tensor): for term_name, term_ in self._terms.items(): tensors[term_name] = term_.processed_actions + # Record the mapping: IO descriptor name -> term name (leapp output) + # e.g., 'joint_position_action' -> 'arm_action' + try: + io_descriptor_name = term_.IO_descriptor.name + _record_action_io_mapping(io_descriptor_name, term_name) + except Exception: + pass # Skip if IO descriptor is not available + # Check for dynamic gains from OperationalSpaceControllerAction osc = getattr(term_, "_osc", None) if osc is not None and hasattr(osc, "cfg"): diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 4f21df94d076..3c75e93ca104 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -63,7 +63,11 @@ # IMPORTANT: Add leapp annotations BEFORE importing isaaclab_tasks # This ensures the patched functions are captured when configs are created -from annotate_functions_for_export import add_leapp_annotations, get_observation_to_articulation_map +from annotate_functions_for_export import ( + add_leapp_annotations, + get_action_io_to_term_map, + get_observation_to_articulation_map, +) from rsl_rl.runners import DistillationRunner, OnPolicyRunner from isaaclab.envs import ( @@ -78,8 +82,6 @@ from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx from isaaclab_rl.utils.pretrained_checkpoint import get_published_pretrained_checkpoint -add_leapp_annotations() - import isaaclab_tasks # noqa: F401 from isaaclab_tasks.utils import get_checkpoint_path from isaaclab_tasks.utils.hydra import hydra_task_config @@ -88,6 +90,7 @@ @hydra_task_config(args_cli.task, args_cli.agent) def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg): """Export a RSL-RL agent.""" + add_leapp_annotations() # grab task name for checkpoint path task_name = args_cli.task.split(":")[-1] train_task_name = task_name.replace("-Play", "") @@ -202,6 +205,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # Get the auto-discovered mapping from observation functions to leapp inputs obs_to_leapp_map = get_observation_to_articulation_map() + action_to_leapp_map = get_action_io_to_term_map() observations = [] for k in out_observations: @@ -220,9 +224,12 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen actions = [] for k in out_actions: + action_name = k["name"] action = { - "name": k["name"], + "name": action_name, } + if action_name in action_to_leapp_map: + action["leapp_inputs"] = list(action_to_leapp_map[action_name]) if "joint_names" in k: action["joint_names"] = k["joint_names"] if "units" in k["extras"]: From cc69ff0e98c3e335569d358c41fc54d9a01b2829 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Fri, 16 Jan 2026 13:37:39 -0800 Subject: [PATCH 09/23] cleaned up utility to annotate environment, added protection for using the same observation multiple times --- .../rsl_rl/annotate_functions_for_export.py | 428 ------------------ .../reinforcement_learning/rsl_rl/export.py | 64 +-- .../rsl_rl/export_annotator.py | 415 +++++++++++++++++ 3 files changed, 427 insertions(+), 480 deletions(-) delete mode 100644 scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py create mode 100644 scripts/reinforcement_learning/rsl_rl/export_annotator.py diff --git a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py b/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py deleted file mode 100644 index e335f1220b4e..000000000000 --- a/scripts/reinforcement_learning/rsl_rl/annotate_functions_for_export.py +++ /dev/null @@ -1,428 +0,0 @@ -# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -import inspect -import torch - -from leapp import annotate -from leapp.leapp_graph.traced_tensor import TracedTensor - -from isaaclab.assets.articulation.articulation_data import ArticulationData -from isaaclab.controllers.operational_space import OperationalSpaceController -from isaaclab.envs.mdp import observations -from isaaclab.managers.action_manager import ActionManager -from isaaclab.managers.observation_manager import ObservationManager - -# Global storage for original and annotating ArticulationData properties -_articulation_data_originals = {} -_articulation_data_annotating = {} - -# Global mapping built during execution: observation_function_name -> [articulation_property_names] -OBSERVATION_TO_ARTICULATION_MAP: dict[str, set[str]] = {} - -# Global mapping built during execution: io_descriptor_name -> term_name (leapp output name) -# Maps the semantic action name (from IO descriptor) to the action_manager output name -ACTION_IO_TO_TERM_MAP: dict[str, str] = {} - - -def _record_articulation_access(observation_name: str, articulation_property: str): - """Record that an observation function accessed an ArticulationData property.""" - if observation_name not in OBSERVATION_TO_ARTICULATION_MAP: - OBSERVATION_TO_ARTICULATION_MAP[observation_name] = set() - OBSERVATION_TO_ARTICULATION_MAP[observation_name].add(articulation_property) - - -def _record_action_io_mapping(io_descriptor_name: str, term_name: str): - """Record the mapping from IO descriptor name to action term name. - - Args: - io_descriptor_name: The name from the action's IO descriptor (e.g., 'joint_position_action') - term_name: The action term name used in action_manager outputs (e.g., 'arm_action') - """ - ACTION_IO_TO_TERM_MAP[io_descriptor_name] = [term_name] - - -def get_observation_to_articulation_map() -> dict[str, list[str]]: - """Get a copy of the observation-to-articulation mapping. - - Returns: - A dictionary mapping observation function names to lists of - ArticulationData property names (leapp input names) they access. - - Example: - { - 'base_lin_vel': ['root_lin_vel_b'], - 'joint_pos_rel': ['joint_pos'], - 'last_action': ['last_actions'], - 'generated_commands': ['commands'], - } - """ - return {k: list(v) for k, v in OBSERVATION_TO_ARTICULATION_MAP.items()} - - -def get_action_io_to_term_map() -> dict[str, str]: - """Get a copy of the action IO descriptor to term name mapping. - - Returns: - A dictionary mapping action IO descriptor names to action term names - (which are used as leapp output names in action_manager). - - Example: - { - 'joint_position_action': 'arm_action', - } - - This allows the semantic.actions section in the YAML to map: - - name: joint_position_action -> leapp_outputs: [arm_action] - """ - return ACTION_IO_TO_TERM_MAP.copy() - - -def _find_calling_observation_function() -> str | None: - """Walk up the call stack to find the observation function that triggered this access.""" - for frame_info in inspect.stack(): - # Look for frames in the observations module - if "isaaclab/envs/mdp/observations" in frame_info.filename: - func_name = frame_info.function - # Skip internal/wrapper functions - if not func_name.startswith("_"): - return func_name - - # Also check for custom observation functions in user code - # Could look for functions with _has_descriptor attribute - frame_locals = frame_info.frame.f_locals - if "self" in frame_locals and hasattr(frame_locals.get("self"), "_has_descriptor"): - return frame_info.function - - return None - - -def _setup_articulation_data_annotations(): - """ - Prepares annotating versions of ArticulationData properties without applying them. - - The annotations will be temporarily applied only during compute_group calls, - avoiding conflicts with rewards, terminations, commands, and actions that also - access these properties. - """ - - # All observation properties - we can include all of them now since annotations - # are only active during compute_group - observation_properties = { - # Root state (position, orientation, velocities) - "root_pos_w", # base_pos_z, root_pos_w - "root_quat_w", # root_quat_w - "root_lin_vel_b", # base_lin_vel - "root_ang_vel_b", # base_ang_vel - "root_lin_vel_w", # root_lin_vel_w - "root_ang_vel_w", # root_ang_vel_w - "projected_gravity_b", # 'projected_gravity_b', - # Body state - "body_pose_w", # body_pose_w - "body_quat_w", # body_projected_gravity_b - # Joint state - "joint_pos", # joint_pos, joint_pos_rel, joint_pos_limit_normalized - "joint_vel", # joint_vel, joint_vel_rel - } - - for prop_name in observation_properties: - attr = getattr(ArticulationData, prop_name, None) - - # Skip if attribute doesn't exist or isn't a property - if attr is None or not isinstance(attr, property): - raise ValueError(f"Attribute {prop_name} does not exist or is not a property") - - # Skip properties without a getter - if attr.fget is None: - raise ValueError(f"Attribute {prop_name} does not have a getter") - - # Store the original property - _articulation_data_originals[prop_name] = attr - - # Create annotating getter - original_fget = attr.fget - - def make_annotating_fget(original, prop_name): - """Factory function to properly capture variables in closure.""" - - def annotating_fget(self): - result = original(self) - - # Find which observation function called us and record the mapping - observation_name = _find_calling_observation_function() - if observation_name: - _record_articulation_access(observation_name, prop_name) - - if isinstance(result, torch.Tensor): - result = annotate.input_tensors({prop_name: result}, node_name="observation_manager") - return result - - return annotating_fget - - annotating_fget = make_annotating_fget(original_fget, prop_name) - annotating_fget.__doc__ = original_fget.__doc__ - - # Create annotating property - annotating_property = property(fget=annotating_fget, fset=attr.fset, fdel=attr.fdel, doc=attr.__doc__) - _articulation_data_annotating[prop_name] = annotating_property - - print(f"Prepared {len(_articulation_data_originals)} ArticulationData properties for temporary annotation") - - -def _apply_articulation_annotations(): - """Temporarily applies annotating versions of ArticulationData properties.""" - for prop_name, annotating_prop in _articulation_data_annotating.items(): - setattr(ArticulationData, prop_name, annotating_prop) - - -def _remove_articulation_annotations(): - """Restores original ArticulationData properties.""" - for prop_name, original_prop in _articulation_data_originals.items(): - setattr(ArticulationData, prop_name, original_prop) - - -def configure_for_export(): - """ - Configures the environment managers for deterministic export. - - This patches ObservationManager to disable noise/corruption during export. - Random operations like torch.rand_like in noise models cause validation - failures because they produce different values on each run. - - IMPORTANT: Must be called BEFORE isaaclab_tasks is imported and before - annotate_observation_manager() and annotate_action_manager(). - """ - from isaaclab.managers.manager_term_cfg import ObservationGroupCfg - - # Patch ObservationManager._prepare_terms to force disable noise - # This ensures deterministic outputs for export validation - original_prepare_terms = ObservationManager._prepare_terms - - def patched_prepare_terms(self): - # Force disable corruption on all observation groups before preparing terms - # Iterate over config items the same way _prepare_terms does - if isinstance(self.cfg, dict): - group_cfg_items = self.cfg.items() - else: - group_cfg_items = self.cfg.__dict__.items() - - for group_name, group_cfg in group_cfg_items: - if group_cfg is None: - continue - if isinstance(group_cfg, ObservationGroupCfg): - group_cfg.enable_corruption = False - - # Call original _prepare_terms - return original_prepare_terms(self) - - ObservationManager._prepare_terms = patched_prepare_terms - print("Configured for export: Disabled observation noise/corruption for deterministic export") - - -def annotate_observation_manager(): - """ - Patches observation-related functions and classes to annotate inputs/outputs. - - This patches: - - ArticulationData properties (temporarily, only during compute_group) - - Observation functions at the module level (last_action, generated_commands, etc.) - - ObservationManager.compute_group to annotate outputs - - IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. - """ - - # Prepare (but don't apply) ArticulationData annotations - _setup_articulation_data_annotations() - - # Patch last_action observation function - original_last_action = observations.last_action - - def patched_last_action(env, action_name=None, **kwargs): - # Pass through kwargs (including 'inspect' for IO descriptors) - result = original_last_action(env, action_name, **kwargs) - # Record the mapping for this custom observation - _record_articulation_access("last_action", "last_actions") - result = annotate.input_tensors({"last_actions": result}, node_name="observation_manager") - return result - - # Preserve original signature and IO descriptor to pass manager validation checks - patched_last_action.__signature__ = inspect.signature(original_last_action) - if hasattr(original_last_action, "_descriptor"): - patched_last_action._descriptor = original_last_action._descriptor - patched_last_action._has_descriptor = original_last_action._has_descriptor - - # Patch generated_commands observation function - original_generated_commands = observations.generated_commands - - def patched_generated_commands(env, command_name=None, **kwargs): - # Pass through kwargs (including 'inspect' for IO descriptors) - result = original_generated_commands(env, command_name, **kwargs) - # Record the mapping for this custom observation (use command_name as the leapp input name) - leapp_input_name = command_name if command_name else "commands" - _record_articulation_access("generated_commands", leapp_input_name) - result = annotate.input_tensors({leapp_input_name: result}, node_name="observation_manager") - return result - - # Preserve original signature and IO descriptor to pass manager validation checks - patched_generated_commands.__signature__ = inspect.signature(original_generated_commands) - if hasattr(original_generated_commands, "_descriptor"): - patched_generated_commands._descriptor = original_generated_commands._descriptor - patched_generated_commands._has_descriptor = original_generated_commands._has_descriptor - - # Apply observation function patches at module level - # Note: Observation functions that use ArticulationData properties (base_pos_z, root_pos_w, - # root_quat_w, body_projected_gravity_b) don't need patching since the underlying - # ArticulationData properties are temporarily annotated during compute_group. - observations.last_action = patched_last_action - observations.generated_commands = patched_generated_commands - - # Patch ObservationManager.compute_group to: - # 1. Temporarily apply ArticulationData annotations before computing - # 2. Annotate outputs - # 3. Restore original ArticulationData properties after computing - original_compute_group = ObservationManager.compute_group - - def patched_compute_group(self, *args, **kwargs): - # Apply ArticulationData annotations only during observation computation - _apply_articulation_annotations() - try: - output = original_compute_group(self, *args, **kwargs) - annotate.output_tensors("observation_manager", output, export_with="torch", use_trace=True) - if isinstance(output, TracedTensor): - return output.tensor - else: - return output - finally: - # Always restore original properties, even if an exception occurs - _remove_articulation_annotations() - - ObservationManager.compute_group = patched_compute_group - - -def annotate_action_manager(): - """ - Patches ActionManager.process_action to annotate action inputs/outputs. - - Also patches OperationalSpaceController.set_command for variable impedance tracing. - - Collects static values (default_joint_stiffness and default_joint_damping) - from action terms that have them. For variable impedance controllers, the gains - are captured as dynamic outputs instead. - - IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. - """ - - # Patch OperationalSpaceController.set_command for variable impedance modes - # This must be done before ActionManager.process_action is patched - original_osc_set_command = OperationalSpaceController.set_command - - def patched_osc_set_command( - self, - command: torch.Tensor, - current_ee_pose_b: torch.Tensor | None = None, - current_task_frame_pose_b: torch.Tensor | None = None, - ): - # For variable impedance modes, register gain buffers before in-place assignment - if self.cfg.impedance_mode in ["variable_kp", "variable"]: - # Register the gain tensors as buffers so in-place assignment is traced - self._motion_p_gains_task, self._motion_d_gains_task = annotate.register_buffer( - "action_manager", - { - "motion_p_gains_task": self._motion_p_gains_task, - "motion_d_gains_task": self._motion_d_gains_task, - }, - ) - - # Call original set_command - in-place assignments will now be traced - return original_osc_set_command(self, command, current_ee_pose_b, current_task_frame_pose_b) - - OperationalSpaceController.set_command = patched_osc_set_command - - # Patch ActionManager.process_action at class level - original_process_action = ActionManager.process_action - - def patched_process_action(self, action: torch.Tensor): - action = annotate.input_tensors({"actions": action}, node_name="action_manager") - - # Register _raw_actions buffers for each action term before processing - # This enables tracing through in-place assignments like: self._raw_actions[:] = actions - for term_name, term_ in self._terms.items(): - if hasattr(term_, "_raw_actions") and term_._raw_actions is not None: - buffers = annotate.register_buffer( - "action_manager", - {"raw_actions": term_._raw_actions}, - ) - term_._raw_actions = buffers["raw_actions"] - - # run the original process action - original_process_action(self, action) - - annotate.mirror_leapp_tags(action, self._action) - tensors = {} - static_values = {} - - for term_name, term_ in self._terms.items(): - tensors[term_name] = term_.processed_actions - - # Record the mapping: IO descriptor name -> term name (leapp output) - # e.g., 'joint_position_action' -> 'arm_action' - try: - io_descriptor_name = term_.IO_descriptor.name - _record_action_io_mapping(io_descriptor_name, term_name) - except Exception: - pass # Skip if IO descriptor is not available - - # Check for dynamic gains from OperationalSpaceControllerAction - osc = getattr(term_, "_osc", None) - if osc is not None and hasattr(osc, "cfg"): - if osc.cfg.impedance_mode in ["variable", "variable_kp"]: - # Dynamic gains - these are now traced due to register_buffer - # Extract diagonal elements (the actual gain values) - tensors[f"{term_name}_kp_gains"] = torch.diagonal(osc._motion_p_gains_task, dim1=-2, dim2=-1) - tensors[f"{term_name}_kd_gains"] = torch.diagonal(osc._motion_d_gains_task, dim1=-2, dim2=-1) - continue # Skip static gain collection for this term - - # Collect static values (kp/kd gains) if available - asset = getattr(term_, "_asset", None) - if asset is not None and hasattr(asset, "data"): - data = asset.data - joint_ids = getattr(term_, "_joint_ids", None) - - # Get default_joint_stiffness (kp gains) - if hasattr(data, "default_joint_stiffness") and data.default_joint_stiffness is not None: - if joint_ids is not None: - static_values[f"{term_name}_kp_gains"] = data.default_joint_stiffness[:, joint_ids] - else: - static_values[f"{term_name}_kp_gains"] = data.default_joint_stiffness - - # Get default_joint_damping (kd gains) - if hasattr(data, "default_joint_damping") and data.default_joint_damping is not None: - if joint_ids is not None: - static_values[f"{term_name}_kd_gains"] = data.default_joint_damping[:, joint_ids] - else: - static_values[f"{term_name}_kd_gains"] = data.default_joint_damping - - annotate.output_tensors("action_manager", tensors, static_outputs=static_values, export_with="torch") - - ActionManager.process_action = patched_process_action - - print("Patched action manager: ActionManager.process_action, OperationalSpaceController.set_command") - - -def add_leapp_annotations(): - """ - Adds all leapp annotations for exporting Isaac Lab policies. - - This is the main entry point that patches: - - ObservationManager and related observation functions - - ActionManager.process_action (includes OperationalSpaceController.set_command) - - IMPORTANT: Must be called BEFORE isaaclab_tasks is imported. - """ - # Configure for deterministic export first (disables noise/corruption) - configure_for_export() - annotate_observation_manager() - annotate_action_manager() - print("All leapp annotations added") diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 3c75e93ca104..c3ac56cda278 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -61,13 +61,7 @@ import time import torch -# IMPORTANT: Add leapp annotations BEFORE importing isaaclab_tasks -# This ensures the patched functions are captured when configs are created -from annotate_functions_for_export import ( - add_leapp_annotations, - get_action_io_to_term_map, - get_observation_to_articulation_map, -) +from export_annotator import ExportAnnotator from rsl_rl.runners import DistillationRunner, OnPolicyRunner from isaaclab.envs import ( @@ -86,11 +80,18 @@ from isaaclab_tasks.utils import get_checkpoint_path from isaaclab_tasks.utils.hydra import hydra_task_config +# IMPORTANT: Add leapp annotations BEFORE importing isaaclab_tasks +# This ensures the patched functions are captured when configs are created +# from annotate_functions_for_export import ( +# add_leapp_annotations, +# get_action_io_to_term_map, +# get_observation_to_articulation_map, +# ) + @hydra_task_config(args_cli.task, args_cli.agent) def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg): """Export a RSL-RL agent.""" - add_leapp_annotations() # grab task name for checkpoint path task_name = args_cli.task.split(":")[-1] train_task_name = task_name.replace("-Play", "") @@ -126,6 +127,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # create isaac environment # Note: observation functions are already patched at module level (before isaaclab_tasks import) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) + annotator = ExportAnnotator(env) + annotator.setup() # convert to single-agent instance if required by the RL algorithm if isinstance(env.unwrapped, DirectMARLEnv): @@ -197,50 +200,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen annotate.stop() annotate.compile_graph() - outs = env.unwrapped.get_IO_descriptors - # only export the policy observations - out_observations = outs["observations"]["policy"] - out_actions = outs["actions"] - out_scene = outs["scene"] - - # Get the auto-discovered mapping from observation functions to leapp inputs - obs_to_leapp_map = get_observation_to_articulation_map() - action_to_leapp_map = get_action_io_to_term_map() - - observations = [] - for k in out_observations: - obs_name = k["name"] - observation = { - "name": obs_name, - } - # Add the leapp input names this observation maps to (copy list to avoid YAML anchors) - if obs_name in obs_to_leapp_map: - observation["leapp_inputs"] = list(obs_to_leapp_map[obs_name]) - if "joint_names" in k: - observation["joint_names"] = k["joint_names"] - if "units" in k["extras"]: - observation["units"] = k["extras"]["units"] - observations.append(observation) - - actions = [] - for k in out_actions: - action_name = k["name"] - action = { - "name": action_name, - } - if action_name in action_to_leapp_map: - action["leapp_inputs"] = list(action_to_leapp_map[action_name]) - if "joint_names" in k: - action["joint_names"] = k["joint_names"] - if "units" in k["extras"]: - observation["units"] = k["extras"]["units"] - actions.append(action) - - semantic = { - "observations": observations, - "actions": actions, - "scene": out_scene, - } + semantic = annotator.get_semantic with open(annotate.config_path, "a") as f: yaml.dump({"semantic": semantic}, f) diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py new file mode 100644 index 000000000000..6541a60a842b --- /dev/null +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -0,0 +1,415 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright (c) 2022-2026, The Isaac Lab Project Developers +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + + +"""Export annotations for Isaac Lab policies using instance-level patching.""" + + +from __future__ import annotations + +import inspect +import torch +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from leapp import annotate +from leapp.leapp_graph.traced_tensor import TracedTensor + +from isaaclab.assets.articulation.articulation_data import ArticulationData + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + + +@dataclass +class ExportAnnotator: + """Encapsulates all leapp annotation logic for exporting Isaac Lab policies. + + Usage: + env = gym.make(...) + annotator = ExportAnnotator(env) + annotator.setup() + # ... run policy ... + obs_map = annotator.observation_to_articulation_map + action_map = annotator.action_io_to_term_map + annotator.cleanup() + """ + + env: ManagerBasedEnv + + io_descriptor_observations: list[Any] = field(default_factory=list) + io_descriptor_actions: list[Any] = field(default_factory=list) + io_descriptor_scene: dict[str, Any] = field(default_factory=dict) + + # Mappings built during execution + observation_to_articulation_map: dict[str, set[str]] = field(default_factory=dict) + action_io_to_term_map: dict[str, list[str]] = field(default_factory=dict) + + # Original methods for restoration + _original_compute_group: callable = field(default=None, repr=False) + _original_process_action: callable = field(default=None, repr=False) + + # ArticulationData patching state + _articulation_originals: dict[str, property] = field(default_factory=dict, repr=False) + _articulation_annotating: dict[str, property] = field(default_factory=dict, repr=False) + _annotations_active: bool = field(default=False, repr=False) + # Cache for annotated tensors within a single compute_group call + # Prevents duplicate input tensors when same property is accessed multiple times + _annotated_tensor_cache: dict[str, torch.Tensor] = field(default_factory=dict, repr=False) + + # Articulation properties to annotate + OBSERVATION_PROPERTIES: frozenset[str] = frozenset({ + "root_pos_w", + "root_quat_w", + "root_lin_vel_b", + "root_ang_vel_b", + "root_lin_vel_w", + "root_ang_vel_w", + "projected_gravity_b", + "body_pose_w", + "body_quat_w", + "joint_pos", + "joint_vel", + }) + + def setup(self): + """Set up all annotations. Call after env is created.""" + self._collect_io_descriptors() + self._disable_observation_noise() + self._prepare_articulation_annotations() + self._patch_observation_functions() + self._patch_observation_manager() + self._patch_action_manager() + + def cleanup(self): + """Restore all original methods and properties.""" + self._restore_observation_functions() + self._restore_observation_manager() + self._restore_action_manager() + self._remove_articulation_annotations() + + # ────────────────────────────────────────────────────────────────── + # IO Descriptor Collection (before patching) + # ────────────────────────────────────────────────────────────────── + + def _collect_io_descriptors(self): + outs = self.env.unwrapped.get_IO_descriptors + self.io_descriptor_observations = outs["observations"]["policy"] + self.io_descriptor_actions = outs["actions"] + self.io_descriptor_scene = outs["scene"] + + # Build action IO descriptor name -> term name mapping + # e.g., 'joint_position_action' -> ['arm_action'] + action_manager = self.env.env.unwrapped.action_manager + for term_name, term in action_manager._terms.items(): + try: + io_name = term.IO_descriptor.name + self.action_io_to_term_map[io_name] = [term_name] + except Exception: + pass # Skip if IO descriptor not available + + # ────────────────────────────────────────────────────────────────── + # Observation Manager + # ────────────────────────────────────────────────────────────────── + + def _disable_observation_noise(self): + """Disable noise/corruption for deterministic export. + + Since we patch after env creation, we need to set term_cfg.noise = None + directly on each term config (not just the group config). + """ + obs_manager = self.env.env.unwrapped.observation_manager + + # Disable noise on each individual term config + for _, term_cfgs in obs_manager._group_obs_term_cfgs.items(): + for term_cfg in term_cfgs: + term_cfg.noise = None + + def _patch_observation_functions(self): + """Patch observation functions inside the observation manager's term configs. + + These functions (last_action, generated_commands) don't access ArticulationData + properties, so they need separate patching to record their mappings and annotate + their outputs. + + We patch the term_cfg.func directly because the observation manager stores + references to these functions at creation time. + """ + obs_manager = self.env.env.unwrapped.observation_manager + + # Store original functions for restoration: (group_name, term_idx) -> original_func + self._original_obs_funcs: dict[tuple[str, int], callable] = {} + + for group_name, term_cfgs in obs_manager._group_obs_term_cfgs.items(): + for term_idx, term_cfg in enumerate(term_cfgs): + original_func = term_cfg.func + func_name = getattr(original_func, "__name__", None) + + if func_name == "last_action": + self._original_obs_funcs[(group_name, term_idx)] = original_func + term_cfg.func = self._make_patched_last_action(original_func) + + elif func_name == "generated_commands": + self._original_obs_funcs[(group_name, term_idx)] = original_func + term_cfg.func = self._make_patched_generated_commands(original_func, term_cfg) + + def _make_patched_last_action(self, original_func): + """Create a patched version of last_action that records mappings.""" + + def patched_last_action(env, action_name=None, **kwargs): + result = original_func(env, action_name, **kwargs) + self._record_articulation_access("last_action", "last_actions") + result = annotate.input_tensors({"last_actions": result}, node_name="observation_manager") + return result + + patched_last_action.__name__ = original_func.__name__ + return patched_last_action + + def _make_patched_generated_commands(self, original_func, term_cfg): + """Create a patched version of generated_commands that records mappings.""" + # Get the command_name from term_cfg.params if available + command_name_from_cfg = term_cfg.params.get("command_name") + + def patched_generated_commands(env, command_name=None, **kwargs): + result = original_func(env, command_name, **kwargs) + # Use command_name parameter, or fall back to config, or default + leapp_input_name = command_name or command_name_from_cfg or "commands" + self._record_articulation_access("generated_commands", leapp_input_name) + result = annotate.input_tensors({leapp_input_name: result}, node_name="observation_manager") + return result + + patched_generated_commands.__name__ = original_func.__name__ + return patched_generated_commands + + def _restore_observation_functions(self): + """Restore original observation functions in term configs.""" + if not hasattr(self, "_original_obs_funcs"): + return + + obs_manager = self.env.env.unwrapped.observation_manager + + for (group_name, term_idx), original_func in self._original_obs_funcs.items(): + obs_manager._group_obs_term_cfgs[group_name][term_idx].func = original_func + + def _patch_observation_manager(self): + """Patch the observation manager instance's compute_group method.""" + obs_manager = self.env.env.unwrapped.observation_manager + self._original_compute_group = obs_manager.compute_group + + def patched_compute_group(*args, **kwargs): + self._apply_articulation_annotations() + try: + output = self._original_compute_group(*args, **kwargs) + annotate.output_tensors("observation_manager", output, export_with="torch", use_trace=True) + return output.tensor if isinstance(output, TracedTensor) else output + finally: + self._remove_articulation_annotations() + + obs_manager.compute_group = patched_compute_group + + def _restore_observation_manager(self): + """Restore original compute_group method.""" + if self._original_compute_group: + self.env.env.unwrapped.observation_manager.compute_group = self._original_compute_group + + # ────────────────────────────────────────────────────────────────── + # Action Manager + # ────────────────────────────────────────────────────────────────── + + def _patch_action_manager(self): + """Patch the action manager instance's process_action method.""" + action_manager = self.env.env.unwrapped.action_manager + self._original_process_action = action_manager.process_action + + def patched_process_action(action: torch.Tensor): + action = annotate.input_tensors({"actions": action}, node_name="action_manager") + + # Register raw_actions buffers for tracing + for term_name, term in action_manager._terms.items(): + if hasattr(term, "_raw_actions") and term._raw_actions is not None: + buffers = annotate.register_buffer("action_manager", {"raw_actions": term._raw_actions}) + term._raw_actions = buffers["raw_actions"] + + self._original_process_action(action) + annotate.mirror_leapp_tags(action, action_manager._action) + + tensors, static_values = self._collect_action_outputs(action_manager) + annotate.output_tensors("action_manager", tensors, static_outputs=static_values, export_with="torch") + + action_manager.process_action = patched_process_action + + def _collect_action_outputs(self, action_manager) -> tuple[dict, dict]: + """Collect action tensors and static values from all terms.""" + tensors = {} + static_values = {} + + for term_name, term in action_manager._terms.items(): + tensors[term_name] = term.processed_actions + + # Handle variable impedance (dynamic gains) + osc = getattr(term, "_osc", None) + if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: + tensors[f"{term_name}_kp_gains"] = torch.diagonal(osc._motion_p_gains_task, dim1=-2, dim2=-1) + tensors[f"{term_name}_kd_gains"] = torch.diagonal(osc._motion_d_gains_task, dim1=-2, dim2=-1) + continue + + # Collect static gains + asset = getattr(term, "_asset", None) + if asset and hasattr(asset, "data"): + self._collect_static_gains(term_name, asset.data, getattr(term, "_joint_ids", None), static_values) + + return tensors, static_values + + def _collect_static_gains(self, term_name: str, data, joint_ids, static_values: dict): + """Extract static kp/kd gains from asset data.""" + if hasattr(data, "default_joint_stiffness") and data.default_joint_stiffness is not None: + gains = data.default_joint_stiffness + static_values[f"{term_name}_kp_gains"] = gains[:, joint_ids] if joint_ids else gains + + if hasattr(data, "default_joint_damping") and data.default_joint_damping is not None: + gains = data.default_joint_damping + static_values[f"{term_name}_kd_gains"] = gains[:, joint_ids] if joint_ids else gains + + def _restore_action_manager(self): + """Restore original process_action method.""" + if self._original_process_action: + self.env.env.unwrapped.action_manager.process_action = self._original_process_action + + # ────────────────────────────────────────────────────────────────── + # ArticulationData Property Annotations + # ────────────────────────────────────────────────────────────────── + + def _prepare_articulation_annotations(self): + """Prepare annotating versions of ArticulationData properties.""" + for prop_name in self.OBSERVATION_PROPERTIES: + original_prop = getattr(ArticulationData, prop_name, None) + if not isinstance(original_prop, property) or original_prop.fget is None: + continue + + self._articulation_originals[prop_name] = original_prop + self._articulation_annotating[prop_name] = self._make_annotating_property(original_prop, prop_name) + + def _make_annotating_property(self, original: property, prop_name: str) -> property: + """Create an annotating version of an ArticulationData property.""" + original_fget = original.fget + assert original_fget is not None # Checked in _prepare_articulation_annotations + + def annotating_fget(data_self): + result = original_fget(data_self) + obs_name = self._find_calling_observation() + if obs_name: + self._record_articulation_access(obs_name, prop_name) + + if isinstance(result, torch.Tensor): + # Check if this property was already annotated in this compute_group call + if prop_name in self._annotated_tensor_cache: + # Return a clone of the cached tensor to avoid duplicate input annotations + return self._annotated_tensor_cache[prop_name].clone() + + # First access - annotate and cache + result = annotate.input_tensors({prop_name: result}, node_name="observation_manager") + self._annotated_tensor_cache[prop_name] = result + + return result + + return property(fget=annotating_fget, fset=original.fset, fdel=original.fdel, doc=original.__doc__) + + def _apply_articulation_annotations(self): + """Temporarily apply annotating properties.""" + if not self._annotations_active: + # Clear the tensor cache at the start of each compute_group call + self._annotated_tensor_cache.clear() + for prop_name, prop in self._articulation_annotating.items(): + setattr(ArticulationData, prop_name, prop) + self._annotations_active = True + + def _remove_articulation_annotations(self): + """Restore original properties.""" + if self._annotations_active: + for prop_name, prop in self._articulation_originals.items(): + setattr(ArticulationData, prop_name, prop) + self._annotations_active = False + # Clear the tensor cache when done + self._annotated_tensor_cache.clear() + + # ────────────────────────────────────────────────────────────────── + # Helpers + # ────────────────────────────────────────────────────────────────── + + def _record_articulation_access(self, obs_name: str, prop_name: str): + """Record that an observation accessed an articulation property.""" + if obs_name not in self.observation_to_articulation_map: + self.observation_to_articulation_map[obs_name] = set() + self.observation_to_articulation_map[obs_name].add(prop_name) + + def _find_calling_observation(self) -> str | None: + """Walk the stack to find the observation function that triggered access. + + Returns the IO descriptor name if available, otherwise the function name. + """ + for frame_info in inspect.stack(): + if "isaaclab/envs/mdp/observations" in frame_info.filename: + func_name = frame_info.function + if func_name.startswith("_"): + continue + + # Try to get the IO descriptor name from the function's descriptor + # The function object should be in the frame's global namespace + frame_globals = frame_info.frame.f_globals + if func_name in frame_globals: + func = frame_globals[func_name] + if hasattr(func, "_descriptor") and hasattr(func._descriptor, "name"): + return func._descriptor.name + + # Fallback to function name (which is what descriptor.name is set to anyway) + return func_name + return None + + # ────────────────────────────────────────────────────────────────── + # Public API for accessing mappings + # ────────────────────────────────────────────────────────────────── + + @property + def get_semantic(self) -> dict[str, Any]: + observations = [] + for k in self.io_descriptor_observations: + obs_name = k["name"] + observation = { + "name": obs_name, + } + # Add the leapp input names this observation maps to (copy list to avoid YAML anchors) + if obs_name in self.observation_to_articulation_map: + observation["leapp_mapping"] = list(self.observation_to_articulation_map[obs_name]) + if "joint_names" in k: + observation["joint_names"] = k["joint_names"] + if "units" in k["extras"]: + observation["units"] = k["extras"]["units"] + observations.append(observation) + + actions = [] + for k in self.io_descriptor_actions: + action_name = k["name"] + action = { + "name": action_name, + } + if action_name in self.action_io_to_term_map: + action["leapp_mapping"] = list(self.action_io_to_term_map[action_name]) + if "joint_names" in k: + action["joint_names"] = k["joint_names"] + if "units" in k["extras"]: + action["units"] = k["extras"]["units"] + actions.append(action) + + scene = self.io_descriptor_scene + + return { + "observations": observations, + "actions": actions, + "scene": scene, + } From fa8ffd9717a62deb92f364b2a8cc74f40a7792f0 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Fri, 16 Jan 2026 13:50:55 -0800 Subject: [PATCH 10/23] added documentation --- .../rsl_rl/LEAPP_annotations_for_isaac_lab.md | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md diff --git a/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md b/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md new file mode 100644 index 000000000000..3b9108e30316 --- /dev/null +++ b/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md @@ -0,0 +1,41 @@ +# LEAPP Export for Isaac Lab + +Export RSL-RL reinforcement learning pipelines as portable processing graphs using [LEAPP](https://gitlab-master.nvidia.com/Isaac/leapp). + +## Exported Artifacts + +| File | Description | +|------|-------------| +| `observation_manager.pt` | Observation processing (TorchScript) | +| `policy.onnx` | Policy network (ONNX) | +| `action_manager.pt` | Action processing (TorchScript) | +| `.yaml` | Pipeline configuration and metadata | +| `.png` | Visualization of the processing graph | + +The YAML file includes semantic metadata (joint names, units, etc.) extracted from IO descriptors. For details on the YAML format, see the [LEAPP documentation](https://gitlab-master.nvidia.com/Isaac/leapp/-/blob/main/docs/0_getting_started.md). + +## Usage + +### 1. Install LEAPP + +```bash +git clone ssh://git@gitlab-master.nvidia.com:12051/Isaac/leapp.git +cd leapp +git checkout develop +pip install -e . +``` + +### 2. Export a Policy + +```bash +./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/export.py \ + --task Isaac-Reach-Franka-v0 \ + --use_pretrained_checkpoint \ + --headless +``` + +> **Note:** Export runs with a single environment instance. + +### 3. View Results + +Artifacts are saved to `.//`. From 619b4b3f64012d3cab0912f61631a4f4ccbe94cd Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Fri, 16 Jan 2026 13:53:56 -0800 Subject: [PATCH 11/23] added sample generated yaml file --- .../rsl_rl/LEAPP_annotations_for_isaac_lab.md | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md b/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md index 3b9108e30316..4d5a9c3d4499 100644 --- a/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md +++ b/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md @@ -39,3 +39,153 @@ pip install -e . ### 3. View Results Artifacts are saved to `.//`. + + + +sample exported `Isaac-Reach-Franka-v0.yaml`: + +```yaml +models: + observation_manager: + inputs: + - name: joint_pos + dtype: float32 + shape: [1, 9] + type: tensor + - name: joint_vel + dtype: float32 + shape: [1, 9] + type: tensor + - name: ee_pose + dtype: float32 + shape: [1, 7] + type: tensor + - name: last_actions + dtype: float32 + shape: [1, 7] + type: tensor + outputs: + - name: obs_policy + dtype: float32 + shape: [1, 32] + type: tensor + parameters: + model_path: observation_manager.pt + md5sum: 3e44b3d2942d5fc3c6a88f28ef3d7b5a + sha256sum: 8d5761e8830be584ec09863775e9bef135e2ad081bcad3029ac1ab50a7fcf819 + device: cuda + backend: torch + policy: + inputs: + - name: obs_policy + dtype: float32 + shape: [1, 32] + type: tensor + outputs: + - name: actions + dtype: float32 + shape: [1, 7] + type: tensor + parameters: + model_path: policy.onnx + md5sum: 848384ae8e4d22052d6e87719d8cb42c + sha256sum: e69cf132746a570e504eb23071ad9cddd146eafaa787adf5bc0951ed948a4bcc + device: cuda + backend: onnx + action_manager: + inputs: + - name: actions + dtype: float32 + shape: [1, 7] + type: tensor + outputs: + - name: arm_action + dtype: float32 + shape: [1, 7] + type: tensor + - name: arm_action_kp_gains + dtype: float32 + shape: [1, 7] + type: tensor + - name: arm_action_kd_gains + dtype: float32 + shape: [1, 7] + type: tensor + parameters: + model_path: action_manager.pt + md5sum: cbbed1862042f23bd285da4c0ddaa946 + sha256sum: 20d31b74ab7dc686f29b4b973ef43b9950076dc5a3da40f417839d581c5b328e + device: cuda + backend: torch + +pipeline: + data_flow: + observation_manager/obs_policy: [policy/obs_policy] + policy/actions: [action_manager/actions] + feedback_flow: + policy/actions: [observation_manager/last_actions] + inputs: + observation_manager: [joint_pos, joint_vel, ee_pose] + outputs: + action_manager: [arm_action, arm_action_kp_gains, arm_action_kd_gains] + +system information: + cuda version: '12.8' + leapp version: 0.3.0 + os: Linux + python version: 3.11.14 + torch version: 2.7.0+cu128 + +semantic: + actions: + - joint_names: + - panda_joint1 + - panda_joint2 + - panda_joint3 + - panda_joint4 + - panda_joint5 + - panda_joint6 + - panda_joint7 + leapp_mapping: + - arm_action + name: joint_position_action + observations: + - joint_names: + - panda_joint1 + - panda_joint2 + - panda_joint3 + - panda_joint4 + - panda_joint5 + - panda_joint6 + - panda_joint7 + - panda_finger_joint1 + - panda_finger_joint2 + leapp_mapping: + - joint_pos + name: joint_pos_rel + units: rad + - joint_names: + - panda_joint1 + - panda_joint2 + - panda_joint3 + - panda_joint4 + - panda_joint5 + - panda_joint6 + - panda_joint7 + - panda_finger_joint1 + - panda_finger_joint2 + leapp_mapping: + - joint_vel + name: joint_vel_rel + units: rad/s + - leapp_mapping: + - ee_pose + name: generated_commands + - leapp_mapping: + - last_actions + name: last_action + scene: + decimation: 2 + dt: 0.03333333333333333 + physics_dt: 0.016666666666666666 +``` From d05a2d2807f38fb38efcb61659756d652969fd9d Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Fri, 16 Jan 2026 13:56:30 -0800 Subject: [PATCH 12/23] fixed bug in action_manager --- source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py b/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py index 13e062119a75..c32e501b7591 100644 --- a/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py +++ b/source/isaaclab/isaaclab/envs/mdp/actions/joint_actions.py @@ -169,7 +169,7 @@ def process_actions(self, actions: torch.Tensor): # store the raw actions self._raw_actions[:] = actions # apply the affine transformations - self._processed_actions = self._raw_actions[:] * self._scale + self._offset + self._processed_actions = self._raw_actions * self._scale + self._offset # clip actions if self.cfg.clip is not None: self._processed_actions = torch.clamp( From b10a247fc6357ca8e2e6f3c24338b58cb889c02b Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Wed, 28 Jan 2026 10:16:52 -0800 Subject: [PATCH 13/23] switched to export with onnx --- scripts/reinforcement_learning/rsl_rl/export_annotator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index 6541a60a842b..9b7accbe98f0 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -206,7 +206,7 @@ def patched_compute_group(*args, **kwargs): self._apply_articulation_annotations() try: output = self._original_compute_group(*args, **kwargs) - annotate.output_tensors("observation_manager", output, export_with="torch", use_trace=True) + annotate.output_tensors("observation_manager", output, export_with="onnx", use_trace=True) return output.tensor if isinstance(output, TracedTensor) else output finally: self._remove_articulation_annotations() @@ -240,7 +240,7 @@ def patched_process_action(action: torch.Tensor): annotate.mirror_leapp_tags(action, action_manager._action) tensors, static_values = self._collect_action_outputs(action_manager) - annotate.output_tensors("action_manager", tensors, static_outputs=static_values, export_with="torch") + annotate.output_tensors("action_manager", tensors, static_outputs=static_values, export_with="onnx") action_manager.process_action = patched_process_action From f85131b5b4a6ad53c19e74e75391a66cfd86cacd Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Thu, 29 Jan 2026 13:16:35 -0800 Subject: [PATCH 14/23] exports as one single policy --- .../reinforcement_learning/rsl_rl/export.py | 11 ++------ .../rsl_rl/export_annotator.py | 26 +++++++------------ 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index c3ac56cda278..482a4276b69d 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -127,7 +127,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # create isaac environment # Note: observation functions are already patched at module level (before isaaclab_tasks import) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) - annotator = ExportAnnotator(env) + annotator = ExportAnnotator(env, task_name=task_name) annotator.setup() # convert to single-agent instance if required by the RL algorithm @@ -186,14 +186,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen for _ in range(5): # run everything in inference mode with torch.inference_mode(): - # agent stepping - with annotate.block( - "policy", - inputs=["obs"], - outputs=["actions"], - backend_params={"model_path": onnx_path, "copy_original_model": True}, - ): - actions = policy(obs) + actions = policy(obs) # env stepping obs, _, _, _ = env.step(actions) diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index 9b7accbe98f0..497cc3c5d38d 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -3,11 +3,6 @@ # # SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2022-2026, The Isaac Lab Project Developers -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - - """Export annotations for Isaac Lab policies using instance-level patching.""" @@ -19,7 +14,6 @@ from typing import TYPE_CHECKING, Any from leapp import annotate -from leapp.leapp_graph.traced_tensor import TracedTensor from isaaclab.assets.articulation.articulation_data import ArticulationData @@ -42,6 +36,7 @@ class ExportAnnotator: """ env: ManagerBasedEnv + task_name: str io_descriptor_observations: list[Any] = field(default_factory=list) io_descriptor_actions: list[Any] = field(default_factory=list) @@ -165,7 +160,7 @@ def _make_patched_last_action(self, original_func): def patched_last_action(env, action_name=None, **kwargs): result = original_func(env, action_name, **kwargs) self._record_articulation_access("last_action", "last_actions") - result = annotate.input_tensors({"last_actions": result}, node_name="observation_manager") + result = annotate.input_tensors({"last_actions": result}, node_name=self.task_name) return result patched_last_action.__name__ = original_func.__name__ @@ -181,7 +176,7 @@ def patched_generated_commands(env, command_name=None, **kwargs): # Use command_name parameter, or fall back to config, or default leapp_input_name = command_name or command_name_from_cfg or "commands" self._record_articulation_access("generated_commands", leapp_input_name) - result = annotate.input_tensors({leapp_input_name: result}, node_name="observation_manager") + result = annotate.input_tensors({leapp_input_name: result}, node_name=self.task_name) return result patched_generated_commands.__name__ = original_func.__name__ @@ -205,9 +200,7 @@ def _patch_observation_manager(self): def patched_compute_group(*args, **kwargs): self._apply_articulation_annotations() try: - output = self._original_compute_group(*args, **kwargs) - annotate.output_tensors("observation_manager", output, export_with="onnx", use_trace=True) - return output.tensor if isinstance(output, TracedTensor) else output + return self._original_compute_group(*args, **kwargs) finally: self._remove_articulation_annotations() @@ -228,19 +221,20 @@ def _patch_action_manager(self): self._original_process_action = action_manager.process_action def patched_process_action(action: torch.Tensor): - action = annotate.input_tensors({"actions": action}, node_name="action_manager") # Register raw_actions buffers for tracing for term_name, term in action_manager._terms.items(): if hasattr(term, "_raw_actions") and term._raw_actions is not None: - buffers = annotate.register_buffer("action_manager", {"raw_actions": term._raw_actions}) + buffers = annotate.register_buffer(self.task_name, {"raw_actions": term._raw_actions}) term._raw_actions = buffers["raw_actions"] self._original_process_action(action) - annotate.mirror_leapp_tags(action, action_manager._action) + # this is stored differently inside the original process action method that would loose tracing. this step preserves it. + action_manager._action = action.clone() tensors, static_values = self._collect_action_outputs(action_manager) - annotate.output_tensors("action_manager", tensors, static_outputs=static_values, export_with="onnx") + tensors["last_action"] = action_manager._action + annotate.output_tensors(self.task_name, tensors, static_outputs=static_values, export_with="onnx") action_manager.process_action = patched_process_action @@ -313,7 +307,7 @@ def annotating_fget(data_self): return self._annotated_tensor_cache[prop_name].clone() # First access - annotate and cache - result = annotate.input_tensors({prop_name: result}, node_name="observation_manager") + result = annotate.input_tensors({prop_name: result}, node_name=self.task_name) self._annotated_tensor_cache[prop_name] = result return result From 847adbc7f8549fff5a4b1e9b86d0c3b3156b0285 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Thu, 29 Jan 2026 16:42:31 -0800 Subject: [PATCH 15/23] fixed readme --- .../rsl_rl/LEAPP_annotations_for_isaac_lab.md | 59 ++++--------------- 1 file changed, 13 insertions(+), 46 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md b/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md index 4d5a9c3d4499..d8ae03f5a6dd 100644 --- a/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md +++ b/scripts/reinforcement_learning/rsl_rl/LEAPP_annotations_for_isaac_lab.md @@ -6,9 +6,7 @@ Export RSL-RL reinforcement learning pipelines as portable processing graphs usi | File | Description | |------|-------------| -| `observation_manager.pt` | Observation processing (TorchScript) | -| `policy.onnx` | Policy network (ONNX) | -| `action_manager.pt` | Action processing (TorchScript) | +| `.onnx` | Policy network (ONNX) | | `.yaml` | Pipeline configuration and metadata | | `.png` | Visualization of the processing graph | @@ -46,7 +44,7 @@ sample exported `Isaac-Reach-Franka-v0.yaml`: ```yaml models: - observation_manager: + Isaac-Reach-Franka-v0: inputs: - name: joint_pos dtype: float32 @@ -65,41 +63,11 @@ models: shape: [1, 7] type: tensor outputs: - - name: obs_policy - dtype: float32 - shape: [1, 32] - type: tensor - parameters: - model_path: observation_manager.pt - md5sum: 3e44b3d2942d5fc3c6a88f28ef3d7b5a - sha256sum: 8d5761e8830be584ec09863775e9bef135e2ad081bcad3029ac1ab50a7fcf819 - device: cuda - backend: torch - policy: - inputs: - - name: obs_policy - dtype: float32 - shape: [1, 32] - type: tensor - outputs: - - name: actions - dtype: float32 - shape: [1, 7] - type: tensor - parameters: - model_path: policy.onnx - md5sum: 848384ae8e4d22052d6e87719d8cb42c - sha256sum: e69cf132746a570e504eb23071ad9cddd146eafaa787adf5bc0951ed948a4bcc - device: cuda - backend: onnx - action_manager: - inputs: - - name: actions + - name: arm_action dtype: float32 shape: [1, 7] type: tensor - outputs: - - name: arm_action + - name: last_action dtype: float32 shape: [1, 7] type: tensor @@ -112,22 +80,20 @@ models: shape: [1, 7] type: tensor parameters: - model_path: action_manager.pt - md5sum: cbbed1862042f23bd285da4c0ddaa946 - sha256sum: 20d31b74ab7dc686f29b4b973ef43b9950076dc5a3da40f417839d581c5b328e + model_path: Isaac-Reach-Franka-v0.onnx + md5sum: 38ee55fa7828b5068b86024206bd5ddb + sha256sum: c605a7076fde5c0d03a36f548d458d24bd543df67aac7675d463d29f870a7eb3 device: cuda - backend: torch + backend: onnx pipeline: - data_flow: - observation_manager/obs_policy: [policy/obs_policy] - policy/actions: [action_manager/actions] + data_flow: {} feedback_flow: - policy/actions: [observation_manager/last_actions] + Isaac-Reach-Franka-v0/last_action: [Isaac-Reach-Franka-v0/last_actions] inputs: - observation_manager: [joint_pos, joint_vel, ee_pose] + Isaac-Reach-Franka-v0: [joint_pos, joint_vel, ee_pose] outputs: - action_manager: [arm_action, arm_action_kp_gains, arm_action_kd_gains] + Isaac-Reach-Franka-v0: [arm_action, arm_action_kp_gains, arm_action_kd_gains] system information: cuda version: '12.8' @@ -188,4 +154,5 @@ semantic: decimation: 2 dt: 0.03333333333333333 physics_dt: 0.016666666666666666 + ``` From b84b3124c09eb9fb8d16fa34ac88fe78f881de8b Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Mon, 9 Mar 2026 14:35:36 -0700 Subject: [PATCH 16/23] changes for export to work after v0.5.0 update --- scripts/reinforcement_learning/rsl_rl/export.py | 11 ++++++----- .../reinforcement_learning/rsl_rl/export_annotator.py | 10 +++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 482a4276b69d..fe623af3dd01 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -11,7 +11,7 @@ import sys import yaml -from leapp import annotate +import leapp from isaaclab.app import AppLauncher @@ -177,7 +177,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # start annotation tracing # Note: all patching is done at module/class level before isaaclab_tasks import - annotate.start(task_name) + leapp.start(task_name, save_path=export_model_dir) obs = env.get_observations() # simulate environment while not simulation_app.is_running(): @@ -190,12 +190,13 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # env stepping obs, _, _, _ = env.step(actions) - annotate.stop() - annotate.compile_graph() + leapp.stop() + leapp.compile_graph() semantic = annotator.get_semantic + config_path = os.path.join(export_model_dir, task_name, f"{task_name}.yaml") - with open(annotate.config_path, "a") as f: + with open(config_path, "a") as f: yaml.dump({"semantic": semantic}, f) # close the simulator diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index 497cc3c5d38d..5acadcaf641b 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -160,7 +160,7 @@ def _make_patched_last_action(self, original_func): def patched_last_action(env, action_name=None, **kwargs): result = original_func(env, action_name, **kwargs) self._record_articulation_access("last_action", "last_actions") - result = annotate.input_tensors({"last_actions": result}, node_name=self.task_name) + result = annotate.input_tensors(self.task_name, {"last_actions": result}) return result patched_last_action.__name__ = original_func.__name__ @@ -176,7 +176,7 @@ def patched_generated_commands(env, command_name=None, **kwargs): # Use command_name parameter, or fall back to config, or default leapp_input_name = command_name or command_name_from_cfg or "commands" self._record_articulation_access("generated_commands", leapp_input_name) - result = annotate.input_tensors({leapp_input_name: result}, node_name=self.task_name) + result = annotate.input_tensors(self.task_name, {leapp_input_name: result}) return result patched_generated_commands.__name__ = original_func.__name__ @@ -225,8 +225,7 @@ def patched_process_action(action: torch.Tensor): # Register raw_actions buffers for tracing for term_name, term in action_manager._terms.items(): if hasattr(term, "_raw_actions") and term._raw_actions is not None: - buffers = annotate.register_buffer(self.task_name, {"raw_actions": term._raw_actions}) - term._raw_actions = buffers["raw_actions"] + term._raw_actions = annotate.register_buffer(self.task_name, {"raw_actions": term._raw_actions}) self._original_process_action(action) # this is stored differently inside the original process action method that would loose tracing. this step preserves it. @@ -255,6 +254,7 @@ def _collect_action_outputs(self, action_manager) -> tuple[dict, dict]: # Collect static gains asset = getattr(term, "_asset", None) + if asset and hasattr(asset, "data"): self._collect_static_gains(term_name, asset.data, getattr(term, "_joint_ids", None), static_values) @@ -307,7 +307,7 @@ def annotating_fget(data_self): return self._annotated_tensor_cache[prop_name].clone() # First access - annotate and cache - result = annotate.input_tensors({prop_name: result}, node_name=self.task_name) + result = annotate.input_tensors(self.task_name, {prop_name: result}) self._annotated_tensor_cache[prop_name] = result return result From e2d203841576bf2c6cca2152f05c70d350991072 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Tue, 10 Mar 2026 16:54:17 -0700 Subject: [PATCH 17/23] refactor and added semantic data --- .../reinforcement_learning/rsl_rl/export.py | 7 - .../rsl_rl/export_annotator.py | 477 +++++++++++------- .../assets/articulation/articulation.py | 5 + .../assets/articulation/articulation_data.py | 23 +- .../envs/mdp/commands/pose_2d_command.py | 3 + .../envs/mdp/commands/pose_command.py | 3 + .../envs/mdp/commands/velocity_command.py | 3 + .../isaaclab/managers/manager_term_cfg.py | 3 + .../isaaclab/utils/leapp_semantics.py | 100 ++++ .../dexsuite/mdp/commands/pose_commands.py | 3 + .../mdp/commands/orientation_command.py | 3 + 11 files changed, 426 insertions(+), 204 deletions(-) create mode 100644 source/isaaclab/isaaclab/utils/leapp_semantics.py diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index fe623af3dd01..c71aa5fbbad3 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -9,7 +9,6 @@ import argparse import sys -import yaml import leapp @@ -193,12 +192,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen leapp.stop() leapp.compile_graph() - semantic = annotator.get_semantic - config_path = os.path.join(export_model_dir, task_name, f"{task_name}.yaml") - - with open(config_path, "a") as f: - yaml.dump({"semantic": semantic}, f) - # close the simulator env.close() diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index 5acadcaf641b..8de5387207b1 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -11,16 +11,24 @@ import inspect import torch from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from leapp import annotate +from leapp.utils.tensor_description import TensorSemantics +from isaaclab.assets.articulation.articulation import Articulation from isaaclab.assets.articulation.articulation_data import ArticulationData +from isaaclab.utils.leapp_semantics import resolve_leapp_element_names if TYPE_CHECKING: from isaaclab.envs import ManagerBasedEnv +# class ObservationPatcher: + +# class ActionPatcher: + + @dataclass class ExportAnnotator: """Encapsulates all leapp annotation logic for exporting Isaac Lab policies. @@ -30,88 +38,63 @@ class ExportAnnotator: annotator = ExportAnnotator(env) annotator.setup() # ... run policy ... - obs_map = annotator.observation_to_articulation_map - action_map = annotator.action_io_to_term_map annotator.cleanup() """ env: ManagerBasedEnv task_name: str - io_descriptor_observations: list[Any] = field(default_factory=list) - io_descriptor_actions: list[Any] = field(default_factory=list) - io_descriptor_scene: dict[str, Any] = field(default_factory=dict) - - # Mappings built during execution - observation_to_articulation_map: dict[str, set[str]] = field(default_factory=dict) - action_io_to_term_map: dict[str, list[str]] = field(default_factory=dict) - # Original methods for restoration _original_compute_group: callable = field(default=None, repr=False) _original_process_action: callable = field(default=None, repr=False) + _original_apply_action: callable = field(default=None, repr=False) # ArticulationData patching state _articulation_originals: dict[str, property] = field(default_factory=dict, repr=False) _articulation_annotating: dict[str, property] = field(default_factory=dict, repr=False) _annotations_active: bool = field(default=False, repr=False) + + # Action writer patching state + _action_write_originals: dict[str, callable] = field(default_factory=dict, repr=False) + _action_write_annotating: dict[str, callable] = field(default_factory=dict, repr=False) + _action_write_annotations_active: bool = field(default=False, repr=False) + # Cache for annotated tensors within a single compute_group call # Prevents duplicate input tensors when same property is accessed multiple times _annotated_tensor_cache: dict[str, torch.Tensor] = field(default_factory=dict, repr=False) - # Articulation properties to annotate - OBSERVATION_PROPERTIES: frozenset[str] = frozenset({ - "root_pos_w", - "root_quat_w", - "root_lin_vel_b", - "root_ang_vel_b", - "root_lin_vel_w", - "root_ang_vel_w", - "projected_gravity_b", - "body_pose_w", - "body_quat_w", - "joint_pos", - "joint_vel", - }) + _action_output_cache: list[TensorSemantics] = field(default_factory=list, repr=False) + _active_action_term_name: str | None = field(default=None, repr=False) + _pending_action_output_export: bool = field(default=False, repr=False) def setup(self): """Set up all annotations. Call after env is created.""" - self._collect_io_descriptors() - self._disable_observation_noise() - self._prepare_articulation_annotations() - self._patch_observation_functions() - self._patch_observation_manager() + self._setup_observation_annotations() + self._prepare_action_write_annotations() self._patch_action_manager() def cleanup(self): """Restore all original methods and properties.""" - self._restore_observation_functions() - self._restore_observation_manager() + self._restore_observation_annotations() self._restore_action_manager() - self._remove_articulation_annotations() + self._remove_action_write_annotations() # ────────────────────────────────────────────────────────────────── - # IO Descriptor Collection (before patching) + # Observation Annotations # ────────────────────────────────────────────────────────────────── - def _collect_io_descriptors(self): - outs = self.env.unwrapped.get_IO_descriptors - self.io_descriptor_observations = outs["observations"]["policy"] - self.io_descriptor_actions = outs["actions"] - self.io_descriptor_scene = outs["scene"] - - # Build action IO descriptor name -> term name mapping - # e.g., 'joint_position_action' -> ['arm_action'] - action_manager = self.env.env.unwrapped.action_manager - for term_name, term in action_manager._terms.items(): - try: - io_name = term.IO_descriptor.name - self.action_io_to_term_map[io_name] = [term_name] - except Exception: - pass # Skip if IO descriptor not available + def _setup_observation_annotations(self): + """Set up all observation-side annotations.""" + self._disable_observation_noise() + self._prepare_articulation_annotations() + self._patch_observation_functions() + self._patch_observation_manager() - # ────────────────────────────────────────────────────────────────── - # Observation Manager - # ────────────────────────────────────────────────────────────────── + def _restore_observation_annotations(self): + """Restore all observation-side patches and temporary annotations.""" + self._restore_observation_functions() + self._restore_observation_manager() + self._remove_articulation_annotations() def _disable_observation_noise(self): """Disable noise/corruption for deterministic export. @@ -141,6 +124,7 @@ def _patch_observation_functions(self): # Store original functions for restoration: (group_name, term_idx) -> original_func self._original_obs_funcs: dict[tuple[str, int], callable] = {} + # find and patch all other known non-articulation data properties for group_name, term_cfgs in obs_manager._group_obs_term_cfgs.items(): for term_idx, term_cfg in enumerate(term_cfgs): original_func = term_cfg.func @@ -155,11 +139,10 @@ def _patch_observation_functions(self): term_cfg.func = self._make_patched_generated_commands(original_func, term_cfg) def _make_patched_last_action(self, original_func): - """Create a patched version of last_action that records mappings.""" + """Create a patched version of last_action for LEAPP tracing.""" def patched_last_action(env, action_name=None, **kwargs): result = original_func(env, action_name, **kwargs) - self._record_articulation_access("last_action", "last_actions") result = annotate.input_tensors(self.task_name, {"last_actions": result}) return result @@ -167,7 +150,7 @@ def patched_last_action(env, action_name=None, **kwargs): return patched_last_action def _make_patched_generated_commands(self, original_func, term_cfg): - """Create a patched version of generated_commands that records mappings.""" + """Create a patched version of generated_commands for LEAPP tracing.""" # Get the command_name from term_cfg.params if available command_name_from_cfg = term_cfg.params.get("command_name") @@ -175,8 +158,20 @@ def patched_generated_commands(env, command_name=None, **kwargs): result = original_func(env, command_name, **kwargs) # Use command_name parameter, or fall back to config, or default leapp_input_name = command_name or command_name_from_cfg or "commands" - self._record_articulation_access("generated_commands", leapp_input_name) - result = annotate.input_tensors(self.task_name, {leapp_input_name: result}) + command_cfg = None + try: + command_cfg = env.command_manager.get_term(leapp_input_name).cfg + except (AttributeError, KeyError): + # Keep export working even if the observation term doesn't point to a registered command term. + command_cfg = None + + semantics = TensorSemantics( + name=leapp_input_name, + ref=result, + kind=getattr(command_cfg, "cmd_hint", None), + element_names=getattr(command_cfg, "element_names", None), + ) + result = annotate.input_tensors(self.task_name, semantics) return result patched_generated_commands.__name__ = original_func.__name__ @@ -211,77 +206,13 @@ def _restore_observation_manager(self): if self._original_compute_group: self.env.env.unwrapped.observation_manager.compute_group = self._original_compute_group - # ────────────────────────────────────────────────────────────────── - # Action Manager - # ────────────────────────────────────────────────────────────────── - - def _patch_action_manager(self): - """Patch the action manager instance's process_action method.""" - action_manager = self.env.env.unwrapped.action_manager - self._original_process_action = action_manager.process_action - - def patched_process_action(action: torch.Tensor): - - # Register raw_actions buffers for tracing - for term_name, term in action_manager._terms.items(): - if hasattr(term, "_raw_actions") and term._raw_actions is not None: - term._raw_actions = annotate.register_buffer(self.task_name, {"raw_actions": term._raw_actions}) - - self._original_process_action(action) - # this is stored differently inside the original process action method that would loose tracing. this step preserves it. - action_manager._action = action.clone() - - tensors, static_values = self._collect_action_outputs(action_manager) - tensors["last_action"] = action_manager._action - annotate.output_tensors(self.task_name, tensors, static_outputs=static_values, export_with="onnx") - - action_manager.process_action = patched_process_action - - def _collect_action_outputs(self, action_manager) -> tuple[dict, dict]: - """Collect action tensors and static values from all terms.""" - tensors = {} - static_values = {} - - for term_name, term in action_manager._terms.items(): - tensors[term_name] = term.processed_actions - - # Handle variable impedance (dynamic gains) - osc = getattr(term, "_osc", None) - if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: - tensors[f"{term_name}_kp_gains"] = torch.diagonal(osc._motion_p_gains_task, dim1=-2, dim2=-1) - tensors[f"{term_name}_kd_gains"] = torch.diagonal(osc._motion_d_gains_task, dim1=-2, dim2=-1) - continue - - # Collect static gains - asset = getattr(term, "_asset", None) - - if asset and hasattr(asset, "data"): - self._collect_static_gains(term_name, asset.data, getattr(term, "_joint_ids", None), static_values) - - return tensors, static_values - - def _collect_static_gains(self, term_name: str, data, joint_ids, static_values: dict): - """Extract static kp/kd gains from asset data.""" - if hasattr(data, "default_joint_stiffness") and data.default_joint_stiffness is not None: - gains = data.default_joint_stiffness - static_values[f"{term_name}_kp_gains"] = gains[:, joint_ids] if joint_ids else gains - - if hasattr(data, "default_joint_damping") and data.default_joint_damping is not None: - gains = data.default_joint_damping - static_values[f"{term_name}_kd_gains"] = gains[:, joint_ids] if joint_ids else gains - - def _restore_action_manager(self): - """Restore original process_action method.""" - if self._original_process_action: - self.env.env.unwrapped.action_manager.process_action = self._original_process_action - # ────────────────────────────────────────────────────────────────── # ArticulationData Property Annotations # ────────────────────────────────────────────────────────────────── def _prepare_articulation_annotations(self): """Prepare annotating versions of ArticulationData properties.""" - for prop_name in self.OBSERVATION_PROPERTIES: + for prop_name in self._get_semantic_articulation_properties(): original_prop = getattr(ArticulationData, prop_name, None) if not isinstance(original_prop, property) or original_prop.fget is None: continue @@ -296,9 +227,6 @@ def _make_annotating_property(self, original: property, prop_name: str) -> prope def annotating_fget(data_self): result = original_fget(data_self) - obs_name = self._find_calling_observation() - if obs_name: - self._record_articulation_access(obs_name, prop_name) if isinstance(result, torch.Tensor): # Check if this property was already annotated in this compute_group call @@ -306,8 +234,10 @@ def annotating_fget(data_self): # Return a clone of the cached tensor to avoid duplicate input annotations return self._annotated_tensor_cache[prop_name].clone() + sem = self._make_property_semantics(prop_name, data_self, result) + # First access - annotate and cache - result = annotate.input_tensors(self.task_name, {prop_name: result}) + result = annotate.input_tensors(self.task_name, sem) self._annotated_tensor_cache[prop_name] = result return result @@ -333,77 +263,240 @@ def _remove_articulation_annotations(self): self._annotated_tensor_cache.clear() # ────────────────────────────────────────────────────────────────── - # Helpers + # Action Manager # ────────────────────────────────────────────────────────────────── - def _record_articulation_access(self, obs_name: str, prop_name: str): - """Record that an observation accessed an articulation property.""" - if obs_name not in self.observation_to_articulation_map: - self.observation_to_articulation_map[obs_name] = set() - self.observation_to_articulation_map[obs_name].add(prop_name) + def _patch_action_manager(self): + """Patch the action manager instance's action processing methods.""" + action_manager = self.env.env.unwrapped.action_manager + self._original_process_action = action_manager.process_action + self._original_apply_action = action_manager.apply_action - def _find_calling_observation(self) -> str | None: - """Walk the stack to find the observation function that triggered access. + def patched_process_action(action: torch.Tensor): + # Register raw_actions buffers for tracing + for term_name, term in action_manager._terms.items(): + if hasattr(term, "_raw_actions") and term._raw_actions is not None: + term._raw_actions = annotate.register_buffer(self.task_name, {"raw_actions": term._raw_actions}) - Returns the IO descriptor name if available, otherwise the function name. - """ - for frame_info in inspect.stack(): - if "isaaclab/envs/mdp/observations" in frame_info.filename: - func_name = frame_info.function - if func_name.startswith("_"): - continue - - # Try to get the IO descriptor name from the function's descriptor - # The function object should be in the frame's global namespace - frame_globals = frame_info.frame.f_globals - if func_name in frame_globals: - func = frame_globals[func_name] - if hasattr(func, "_descriptor") and hasattr(func._descriptor, "name"): - return func._descriptor.name - - # Fallback to function name (which is what descriptor.name is set to anyway) - return func_name - return None + self._original_process_action(action) + # this is stored differently inside the original process action method that would loose tracing. this step preserves it. + action_manager._action = action.clone() + self._pending_action_output_export = True + + def patched_apply_action(): + if not self._pending_action_output_export: + return self._original_apply_action() + + original_term_apply_actions: dict[str, callable] = {} + self._action_output_cache.clear() + self._apply_action_write_annotations() + + try: + for term_name, term in action_manager._terms.items(): + original_term_apply_actions[term_name] = term.apply_actions + term.apply_actions = self._make_patched_term_apply_actions(term.apply_actions, term_name) + + self._original_apply_action() + + self._action_output_cache.extend(self._collect_action_outputs(action_manager)) + self._action_output_cache.append(TensorSemantics(name="last_action", ref=action_manager._action)) + static_values = self._collect_action_static_outputs(action_manager) + annotate.output_tensors( + self.task_name, + self._action_output_cache, + static_outputs=static_values, + export_with="onnx", + ) + self._pending_action_output_export = False + finally: + for term_name, original_apply_actions in original_term_apply_actions.items(): + action_manager._terms[term_name].apply_actions = original_apply_actions + self._active_action_term_name = None + self._remove_action_write_annotations() + self._action_output_cache.clear() + + action_manager.process_action = patched_process_action + action_manager.apply_action = patched_apply_action + + def _make_patched_term_apply_actions(self, original_func, term_name: str): + """Wrap an action term's apply call to keep the current term context.""" + + def patched_apply_actions(): + self._active_action_term_name = term_name + try: + return original_func() + finally: + self._active_action_term_name = None + + return patched_apply_actions + + def _collect_action_outputs(self, action_manager) -> list[TensorSemantics]: + """Collect non-writer action tensors that should still be exported.""" + tensors: list[TensorSemantics] = [] + + for term_name, term in action_manager._terms.items(): + # Handle variable impedance (dynamic gains) + osc = getattr(term, "_osc", None) + if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: + tensors.append( + TensorSemantics( + name=f"{term_name}_kp_gains", + ref=torch.diagonal(osc._motion_p_gains_task, dim1=-2, dim2=-1), + kind="kp", + ) + ) + tensors.append( + TensorSemantics( + name=f"{term_name}_kd_gains", + ref=torch.diagonal(osc._motion_d_gains_task, dim1=-2, dim2=-1), + kind="kp", + ) + ) + return tensors + + def _collect_action_static_outputs(self, action_manager) -> dict: + """Collect static values from action terms.""" + static_values = {} + for term_name, term in action_manager._terms.items(): + osc = getattr(term, "_osc", None) + if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: + continue + asset = getattr(term, "_asset", None) + if asset and hasattr(asset, "data"): + self._collect_static_gains(term_name, asset.data, getattr(term, "_joint_ids", None), static_values) + return static_values + + def _collect_static_gains(self, term_name: str, data, joint_ids, static_values: dict): + """Extract static kp/kd gains from asset data.""" + if hasattr(data, "default_joint_stiffness") and data.default_joint_stiffness is not None: + gains = data.default_joint_stiffness + static_values[f"{term_name}_kp_gains"] = gains[:, joint_ids] if joint_ids else gains + + if hasattr(data, "default_joint_damping") and data.default_joint_damping is not None: + gains = data.default_joint_damping + static_values[f"{term_name}_kd_gains"] = gains[:, joint_ids] if joint_ids else gains + + def _restore_action_manager(self): + """Restore original action manager methods.""" + if self._original_process_action: + self.env.env.unwrapped.action_manager.process_action = self._original_process_action + if self._original_apply_action: + self.env.env.unwrapped.action_manager.apply_action = self._original_apply_action + + # ────────────────────────────────────────────────────────────────── + # Action Write Annotations + # ────────────────────────────────────────────────────────────────── + + def _prepare_action_write_annotations(self): + """Prepare annotating versions of low-level action writer methods.""" + for method_name in self._get_semantic_action_write_methods(): + original_method = getattr(Articulation, method_name, None) + if original_method is None: + continue + + self._action_write_originals[method_name] = original_method + self._action_write_annotating[method_name] = self._make_annotating_action_write_method( + original_method, method_name + ) + + def _make_annotating_action_write_method(self, original_func, method_name: str): + """Create an annotating version of a low-level action writer.""" + signature = inspect.signature(original_func) + + def annotating_method(asset_self, *args, **kwargs): + result = original_func(asset_self, *args, **kwargs) + bound_args = signature.bind_partial(asset_self, *args, **kwargs) + target = bound_args.arguments.get("target") + + if not isinstance(target, torch.Tensor): + return result + + output_name = self._get_action_output_name(method_name) + semantics = getattr(self._action_write_originals[method_name], "_leapp_semantics", None) + joint_ids = bound_args.arguments.get("joint_ids") + tensor_target: torch.Tensor = target + target_snapshot = tensor_target.clone() + self._action_output_cache.append( + TensorSemantics( + name=output_name, + ref=target_snapshot, + kind=semantics.kind if semantics is not None else None, + element_names=resolve_leapp_element_names( + semantics, self._make_joint_name_context(asset_self, joint_ids) + ), + ) + ) + + return result + + annotating_method.__name__ = original_func.__name__ + return annotating_method + + def _apply_action_write_annotations(self): + """Temporarily apply annotating action writer methods.""" + if not self._action_write_annotations_active: + for method_name, method in self._action_write_annotating.items(): + setattr(Articulation, method_name, method) + self._action_write_annotations_active = True + + def _remove_action_write_annotations(self): + """Restore original low-level action writer methods.""" + if self._action_write_annotations_active: + for method_name, method in self._action_write_originals.items(): + setattr(Articulation, method_name, method) + self._action_write_annotations_active = False # ────────────────────────────────────────────────────────────────── - # Public API for accessing mappings + # Helpers # ────────────────────────────────────────────────────────────────── - @property - def get_semantic(self) -> dict[str, Any]: - observations = [] - for k in self.io_descriptor_observations: - obs_name = k["name"] - observation = { - "name": obs_name, - } - # Add the leapp input names this observation maps to (copy list to avoid YAML anchors) - if obs_name in self.observation_to_articulation_map: - observation["leapp_mapping"] = list(self.observation_to_articulation_map[obs_name]) - if "joint_names" in k: - observation["joint_names"] = k["joint_names"] - if "units" in k["extras"]: - observation["units"] = k["extras"]["units"] - observations.append(observation) - - actions = [] - for k in self.io_descriptor_actions: - action_name = k["name"] - action = { - "name": action_name, - } - if action_name in self.action_io_to_term_map: - action["leapp_mapping"] = list(self.action_io_to_term_map[action_name]) - if "joint_names" in k: - action["joint_names"] = k["joint_names"] - if "units" in k["extras"]: - action["units"] = k["extras"]["units"] - actions.append(action) - - scene = self.io_descriptor_scene - - return { - "observations": observations, - "actions": actions, - "scene": scene, - } + def _get_semantic_action_write_methods(self) -> frozenset[str]: + """Collect Articulation methods that advertise LEAPP semantics.""" + methods = set() + for method_name in dir(Articulation): + method = getattr(Articulation, method_name, None) + if callable(method) and hasattr(method, "_leapp_semantics"): + methods.add(method_name) + return frozenset(methods) + + def _get_action_output_name(self, method_name: str) -> str: + """Return a stable output name for the current action write.""" + base_name = self._active_action_term_name or method_name + output_name = base_name + existing_names = {tensor.name for tensor in self._action_output_cache} + if output_name in existing_names: + output_name = f"{base_name}_{method_name}" + suffix = 2 + while output_name in existing_names: + output_name = f"{base_name}_{method_name}_{suffix}" + suffix += 1 + return output_name + + def _make_joint_name_context(self, asset_self: Articulation, joint_ids): + """Create a lightweight context for resolving runtime joint name subsets.""" + return type( + "JointNameContext", + (), + {"joint_names": asset_self.joint_names, "_joint_ids": joint_ids}, + )() + + def _get_semantic_articulation_properties(self) -> frozenset[str]: + """Collect ArticulationData properties that advertise LEAPP semantics.""" + properties = set() + for prop_name in dir(ArticulationData): + prop = getattr(ArticulationData, prop_name, None) + if isinstance(prop, property) and prop.fget is not None and hasattr(prop.fget, "_leapp_semantics"): + properties.add(prop_name) + return frozenset(properties) + + def _make_property_semantics( + self, prop_name: str, data_self: ArticulationData, tensor: torch.Tensor + ) -> TensorSemantics: + """Create semantic metadata for raw ArticulationData inputs.""" + semantics = getattr(self._articulation_originals[prop_name].fget, "_leapp_semantics", None) + return TensorSemantics( + name=prop_name, + ref=tensor, + kind=semantics.kind if semantics is not None else None, + element_names=resolve_leapp_element_names(semantics, data_self), + ) diff --git a/source/isaaclab/isaaclab/assets/articulation/articulation.py b/source/isaaclab/isaaclab/assets/articulation/articulation.py index 552e9b2d92ec..0a7143959534 100644 --- a/source/isaaclab/isaaclab/assets/articulation/articulation.py +++ b/source/isaaclab/isaaclab/assets/articulation/articulation.py @@ -31,6 +31,8 @@ if TYPE_CHECKING: from .articulation_cfg import ArticulationCfg +from isaaclab.utils.leapp_semantics import leapp_tensor_semantics + # import logger logger = logging.getLogger(__name__) @@ -1057,6 +1059,7 @@ def set_external_force_and_torque( if self.uses_external_wrench_positions: self._external_wrench_positions_b.flatten(0, 1)[indices] = 0.0 + @leapp_tensor_semantics(kind="target/joint/position", element_names_source="joint_names") def set_joint_position_target( self, target: torch.Tensor, joint_ids: Sequence[int] | slice | None = None, env_ids: Sequence[int] | None = None ): @@ -1081,6 +1084,7 @@ def set_joint_position_target( # set targets self._data.joint_pos_target[env_ids, joint_ids] = target + @leapp_tensor_semantics(kind="target/joint/velocity", element_names_source="joint_names") def set_joint_velocity_target( self, target: torch.Tensor, joint_ids: Sequence[int] | slice | None = None, env_ids: Sequence[int] | None = None ): @@ -1105,6 +1109,7 @@ def set_joint_velocity_target( # set targets self._data.joint_vel_target[env_ids, joint_ids] = target + @leapp_tensor_semantics(kind="target/joint/effort", element_names_source="joint_names") def set_joint_effort_target( self, target: torch.Tensor, joint_ids: Sequence[int] | slice | None = None, env_ids: Sequence[int] | None = None ): diff --git a/source/isaaclab/isaaclab/assets/articulation/articulation_data.py b/source/isaaclab/isaaclab/assets/articulation/articulation_data.py index f1ab1d05586a..a627e91e9e3f 100644 --- a/source/isaaclab/isaaclab/assets/articulation/articulation_data.py +++ b/source/isaaclab/isaaclab/assets/articulation/articulation_data.py @@ -12,6 +12,7 @@ import isaaclab.utils.math as math_utils from isaaclab.utils.buffers import TimestampedBuffer +from isaaclab.utils.leapp_semantics import leapp_tensor_semantics # import logger logger = logging.getLogger(__name__) @@ -104,16 +105,16 @@ def update(self, dt: float): # Names. ## - body_names: list[str] = None + body_names: list[str] | None = None """Body names in the order parsed by the simulation view.""" - joint_names: list[str] = None + joint_names: list[str] | None = None """Joint names in the order parsed by the simulation view.""" - fixed_tendon_names: list[str] = None + fixed_tendon_names: list[str] | None = None """Fixed tendon names in the order parsed by the simulation view.""" - spatial_tendon_names: list[str] = None + spatial_tendon_names: list[str] | None = None """Spatial tendon names in the order parsed by the simulation view.""" ## @@ -732,7 +733,7 @@ def body_incoming_joint_wrench_b(self) -> torch.Tensor: if self._body_incoming_joint_wrench_b.timestamp < self._sim_timestamp: self._body_incoming_joint_wrench_b.data = self._root_physx_view.get_link_incoming_joint_force() - self._body_incoming_joint_wrench_b.time_stamp = self._sim_timestamp + self._body_incoming_joint_wrench_b.timestamp = self._sim_timestamp return self._body_incoming_joint_wrench_b.data ## @@ -740,6 +741,7 @@ def body_incoming_joint_wrench_b(self) -> torch.Tensor: ## @property + @leapp_tensor_semantics(kind="state/joint/position", element_names_source="joint_names") def joint_pos(self): """Joint positions of all joints. Shape is (num_instances, num_joints).""" if self._joint_pos.timestamp < self._sim_timestamp: @@ -749,6 +751,7 @@ def joint_pos(self): return self._joint_pos.data @property + @leapp_tensor_semantics(kind="state/joint/velocity", element_names_source="joint_names") def joint_vel(self): """Joint velocities of all joints. Shape is (num_instances, num_joints).""" if self._joint_vel.timestamp < self._sim_timestamp: @@ -774,6 +777,7 @@ def joint_acc(self): ## @property + @leapp_tensor_semantics(kind="state/body/projected_gravity", element_names_source="xyz") def projected_gravity_b(self): """Projection of the gravity direction on base frame. Shape is (num_instances, 3).""" return math_utils.quat_apply_inverse(self.root_link_quat_w, self.GRAVITY_VEC_W) @@ -997,16 +1001,19 @@ def body_com_quat_b(self) -> torch.Tensor: ## @property + @leapp_tensor_semantics(kind="state/body/pose", element_names_source="pose7") def root_pose_w(self) -> torch.Tensor: """Same as :attr:`root_link_pose_w`.""" return self.root_link_pose_w @property + @leapp_tensor_semantics(kind="state/body/position", element_names_source="xyz") def root_pos_w(self) -> torch.Tensor: """Same as :attr:`root_link_pos_w`.""" return self.root_link_pos_w @property + @leapp_tensor_semantics(kind="state/body/rotation", element_names_source="quat_wxyz") def root_quat_w(self) -> torch.Tensor: """Same as :attr:`root_link_quat_w`.""" return self.root_link_quat_w @@ -1017,26 +1024,31 @@ def root_vel_w(self) -> torch.Tensor: return self.root_com_vel_w @property + @leapp_tensor_semantics(kind="state/body/linear_velocity", element_names_source="xyz") def root_lin_vel_w(self) -> torch.Tensor: """Same as :attr:`root_com_lin_vel_w`.""" return self.root_com_lin_vel_w @property + @leapp_tensor_semantics(kind="state/body/angular_velocity", element_names_source="xyz") def root_ang_vel_w(self) -> torch.Tensor: """Same as :attr:`root_com_ang_vel_w`.""" return self.root_com_ang_vel_w @property + @leapp_tensor_semantics(kind="state/body/linear_velocity", element_names_source="xyz") def root_lin_vel_b(self) -> torch.Tensor: """Same as :attr:`root_com_lin_vel_b`.""" return self.root_com_lin_vel_b @property + @leapp_tensor_semantics(kind="state/body/angular_velocity", element_names_source="xyz") def root_ang_vel_b(self) -> torch.Tensor: """Same as :attr:`root_com_ang_vel_b`.""" return self.root_com_ang_vel_b @property + @leapp_tensor_semantics(kind="state/body/pose", element_names_source="body_pose") def body_pose_w(self) -> torch.Tensor: """Same as :attr:`body_link_pose_w`.""" return self.body_link_pose_w @@ -1047,6 +1059,7 @@ def body_pos_w(self) -> torch.Tensor: return self.body_link_pos_w @property + @leapp_tensor_semantics(kind="state/body/rotation", element_names_source="body_quat") def body_quat_w(self) -> torch.Tensor: """Same as :attr:`body_link_quat_w`.""" return self.body_link_quat_w diff --git a/source/isaaclab/isaaclab/envs/mdp/commands/pose_2d_command.py b/source/isaaclab/isaaclab/envs/mdp/commands/pose_2d_command.py index 82967fc409d8..13561a352fe6 100644 --- a/source/isaaclab/isaaclab/envs/mdp/commands/pose_2d_command.py +++ b/source/isaaclab/isaaclab/envs/mdp/commands/pose_2d_command.py @@ -60,6 +60,9 @@ def __init__(self, cfg: UniformPose2dCommandCfg, env: ManagerBasedEnv): self.metrics["error_pos"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_heading"] = torch.zeros(self.num_envs, device=self.device) + self.cfg.cmd_hint = self.cfg.cmd_hint or "command/body/pose" + self.cfg.element_names = self.cfg.element_names or ["x", "y", "z", "heading"] + def __str__(self) -> str: msg = "PositionCommand:\n" msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n" diff --git a/source/isaaclab/isaaclab/envs/mdp/commands/pose_command.py b/source/isaaclab/isaaclab/envs/mdp/commands/pose_command.py index 13503845b5a4..52a6d5888b90 100644 --- a/source/isaaclab/isaaclab/envs/mdp/commands/pose_command.py +++ b/source/isaaclab/isaaclab/envs/mdp/commands/pose_command.py @@ -67,6 +67,9 @@ def __init__(self, cfg: UniformPoseCommandCfg, env: ManagerBasedEnv): self.metrics["position_error"] = torch.zeros(self.num_envs, device=self.device) self.metrics["orientation_error"] = torch.zeros(self.num_envs, device=self.device) + self.cfg.cmd_hint = self.cfg.cmd_hint or "command/body/pose" + self.cfg.element_names = self.cfg.element_names or ["x", "y", "z", "qw", "qx", "qy", "qz"] + def __str__(self) -> str: msg = "UniformPoseCommand:\n" msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n" diff --git a/source/isaaclab/isaaclab/envs/mdp/commands/velocity_command.py b/source/isaaclab/isaaclab/envs/mdp/commands/velocity_command.py index e30fc90ef3a7..f30927fc113d 100644 --- a/source/isaaclab/isaaclab/envs/mdp/commands/velocity_command.py +++ b/source/isaaclab/isaaclab/envs/mdp/commands/velocity_command.py @@ -86,6 +86,9 @@ def __init__(self, cfg: UniformVelocityCommandCfg, env: ManagerBasedEnv): self.metrics["error_vel_xy"] = torch.zeros(self.num_envs, device=self.device) self.metrics["error_vel_yaw"] = torch.zeros(self.num_envs, device=self.device) + self.cfg.cmd_hint = self.cfg.cmd_hint or "command/body/velocity" + self.cfg.element_names = self.cfg.element_names or ["lin_vel_x", "lin_vel_y", "ang_vel_z"] + def __str__(self) -> str: """Return a string representation of the command generator.""" msg = "UniformVelocityCommand:\n" diff --git a/source/isaaclab/isaaclab/managers/manager_term_cfg.py b/source/isaaclab/isaaclab/managers/manager_term_cfg.py index 005c448a7c71..cbadb4e075b3 100644 --- a/source/isaaclab/isaaclab/managers/manager_term_cfg.py +++ b/source/isaaclab/isaaclab/managers/manager_term_cfg.py @@ -117,6 +117,9 @@ class CommandTermCfg: debug_vis: bool = False """Whether to visualize debug information. Defaults to False.""" + cmd_hint: str | None = None # type hint for the command for deployment + element_names: list[str] | list[list[str]] | None = None # element names for the command for deployment + ## # Curriculum manager. diff --git a/source/isaaclab/isaaclab/utils/leapp_semantics.py b/source/isaaclab/isaaclab/utils/leapp_semantics.py new file mode 100644 index 000000000000..b101cda1a991 --- /dev/null +++ b/source/isaaclab/isaaclab/utils/leapp_semantics.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""LEAPP semantic metadata helpers for raw tensor-producing functions.""" + +from __future__ import annotations + +from collections.abc import Callable +from contextlib import suppress +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True) +class LeappTensorSemantics: + """Semantic metadata attached directly to a raw tensor-producing function.""" + + kind: Any = None + element_names: list[str] | list[list[str]] | None = None + element_names_source: str | None = None + + +XYZ_ELEMENT_NAMES: list[str] = ["x", "y", "z"] +QUAT_WXYZ_ELEMENT_NAMES: list[str] = ["qw", "qx", "qy", "qz"] +POSE7_ELEMENT_NAMES: list[str] = ["x", "y", "z", "qw", "qx", "qy", "qz"] + + +def leapp_tensor_semantics( + *, + kind: Any = None, + element_names: list[str] | list[list[str]] | None = None, + element_names_source: str | None = None, +) -> Callable: + """Attach LEAPP semantic metadata to a raw tensor-producing function.""" + + semantics = LeappTensorSemantics( + kind=kind, + element_names=element_names, + element_names_source=element_names_source, + ) + + def _apply(func: Callable) -> Callable: + func._leapp_semantics = semantics + return func + + return _apply + + +def _select_element_names(names: list[str] | None, indices: Any = None) -> list[str] | None: + """Select element names using optional runtime indices.""" + if names is None: + return None + if indices is None or indices == slice(None): + return list(names) + if isinstance(indices, slice): + return list(names[indices]) + with suppress(AttributeError): + indices = indices.tolist() + if isinstance(indices, (list, tuple)): + return [names[int(index)] for index in indices] + if isinstance(indices, int): + return [names[indices]] + return None + + +def resolve_leapp_element_names(semantics: LeappTensorSemantics | None, data_self) -> list | None: + """Resolve element names from attached semantics and a tensor-producing object.""" + if semantics is None: + return None + if semantics.element_names is not None: + return semantics.element_names + + source = semantics.element_names_source + if source == "joint_names": + return _select_element_names(getattr(data_self, "joint_names", None), getattr(data_self, "_joint_ids", None)) + if source == "body_names": + return _select_element_names(getattr(data_self, "body_names", None), getattr(data_self, "_body_ids", None)) + if source == "body_pose": + body_names = _select_element_names( + getattr(data_self, "body_names", None), getattr(data_self, "_body_ids", None) + ) + if body_names is None: + return None + return [body_names, POSE7_ELEMENT_NAMES] + if source == "body_quat": + body_names = _select_element_names( + getattr(data_self, "body_names", None), getattr(data_self, "_body_ids", None) + ) + if body_names is None: + return None + return [body_names, QUAT_WXYZ_ELEMENT_NAMES] + if source == "pose7": + return POSE7_ELEMENT_NAMES + if source == "xyz": + return XYZ_ELEMENT_NAMES + if source == "quat_wxyz": + return QUAT_WXYZ_ELEMENT_NAMES + return None diff --git a/source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/dexsuite/mdp/commands/pose_commands.py b/source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/dexsuite/mdp/commands/pose_commands.py index 59ca92be13fd..4b337824186f 100644 --- a/source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/dexsuite/mdp/commands/pose_commands.py +++ b/source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/dexsuite/mdp/commands/pose_commands.py @@ -78,6 +78,9 @@ def __init__(self, cfg: dex_cmd_cfgs.ObjectUniformPoseCommandCfg, env: ManagerBa self.success_visualizer = VisualizationMarkers(self.cfg.success_visualizer_cfg) self.success_visualizer.set_visibility(True) + self.cfg.cmd_hint = self.cfg.cmd_hint or "command/body/pose" + self.cfg.element_names = self.cfg.element_names or ["x", "y", "z", "qw", "qx", "qy", "qz"] + def __str__(self) -> str: msg = "UniformPoseCommand:\n" msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n" diff --git a/source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/inhand/mdp/commands/orientation_command.py b/source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/inhand/mdp/commands/orientation_command.py index 73dd68fead0b..83f5561fe0ed 100644 --- a/source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/inhand/mdp/commands/orientation_command.py +++ b/source/isaaclab_tasks/isaaclab_tasks/manager_based/manipulation/inhand/mdp/commands/orientation_command.py @@ -73,6 +73,9 @@ def __init__(self, cfg: InHandReOrientationCommandCfg, env: ManagerBasedRLEnv): self.metrics["position_error"] = torch.zeros(self.num_envs, device=self.device) self.metrics["consecutive_success"] = torch.zeros(self.num_envs, device=self.device) + self.cfg.cmd_hint = self.cfg.cmd_hint or "command/body/pose" + self.cfg.element_names = self.cfg.element_names or ["x", "y", "z", "qw", "qx", "qy", "qz"] + def __str__(self) -> str: msg = "InHandManipulationCommandGenerator:\n" msg += f"\tCommand dimension: {tuple(self.command.shape[1:])}\n" From 6d4e3b8129e8d975b4d9f9d61345a98aabbbf468 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Wed, 11 Mar 2026 01:11:25 -0700 Subject: [PATCH 18/23] updated annotations to use local patching again --- .../reinforcement_learning/rsl_rl/export.py | 5 +- .../rsl_rl/export_annotator.py | 839 ++++++++++-------- 2 files changed, 467 insertions(+), 377 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index c71aa5fbbad3..4e3420348581 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -60,7 +60,7 @@ import time import torch -from export_annotator import ExportAnnotator +from export_annotator import patch_env_for_export from rsl_rl.runners import DistillationRunner, OnPolicyRunner from isaaclab.envs import ( @@ -126,8 +126,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # create isaac environment # Note: observation functions are already patched at module level (before isaaclab_tasks import) env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) - annotator = ExportAnnotator(env, task_name=task_name) - annotator.setup() + patch_env_for_export(env, task_name=task_name) # convert to single-agent instance if required by the RL algorithm if isinstance(env.unwrapped, DirectMARLEnv): diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index 8de5387207b1..beb17e57c31e 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -3,14 +3,24 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Export annotations for Isaac Lab policies using instance-level patching.""" +"""Export annotations for Isaac Lab policies using proxy-based patching. +Observation and action annotation is achieved by routing calls through lightweight +proxy objects rather than globally patching class methods. This means: + +- Observation term functions see an _EnvProxy whose scene returns _ArticulationProxy + objects with annotating data getters. The real env / data is shared and unmodified. + +- Action terms have their ``_asset`` attribute replaced with an _ArticulationWriteProxy + that intercepts ``_leapp_semantics``-decorated write methods and records outputs. + The real Articulation class is never patched. +""" from __future__ import annotations import inspect import torch -from dataclasses import dataclass, field +from contextlib import suppress from typing import TYPE_CHECKING from leapp import annotate @@ -24,318 +34,465 @@ from isaaclab.envs import ManagerBasedEnv -# class ObservationPatcher: +# ══════════════════════════════════════════════════════════════════ +# Observation-side proxies +# ══════════════════════════════════════════════════════════════════ -# class ActionPatcher: +class _ArticulationDataProxy: + """Proxy around a real ArticulationData that intercepts annotated property reads. -@dataclass -class ExportAnnotator: - """Encapsulates all leapp annotation logic for exporting Isaac Lab policies. + For properties whose getter carries ``_leapp_semantics``, the proxy calls + the annotating getter (which records the tensor with LEAPP) and caches the + result for deduplication within a single ``compute_group`` call. - Usage: - env = gym.make(...) - annotator = ExportAnnotator(env) - annotator.setup() - # ... run policy ... - annotator.cleanup() + All other attribute access is forwarded transparently to the real object. """ - env: ManagerBasedEnv - task_name: str - - # Original methods for restoration - _original_compute_group: callable = field(default=None, repr=False) - _original_process_action: callable = field(default=None, repr=False) - _original_apply_action: callable = field(default=None, repr=False) - - # ArticulationData patching state - _articulation_originals: dict[str, property] = field(default_factory=dict, repr=False) - _articulation_annotating: dict[str, property] = field(default_factory=dict, repr=False) - _annotations_active: bool = field(default=False, repr=False) - - # Action writer patching state - _action_write_originals: dict[str, callable] = field(default_factory=dict, repr=False) - _action_write_annotating: dict[str, callable] = field(default_factory=dict, repr=False) - _action_write_annotations_active: bool = field(default=False, repr=False) - - # Cache for annotated tensors within a single compute_group call - # Prevents duplicate input tensors when same property is accessed multiple times - _annotated_tensor_cache: dict[str, torch.Tensor] = field(default_factory=dict, repr=False) - - _action_output_cache: list[TensorSemantics] = field(default_factory=list, repr=False) - _active_action_term_name: str | None = field(default=None, repr=False) - _pending_action_output_export: bool = field(default=False, repr=False) - - def setup(self): - """Set up all annotations. Call after env is created.""" - self._setup_observation_annotations() - self._prepare_action_write_annotations() - self._patch_action_manager() - - def cleanup(self): - """Restore all original methods and properties.""" - self._restore_observation_annotations() - self._restore_action_manager() - self._remove_action_write_annotations() - - # ────────────────────────────────────────────────────────────────── - # Observation Annotations - # ────────────────────────────────────────────────────────────────── - - def _setup_observation_annotations(self): - """Set up all observation-side annotations.""" - self._disable_observation_noise() - self._prepare_articulation_annotations() - self._patch_observation_functions() - self._patch_observation_manager() - - def _restore_observation_annotations(self): - """Restore all observation-side patches and temporary annotations.""" - self._restore_observation_functions() - self._restore_observation_manager() - self._remove_articulation_annotations() - - def _disable_observation_noise(self): - """Disable noise/corruption for deterministic export. - - Since we patch after env creation, we need to set term_cfg.noise = None - directly on each term config (not just the group config). - """ - obs_manager = self.env.env.unwrapped.observation_manager + def __init__(self, real_data: ArticulationData, annotating_getters: dict[str, callable], cache: dict): + object.__setattr__(self, "_real_data", real_data) + object.__setattr__(self, "_annotating_getters", annotating_getters) + object.__setattr__(self, "_cache", cache) + + def __getattr__(self, name): + """Intercept annotated properties; forward everything else.""" + getters = object.__getattribute__(self, "_annotating_getters") + if name in getters: + cache = object.__getattribute__(self, "_cache") + if name in cache: + return cache[name].clone() + real_data = object.__getattribute__(self, "_real_data") + result = getters[name](real_data) + cache[name] = result + return result + return getattr(object.__getattribute__(self, "_real_data"), name) - # Disable noise on each individual term config - for _, term_cfgs in obs_manager._group_obs_term_cfgs.items(): - for term_cfg in term_cfgs: - term_cfg.noise = None - def _patch_observation_functions(self): - """Patch observation functions inside the observation manager's term configs. +class _ArticulationProxy: + """Proxy around a real Articulation that returns _ArticulationDataProxy for ``.data``. + + All other attribute access is forwarded transparently to the real asset. + """ + + def __init__(self, real_asset: Articulation, data_proxy: _ArticulationDataProxy): + object.__setattr__(self, "_real_asset", real_asset) + object.__setattr__(self, "_data_proxy", data_proxy) + + @property + def data(self): + """Return the annotating data proxy instead of the real ArticulationData.""" + return object.__getattribute__(self, "_data_proxy") + + def __getattr__(self, name): + """Forward all non-data attribute access to the real asset.""" + return getattr(object.__getattribute__(self, "_real_asset"), name) + + +class _SceneProxy: + """Proxy around the real InteractiveScene. + + When an observation term looks up an asset by name, this proxy lazily wraps + Articulation entities in _ArticulationProxy so their data getters annotate. + Non-Articulation entities are returned as-is. + """ + + def __init__(self, real_scene, annotating_getters: dict[str, callable], cache: dict): + object.__setattr__(self, "_real_scene", real_scene) + object.__setattr__(self, "_annotating_getters", annotating_getters) + object.__setattr__(self, "_cache", cache) + object.__setattr__(self, "_proxied", {}) + + def __getitem__(self, key): + """Return an ArticulationProxy for Articulation entities, real entity otherwise.""" + proxied = object.__getattribute__(self, "_proxied") + if key in proxied: + return proxied[key] + real_scene = object.__getattribute__(self, "_real_scene") + entity = real_scene[key] + if isinstance(entity, Articulation): + getters = object.__getattribute__(self, "_annotating_getters") + cache = object.__getattribute__(self, "_cache") + data_proxy = _ArticulationDataProxy(entity.data, getters, cache) + proxy = _ArticulationProxy(entity, data_proxy) + proxied[key] = proxy + return proxy + return entity + + def __getattr__(self, name): + """Forward all other scene access to the real scene.""" + return getattr(object.__getattribute__(self, "_real_scene"), name) + + +class _EnvProxy: + """Proxy around the real env that returns a _SceneProxy for ``.scene``. + + All other attribute access (``num_envs``, ``command_manager``, etc.) + is forwarded transparently to the real env. + """ + + def __init__(self, real_env, scene_proxy: _SceneProxy): + object.__setattr__(self, "_real_env", real_env) + object.__setattr__(self, "_scene_proxy", scene_proxy) + + @property + def scene(self): + """Return the scene proxy instead of the real scene.""" + return object.__getattribute__(self, "_scene_proxy") + + def __getattr__(self, name): + """Forward all non-scene attribute access to the real env.""" + return getattr(object.__getattribute__(self, "_real_env"), name) + + +# ══════════════════════════════════════════════════════════════════ +# Action-side proxy +# ══════════════════════════════════════════════════════════════════ + + +class _ArticulationWriteProxy: + """Proxy around a real Articulation that intercepts ``_leapp_semantics`` write methods. + + When an action term calls e.g. ``self._asset.set_joint_position_target(target, joint_ids)``, + this proxy: + + 1. Calls the real method on the real asset (so the simulation sees the write). + 2. Snapshots the ``target`` tensor and records a ``TensorSemantics`` entry in the + shared output cache. + + All other attribute access is forwarded transparently to the real asset. + """ + + def __init__( + self, + real_asset: Articulation, + term_name: str, + output_cache: list[TensorSemantics], + annotating_methods: dict[str, callable], + ): + object.__setattr__(self, "_real_asset", real_asset) + object.__setattr__(self, "_term_name", term_name) + object.__setattr__(self, "_output_cache", output_cache) + object.__setattr__(self, "_annotating_methods", annotating_methods) + + def __getattr__(self, name): + """Return an annotating wrapper for _leapp_semantics methods; forward everything else.""" + methods = object.__getattribute__(self, "_annotating_methods") + if name in methods: + real_asset = object.__getattribute__(self, "_real_asset") + term_name = object.__getattribute__(self, "_term_name") + output_cache = object.__getattribute__(self, "_output_cache") + original_method = getattr(real_asset, name) + return methods[name](real_asset, original_method, term_name, output_cache) + return getattr(object.__getattribute__(self, "_real_asset"), name) + + +# ══════════════════════════════════════════════════════════════════ +# ObservationPatcher +# ══════════════════════════════════════════════════════════════════ + + +class ObservationPatcher: + """Permanently patches observation term functions to annotate their inputs via proxies. + + Instead of globally patching ArticulationData properties and toggling them on/off, + this scans for ``_leapp_semantics``-decorated properties once, builds annotating + getters, and routes observation term calls through lightweight proxy objects that + share the same underlying env and tensor state. + """ + + def __init__(self, task_name: str): + self.task_name = task_name + self._annotated_tensor_cache: dict[str, torch.Tensor] = {} - These functions (last_action, generated_commands) don't access ArticulationData - properties, so they need separate patching to record their mappings and annotate - their outputs. + def setup(self, obs_manager): + """Patch all observation terms to use annotating proxies. - We patch the term_cfg.func directly because the observation manager stores - references to these functions at creation time. + For each term in the observation manager: + - Normal terms get their ``env`` argument swapped for a proxy env. + - ``last_action`` and ``generated_commands`` get dedicated wrappers. + - Noise is disabled on every term for deterministic export. + + A thin wrapper on ``compute_group`` clears the dedup cache between calls. """ - obs_manager = self.env.env.unwrapped.observation_manager + real_env = obs_manager._env - # Store original functions for restoration: (group_name, term_idx) -> original_func - self._original_obs_funcs: dict[tuple[str, int], callable] = {} + annotating_getters = self._build_annotating_getters() + scene_proxy = _SceneProxy(real_env.scene, annotating_getters, self._annotated_tensor_cache) + proxy_env = _EnvProxy(real_env, scene_proxy) - # find and patch all other known non-articulation data properties for group_name, term_cfgs in obs_manager._group_obs_term_cfgs.items(): - for term_idx, term_cfg in enumerate(term_cfgs): + for term_cfg in term_cfgs: original_func = term_cfg.func func_name = getattr(original_func, "__name__", None) if func_name == "last_action": - self._original_obs_funcs[(group_name, term_idx)] = original_func - term_cfg.func = self._make_patched_last_action(original_func) - + term_cfg.func = self._wrap_last_action(original_func) elif func_name == "generated_commands": - self._original_obs_funcs[(group_name, term_idx)] = original_func - term_cfg.func = self._make_patched_generated_commands(original_func, term_cfg) + term_cfg.func = self._wrap_generated_commands(original_func, term_cfg) + else: + term_cfg.func = self._wrap_with_proxy(original_func, proxy_env) + + term_cfg.noise = None + + original_compute_group = obs_manager.compute_group + cache = self._annotated_tensor_cache + + def patched_compute_group(*args, **kwargs): + """Clear the tensor dedup cache, then run the real compute_group.""" + cache.clear() + return original_compute_group(*args, **kwargs) + + obs_manager.compute_group = patched_compute_group + + # ── Scanning ────────────────────────────────────────────────── + + def _build_annotating_getters(self) -> dict[str, callable]: + """Scan ArticulationData for ``_leapp_semantics`` properties and build annotating getters. + + Returns a dict mapping property name to a callable ``(data_self) -> annotated_tensor``. + """ + getters: dict[str, callable] = {} + for prop_name in dir(ArticulationData): + prop = getattr(ArticulationData, prop_name, None) + if isinstance(prop, property) and prop.fget and hasattr(prop.fget, "_leapp_semantics"): + getters[prop_name] = self._make_annotating_getter(prop.fget, prop_name) + return getters + + def _make_annotating_getter(self, original_fget, prop_name: str): + """Create an annotating getter callable for a single ArticulationData property. + + The returned callable invokes the real getter, then registers the result + as a LEAPP input tensor with the property's semantic metadata. + """ + task_name = self.task_name + + def getter(data_self): + result = original_fget(data_self) + if not isinstance(result, torch.Tensor): + return result + semantics_meta = getattr(original_fget, "_leapp_semantics", None) + sem = TensorSemantics( + name=prop_name, + ref=result, + kind=semantics_meta.kind if semantics_meta else None, + element_names=resolve_leapp_element_names(semantics_meta, data_self), + ) + return annotate.input_tensors(task_name, sem) + + return getter + + # ── Term wrappers ───────────────────────────────────────────── + + @staticmethod + def _wrap_with_proxy(original_func, proxy_env): + """Wrap a term function so it receives the proxy env instead of the real env. + + This is the generic wrapper for observation terms that read ArticulationData + properties. By substituting the env, the entire downstream chain + (env.scene[name].data.property) goes through the proxy. + """ + + def wrapped(env, **kwargs): + return original_func(proxy_env, **kwargs) - def _make_patched_last_action(self, original_func): - """Create a patched version of last_action for LEAPP tracing.""" + wrapped.__name__ = getattr(original_func, "__name__", "unknown") + return wrapped - def patched_last_action(env, action_name=None, **kwargs): + def _wrap_last_action(self, original_func): + """Wrap the ``last_action`` observation term to annotate its output as a LEAPP input.""" + task_name = self.task_name + + def wrapped(env, action_name=None, **kwargs): result = original_func(env, action_name, **kwargs) - result = annotate.input_tensors(self.task_name, {"last_actions": result}) - return result + return annotate.input_tensors(task_name, {"last_actions": result}) - patched_last_action.__name__ = original_func.__name__ - return patched_last_action + wrapped.__name__ = original_func.__name__ + return wrapped - def _make_patched_generated_commands(self, original_func, term_cfg): - """Create a patched version of generated_commands for LEAPP tracing.""" - # Get the command_name from term_cfg.params if available + def _wrap_generated_commands(self, original_func, term_cfg): + """Wrap the ``generated_commands`` observation term to annotate its output as a LEAPP input. + + Resolves command semantics (kind, element_names) from the command manager + configuration when available. + """ + task_name = self.task_name command_name_from_cfg = term_cfg.params.get("command_name") - def patched_generated_commands(env, command_name=None, **kwargs): + def wrapped(env, command_name=None, **kwargs): result = original_func(env, command_name, **kwargs) - # Use command_name parameter, or fall back to config, or default leapp_input_name = command_name or command_name_from_cfg or "commands" command_cfg = None - try: + with suppress(AttributeError, KeyError): command_cfg = env.command_manager.get_term(leapp_input_name).cfg - except (AttributeError, KeyError): - # Keep export working even if the observation term doesn't point to a registered command term. - command_cfg = None - - semantics = TensorSemantics( + sem = TensorSemantics( name=leapp_input_name, ref=result, kind=getattr(command_cfg, "cmd_hint", None), element_names=getattr(command_cfg, "element_names", None), ) - result = annotate.input_tensors(self.task_name, semantics) - return result + return annotate.input_tensors(task_name, sem) - patched_generated_commands.__name__ = original_func.__name__ - return patched_generated_commands + wrapped.__name__ = original_func.__name__ + return wrapped - def _restore_observation_functions(self): - """Restore original observation functions in term configs.""" - if not hasattr(self, "_original_obs_funcs"): - return - obs_manager = self.env.env.unwrapped.observation_manager +# ══════════════════════════════════════════════════════════════════ +# ActionPatcher +# ══════════════════════════════════════════════════════════════════ - for (group_name, term_idx), original_func in self._original_obs_funcs.items(): - obs_manager._group_obs_term_cfgs[group_name][term_idx].func = original_func - def _patch_observation_manager(self): - """Patch the observation manager instance's compute_group method.""" - obs_manager = self.env.env.unwrapped.observation_manager - self._original_compute_group = obs_manager.compute_group +class ActionPatcher: + """Permanently patches action terms to annotate their outputs via proxies. - def patched_compute_group(*args, **kwargs): - self._apply_articulation_annotations() - try: - return self._original_compute_group(*args, **kwargs) - finally: - self._remove_articulation_annotations() + 1. Scans Articulation for ``_leapp_semantics``-decorated methods once. + 2. Replaces each action term's ``_asset`` with an ``_ArticulationWriteProxy`` that + intercepts those methods and records output semantics. + 3. Patches ``process_action`` and ``apply_action`` on the action manager instance + to coordinate buffer registration and the single ``annotate.output_tensors`` call. - obs_manager.compute_group = patched_compute_group + No Articulation class methods are ever modified. + """ - def _restore_observation_manager(self): - """Restore original compute_group method.""" - if self._original_compute_group: - self.env.env.unwrapped.observation_manager.compute_group = self._original_compute_group + def __init__(self, task_name: str): + self.task_name = task_name + self._action_output_cache: list[TensorSemantics] = [] + self._pending_action_output_export: bool = False - # ────────────────────────────────────────────────────────────────── - # ArticulationData Property Annotations - # ────────────────────────────────────────────────────────────────── + def setup(self, action_manager): + """Patch all action terms and the action manager for LEAPP annotation. - def _prepare_articulation_annotations(self): - """Prepare annotating versions of ArticulationData properties.""" - for prop_name in self._get_semantic_articulation_properties(): - original_prop = getattr(ArticulationData, prop_name, None) - if not isinstance(original_prop, property) or original_prop.fget is None: - continue + For each action term with an Articulation asset, replaces ``term._asset`` + with an ``_ArticulationWriteProxy``. Then patches ``process_action`` and + ``apply_action`` on the manager instance. + """ + annotating_methods = self._build_annotating_write_methods() - self._articulation_originals[prop_name] = original_prop - self._articulation_annotating[prop_name] = self._make_annotating_property(original_prop, prop_name) + for term_name, term in action_manager._terms.items(): + asset = getattr(term, "_asset", None) + if isinstance(asset, Articulation): + term._asset = _ArticulationWriteProxy( + real_asset=asset, + term_name=term_name, + output_cache=self._action_output_cache, + annotating_methods=annotating_methods, + ) - def _make_annotating_property(self, original: property, prop_name: str) -> property: - """Create an annotating version of an ArticulationData property.""" - original_fget = original.fget - assert original_fget is not None # Checked in _prepare_articulation_annotations + self._patch_manager_methods(action_manager) - def annotating_fget(data_self): - result = original_fget(data_self) + # ── Scanning ────────────────────────────────────────────────── - if isinstance(result, torch.Tensor): - # Check if this property was already annotated in this compute_group call - if prop_name in self._annotated_tensor_cache: - # Return a clone of the cached tensor to avoid duplicate input annotations - return self._annotated_tensor_cache[prop_name].clone() + def _build_annotating_write_methods(self) -> dict[str, callable]: + """Scan Articulation for ``_leapp_semantics`` methods and build interceptors. - sem = self._make_property_semantics(prop_name, data_self, result) + Returns a dict mapping method name to a factory callable. The factory takes + ``(real_asset, original_bound_method, term_name, output_cache)`` and returns + a callable that the proxy returns in ``__getattr__``. + """ + methods: dict[str, callable] = {} + for method_name in dir(Articulation): + method = getattr(Articulation, method_name, None) + if callable(method) and hasattr(method, "_leapp_semantics"): + methods[method_name] = self._make_write_interceptor_factory(method, method_name) + return methods - # First access - annotate and cache - result = annotate.input_tensors(self.task_name, sem) - self._annotated_tensor_cache[prop_name] = result + def _make_write_interceptor_factory(self, original_unbound, method_name: str): + """Create a factory that produces bound annotating wrappers for a single write method. - return result + The factory is called by ``_ArticulationWriteProxy.__getattr__`` each time the + method is accessed. It returns a callable that: + + 1. Calls the real method on the real asset. + 2. Inspects the ``target`` argument. + 3. Records a ``TensorSemantics`` entry in the shared output cache. + """ + signature = inspect.signature(original_unbound) + semantics = getattr(original_unbound, "_leapp_semantics", None) + + def factory(real_asset: Articulation, original_bound, term_name: str, output_cache: list): + + def interceptor(*args, **kwargs): + result = original_bound(*args, **kwargs) + bound_args = signature.bind_partial(real_asset, *args, **kwargs) + target = bound_args.arguments.get("target") + + if isinstance(target, torch.Tensor): + tensor_target: torch.Tensor = target + output_name = _unique_output_name(term_name, method_name, output_cache) + joint_ids = bound_args.arguments.get("joint_ids") + output_cache.append( + TensorSemantics( + name=output_name, + ref=tensor_target.clone(), + kind=semantics.kind if semantics is not None else None, + element_names=resolve_leapp_element_names( + semantics, + _JointNameContext(real_asset.joint_names, joint_ids), + ), + ) + ) + + return result + + return interceptor + + return factory + + # ── Manager patches ─────────────────────────────────────────── + + def _patch_manager_methods(self, action_manager): + """Patch ``process_action`` and ``apply_action`` on the action manager instance. + + ``process_action`` registers raw_action buffers for LEAPP tracing and + preserves the action tensor clone. - return property(fget=annotating_fget, fset=original.fset, fdel=original.fdel, doc=original.__doc__) - - def _apply_articulation_annotations(self): - """Temporarily apply annotating properties.""" - if not self._annotations_active: - # Clear the tensor cache at the start of each compute_group call - self._annotated_tensor_cache.clear() - for prop_name, prop in self._articulation_annotating.items(): - setattr(ArticulationData, prop_name, prop) - self._annotations_active = True - - def _remove_articulation_annotations(self): - """Restore original properties.""" - if self._annotations_active: - for prop_name, prop in self._articulation_originals.items(): - setattr(ArticulationData, prop_name, prop) - self._annotations_active = False - # Clear the tensor cache when done - self._annotated_tensor_cache.clear() - - # ────────────────────────────────────────────────────────────────── - # Action Manager - # ────────────────────────────────────────────────────────────────── - - def _patch_action_manager(self): - """Patch the action manager instance's action processing methods.""" - action_manager = self.env.env.unwrapped.action_manager - self._original_process_action = action_manager.process_action - self._original_apply_action = action_manager.apply_action + ``apply_action`` coordinates the output cache lifecycle: clear before, + collect and export after. + """ + original_process = action_manager.process_action + original_apply = action_manager.apply_action + task_name = self.task_name def patched_process_action(action: torch.Tensor): - # Register raw_actions buffers for tracing + """Register raw_action buffers, call real process_action, preserve action clone.""" for term_name, term in action_manager._terms.items(): if hasattr(term, "_raw_actions") and term._raw_actions is not None: - term._raw_actions = annotate.register_buffer(self.task_name, {"raw_actions": term._raw_actions}) + term._raw_actions = annotate.register_buffer(task_name, {"raw_actions": term._raw_actions}) - self._original_process_action(action) - # this is stored differently inside the original process action method that would loose tracing. this step preserves it. + original_process(action) action_manager._action = action.clone() self._pending_action_output_export = True def patched_apply_action(): + """Clear cache, call real apply_action, collect outputs, call annotate.output_tensors.""" if not self._pending_action_output_export: - return self._original_apply_action() + return original_apply() - original_term_apply_actions: dict[str, callable] = {} self._action_output_cache.clear() - self._apply_action_write_annotations() - - try: - for term_name, term in action_manager._terms.items(): - original_term_apply_actions[term_name] = term.apply_actions - term.apply_actions = self._make_patched_term_apply_actions(term.apply_actions, term_name) - - self._original_apply_action() - - self._action_output_cache.extend(self._collect_action_outputs(action_manager)) - self._action_output_cache.append(TensorSemantics(name="last_action", ref=action_manager._action)) - static_values = self._collect_action_static_outputs(action_manager) - annotate.output_tensors( - self.task_name, - self._action_output_cache, - static_outputs=static_values, - export_with="onnx", - ) - self._pending_action_output_export = False - finally: - for term_name, original_apply_actions in original_term_apply_actions.items(): - action_manager._terms[term_name].apply_actions = original_apply_actions - self._active_action_term_name = None - self._remove_action_write_annotations() - self._action_output_cache.clear() + original_apply() + + self._action_output_cache.extend(self._collect_action_outputs(action_manager)) + self._action_output_cache.append(TensorSemantics(name="last_action", ref=action_manager._action)) + static_values = self._collect_action_static_outputs(action_manager) + annotate.output_tensors( + task_name, + self._action_output_cache, + static_outputs=static_values, + export_with="onnx", + ) + self._pending_action_output_export = False + self._action_output_cache.clear() + return None action_manager.process_action = patched_process_action action_manager.apply_action = patched_apply_action - def _make_patched_term_apply_actions(self, original_func, term_name: str): - """Wrap an action term's apply call to keep the current term context.""" + # ── Output collection ───────────────────────────────────────── - def patched_apply_actions(): - self._active_action_term_name = term_name - try: - return original_func() - finally: - self._active_action_term_name = None - - return patched_apply_actions - - def _collect_action_outputs(self, action_manager) -> list[TensorSemantics]: - """Collect non-writer action tensors that should still be exported.""" + @staticmethod + def _collect_action_outputs(action_manager) -> list[TensorSemantics]: + """Collect non-writer action tensors that should be exported (e.g. OSC dynamic gains).""" tensors: list[TensorSemantics] = [] - for term_name, term in action_manager._terms.items(): - # Handle variable impedance (dynamic gains) osc = getattr(term, "_osc", None) if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: tensors.append( @@ -354,149 +511,83 @@ def _collect_action_outputs(self, action_manager) -> list[TensorSemantics]: ) return tensors - def _collect_action_static_outputs(self, action_manager) -> dict: - """Collect static values from action terms.""" - static_values = {} + @staticmethod + def _collect_action_static_outputs(action_manager) -> dict: + """Collect static kp/kd gain values from action terms for export metadata.""" + static_values: dict = {} for term_name, term in action_manager._terms.items(): osc = getattr(term, "_osc", None) if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: continue asset = getattr(term, "_asset", None) - if asset and hasattr(asset, "data"): - self._collect_static_gains(term_name, asset.data, getattr(term, "_joint_ids", None), static_values) + real_asset = getattr(asset, "_real_asset", asset) + if real_asset and hasattr(real_asset, "data"): + data = real_asset.data + joint_ids = getattr(term, "_joint_ids", None) + if hasattr(data, "default_joint_stiffness") and data.default_joint_stiffness is not None: + gains = data.default_joint_stiffness + static_values[f"{term_name}_kp_gains"] = gains[:, joint_ids] if joint_ids else gains + if hasattr(data, "default_joint_damping") and data.default_joint_damping is not None: + gains = data.default_joint_damping + static_values[f"{term_name}_kd_gains"] = gains[:, joint_ids] if joint_ids else gains return static_values - def _collect_static_gains(self, term_name: str, data, joint_ids, static_values: dict): - """Extract static kp/kd gains from asset data.""" - if hasattr(data, "default_joint_stiffness") and data.default_joint_stiffness is not None: - gains = data.default_joint_stiffness - static_values[f"{term_name}_kp_gains"] = gains[:, joint_ids] if joint_ids else gains - - if hasattr(data, "default_joint_damping") and data.default_joint_damping is not None: - gains = data.default_joint_damping - static_values[f"{term_name}_kd_gains"] = gains[:, joint_ids] if joint_ids else gains - - def _restore_action_manager(self): - """Restore original action manager methods.""" - if self._original_process_action: - self.env.env.unwrapped.action_manager.process_action = self._original_process_action - if self._original_apply_action: - self.env.env.unwrapped.action_manager.apply_action = self._original_apply_action - - # ────────────────────────────────────────────────────────────────── - # Action Write Annotations - # ────────────────────────────────────────────────────────────────── - - def _prepare_action_write_annotations(self): - """Prepare annotating versions of low-level action writer methods.""" - for method_name in self._get_semantic_action_write_methods(): - original_method = getattr(Articulation, method_name, None) - if original_method is None: - continue - self._action_write_originals[method_name] = original_method - self._action_write_annotating[method_name] = self._make_annotating_action_write_method( - original_method, method_name - ) +# ══════════════════════════════════════════════════════════════════ +# Helpers +# ══════════════════════════════════════════════════════════════════ - def _make_annotating_action_write_method(self, original_func, method_name: str): - """Create an annotating version of a low-level action writer.""" - signature = inspect.signature(original_func) - def annotating_method(asset_self, *args, **kwargs): - result = original_func(asset_self, *args, **kwargs) - bound_args = signature.bind_partial(asset_self, *args, **kwargs) - target = bound_args.arguments.get("target") +class _JointNameContext: + """Lightweight stand-in for resolving runtime joint name subsets in ``resolve_leapp_element_names``.""" - if not isinstance(target, torch.Tensor): - return result + __slots__ = ("joint_names", "_joint_ids") - output_name = self._get_action_output_name(method_name) - semantics = getattr(self._action_write_originals[method_name], "_leapp_semantics", None) - joint_ids = bound_args.arguments.get("joint_ids") - tensor_target: torch.Tensor = target - target_snapshot = tensor_target.clone() - self._action_output_cache.append( - TensorSemantics( - name=output_name, - ref=target_snapshot, - kind=semantics.kind if semantics is not None else None, - element_names=resolve_leapp_element_names( - semantics, self._make_joint_name_context(asset_self, joint_ids) - ), - ) - ) + def __init__(self, joint_names: list[str], joint_ids): + self.joint_names = joint_names + self._joint_ids = joint_ids - return result - annotating_method.__name__ = original_func.__name__ - return annotating_method - - def _apply_action_write_annotations(self): - """Temporarily apply annotating action writer methods.""" - if not self._action_write_annotations_active: - for method_name, method in self._action_write_annotating.items(): - setattr(Articulation, method_name, method) - self._action_write_annotations_active = True - - def _remove_action_write_annotations(self): - """Restore original low-level action writer methods.""" - if self._action_write_annotations_active: - for method_name, method in self._action_write_originals.items(): - setattr(Articulation, method_name, method) - self._action_write_annotations_active = False - - # ────────────────────────────────────────────────────────────────── - # Helpers - # ────────────────────────────────────────────────────────────────── - - def _get_semantic_action_write_methods(self) -> frozenset[str]: - """Collect Articulation methods that advertise LEAPP semantics.""" - methods = set() - for method_name in dir(Articulation): - method = getattr(Articulation, method_name, None) - if callable(method) and hasattr(method, "_leapp_semantics"): - methods.add(method_name) - return frozenset(methods) - - def _get_action_output_name(self, method_name: str) -> str: - """Return a stable output name for the current action write.""" - base_name = self._active_action_term_name or method_name - output_name = base_name - existing_names = {tensor.name for tensor in self._action_output_cache} - if output_name in existing_names: - output_name = f"{base_name}_{method_name}" - suffix = 2 - while output_name in existing_names: - output_name = f"{base_name}_{method_name}_{suffix}" - suffix += 1 - return output_name - - def _make_joint_name_context(self, asset_self: Articulation, joint_ids): - """Create a lightweight context for resolving runtime joint name subsets.""" - return type( - "JointNameContext", - (), - {"joint_names": asset_self.joint_names, "_joint_ids": joint_ids}, - )() - - def _get_semantic_articulation_properties(self) -> frozenset[str]: - """Collect ArticulationData properties that advertise LEAPP semantics.""" - properties = set() - for prop_name in dir(ArticulationData): - prop = getattr(ArticulationData, prop_name, None) - if isinstance(prop, property) and prop.fget is not None and hasattr(prop.fget, "_leapp_semantics"): - properties.add(prop_name) - return frozenset(properties) - - def _make_property_semantics( - self, prop_name: str, data_self: ArticulationData, tensor: torch.Tensor - ) -> TensorSemantics: - """Create semantic metadata for raw ArticulationData inputs.""" - semantics = getattr(self._articulation_originals[prop_name].fget, "_leapp_semantics", None) - return TensorSemantics( - name=prop_name, - ref=tensor, - kind=semantics.kind if semantics is not None else None, - element_names=resolve_leapp_element_names(semantics, data_self), - ) +def _unique_output_name(term_name: str, method_name: str, output_cache: list[TensorSemantics]) -> str: + """Return a stable, unique output name for an action write entry. + + Prefers ``term_name``, falls back to ``term_name_method_name``, and appends a + numeric suffix if even that collides. + """ + existing = {t.name for t in output_cache} + candidate = term_name + if candidate in existing: + candidate = f"{term_name}_{method_name}" + suffix = 2 + while candidate in existing: + candidate = f"{term_name}_{method_name}_{suffix}" + suffix += 1 + return candidate + + +# ══════════════════════════════════════════════════════════════════ +# Public entry point +# ══════════════════════════════════════════════════════════════════ + + +def patch_env_for_export(env: ManagerBasedEnv, task_name: str) -> None: + """Patch the env's observation and action managers for LEAPP export. + + This is a thin public entry point around ``ObservationPatcher`` and + ``ActionPatcher``. It mutates the provided env instance in-place so that: + + - Observation terms route through proxy objects that annotate + ``ArticulationData`` reads. + - Action terms route through proxy objects that annotate + ``Articulation`` write methods. + + The underlying env, scene, assets, and tensors remain shared with the rest + of the pipeline; only the manager call paths are redirected. + """ + unwrapped = env.env.unwrapped + + obs_patcher = ObservationPatcher(task_name) + obs_patcher.setup(unwrapped.observation_manager) + + action_patcher = ActionPatcher(task_name) + action_patcher.setup(unwrapped.action_manager) From e6cc033b753364dc80eab1b1dc9ed2433674b940 Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Thu, 12 Mar 2026 09:36:37 -0700 Subject: [PATCH 19/23] refactor for unified patching class. new class also does action manager side observation patching --- .../reinforcement_learning/rsl_rl/export.py | 37 +- .../rsl_rl/export_annotator.py | 391 ++++++++++-------- 2 files changed, 215 insertions(+), 213 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 4e3420348581..54040d882053 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -72,21 +72,13 @@ ) from isaaclab.utils.assets import retrieve_file_path -from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx +from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper from isaaclab_rl.utils.pretrained_checkpoint import get_published_pretrained_checkpoint import isaaclab_tasks # noqa: F401 from isaaclab_tasks.utils import get_checkpoint_path from isaaclab_tasks.utils.hydra import hydra_task_config -# IMPORTANT: Add leapp annotations BEFORE importing isaaclab_tasks -# This ensures the patched functions are captured when configs are created -# from annotate_functions_for_export import ( -# add_leapp_annotations, -# get_action_io_to_term_map, -# get_observation_to_articulation_map, -# ) - @hydra_task_config(args_cli.task, args_cli.agent) def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg): @@ -148,34 +140,9 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen # obtain the trained policy for inference policy = runner.get_inference_policy(device=env.unwrapped.device) - # extract the neural network module - # we do this in a try-except to maintain backwards compatibility. - try: - # version 2.3 onwards - policy_nn = runner.alg.policy - except AttributeError: - # version 2.2 and below - policy_nn = runner.alg.actor_critic - - # extract the normalizer - if hasattr(policy_nn, "actor_obs_normalizer"): - normalizer = policy_nn.actor_obs_normalizer - elif hasattr(policy_nn, "student_obs_normalizer"): - normalizer = policy_nn.student_obs_normalizer - else: - normalizer = None - - # export policy to onnx/jit - export_model_dir = os.path.join(os.path.dirname(resume_path), "exported") - export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt") - export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx") - jit_path = os.path.join(export_model_dir, "policy.pt") - onnx_path = os.path.join(export_model_dir, "policy.onnx") - print(f"[INFO]: Exported policy to: jit {jit_path}, onnx {onnx_path}") - # start annotation tracing # Note: all patching is done at module/class level before isaaclab_tasks import - leapp.start(task_name, save_path=export_model_dir) + leapp.start(task_name, save_path=log_dir) obs = env.get_observations() # simulate environment while not simulation_app.is_running(): diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index beb17e57c31e..2d09ec123e3e 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -5,15 +5,30 @@ """Export annotations for Isaac Lab policies using proxy-based patching. -Observation and action annotation is achieved by routing calls through lightweight -proxy objects rather than globally patching class methods. This means: - -- Observation term functions see an _EnvProxy whose scene returns _ArticulationProxy - objects with annotating data getters. The real env / data is shared and unmodified. - -- Action terms have their ``_asset`` attribute replaced with an _ArticulationWriteProxy - that intercepts ``_leapp_semantics``-decorated write methods and records outputs. - The real Articulation class is never patched. +Observation and action annotation share a single set of annotating getters +and a unified dedup cache so that a state property (e.g. ``joint_pos``) +read by both an observation term and an action term resolves to one LEAPP +input edge. + +- Observation term functions see an _EnvProxy whose scene returns + _ArticulationProxy objects with annotating data getters. + +- Action terms have their ``_asset`` attribute replaced with an + _ArticulationWriteProxy that intercepts ``_leapp_semantics``-decorated + write methods **and** routes ``.data`` reads through the same annotating + data proxy used by observations. + +Cache lifecycle (assuming single-env play-mode export): + + compute_group() clear cache → obs terms populate cache + policy inference TracedTensors propagate through NN + process_action() register_buffer for raw_actions + apply_action() [tracing] reuse cached TracedTensors for state reads, + capture write outputs, call output_tensors(), + then clear cache + apply_action() [decim.] clear cache → fresh reads for simulation + ... + compute_group() clear cache → fresh reads for next obs """ from __future__ import annotations @@ -35,7 +50,7 @@ # ══════════════════════════════════════════════════════════════════ -# Observation-side proxies +# Shared data proxy # ══════════════════════════════════════════════════════════════════ @@ -44,7 +59,8 @@ class _ArticulationDataProxy: For properties whose getter carries ``_leapp_semantics``, the proxy calls the annotating getter (which records the tensor with LEAPP) and caches the - result for deduplication within a single ``compute_group`` call. + result for deduplication. Consumers within the same annotation pass + (observation terms **and** action terms) receive the same TracedTensor. All other attribute access is forwarded transparently to the real object. """ @@ -68,6 +84,11 @@ def __getattr__(self, name): return getattr(object.__getattribute__(self, "_real_data"), name) +# ══════════════════════════════════════════════════════════════════ +# Observation-side proxies +# ══════════════════════════════════════════════════════════════════ + + class _ArticulationProxy: """Proxy around a real Articulation that returns _ArticulationDataProxy for ``.data``. @@ -150,14 +171,13 @@ def __getattr__(self, name): class _ArticulationWriteProxy: - """Proxy around a real Articulation that intercepts ``_leapp_semantics`` write methods. - - When an action term calls e.g. ``self._asset.set_joint_position_target(target, joint_ids)``, - this proxy: + """Proxy around a real Articulation for action terms. - 1. Calls the real method on the real asset (so the simulation sees the write). - 2. Snapshots the ``target`` tensor and records a ``TensorSemantics`` entry in the - shared output cache. + Intercepts ``_leapp_semantics``-decorated write methods **and** routes + ``.data`` reads through a shared ``_ArticulationDataProxy`` so that + action-side state reads (e.g. ``self._asset.data.joint_pos`` inside + ``RelativeJointPositionAction``) participate in LEAPP annotation and + share the dedup cache with observation-side reads. All other attribute access is forwarded transparently to the real asset. """ @@ -168,11 +188,18 @@ def __init__( term_name: str, output_cache: list[TensorSemantics], annotating_methods: dict[str, callable], + data_proxy: _ArticulationDataProxy, ): object.__setattr__(self, "_real_asset", real_asset) object.__setattr__(self, "_term_name", term_name) object.__setattr__(self, "_output_cache", output_cache) object.__setattr__(self, "_annotating_methods", annotating_methods) + object.__setattr__(self, "_data_proxy", data_proxy) + + @property + def data(self): + """Return the shared annotating data proxy.""" + return object.__getattribute__(self, "_data_proxy") def __getattr__(self, name): """Return an annotating wrapper for _leapp_semantics methods; forward everything else.""" @@ -187,62 +214,54 @@ def __getattr__(self, name): # ══════════════════════════════════════════════════════════════════ -# ObservationPatcher +# ExportPatcher # ══════════════════════════════════════════════════════════════════ -class ObservationPatcher: - """Permanently patches observation term functions to annotate their inputs via proxies. +class ExportPatcher: + """Unified patcher that annotates observation inputs and action outputs for LEAPP export. + + Builds a single set of annotating getters from ``ArticulationData`` and a + shared dedup cache, then wires them into both: - Instead of globally patching ArticulationData properties and toggling them on/off, - this scans for ``_leapp_semantics``-decorated properties once, builds annotating - getters, and routes observation term calls through lightweight proxy objects that - share the same underlying env and tensor state. + - The observation proxy chain (``_EnvProxy`` → ``_SceneProxy`` → + ``_ArticulationProxy`` → ``_ArticulationDataProxy``) for state reads + by observation term functions. + - The ``_ArticulationWriteProxy`` on each action term, which intercepts + target writes **and** routes ``.data`` reads through the same + ``_ArticulationDataProxy`` / cache. + + This ensures that a property like ``joint_pos`` read by both an + observation term and ``RelativeJointPositionAction.apply_actions()`` + resolves to a single LEAPP input edge rather than being silently baked + in as a constant. """ def __init__(self, task_name: str): self.task_name = task_name self._annotated_tensor_cache: dict[str, torch.Tensor] = {} + self._action_output_cache: list[TensorSemantics] = [] + self._pending_action_output_export: bool = False + self._uses_last_action_state: bool = False - def setup(self, obs_manager): - """Patch all observation terms to use annotating proxies. - - For each term in the observation manager: - - Normal terms get their ``env`` argument swapped for a proxy env. - - ``last_action`` and ``generated_commands`` get dedicated wrappers. - - Noise is disabled on every term for deterministic export. - - A thin wrapper on ``compute_group`` clears the dedup cache between calls. - """ - real_env = obs_manager._env + def setup(self, env): + """Patch observation and action managers on the unwrapped env.""" + unwrapped = env.env.unwrapped annotating_getters = self._build_annotating_getters() - scene_proxy = _SceneProxy(real_env.scene, annotating_getters, self._annotated_tensor_cache) - proxy_env = _EnvProxy(real_env, scene_proxy) - - for group_name, term_cfgs in obs_manager._group_obs_term_cfgs.items(): - for term_cfg in term_cfgs: - original_func = term_cfg.func - func_name = getattr(original_func, "__name__", None) - - if func_name == "last_action": - term_cfg.func = self._wrap_last_action(original_func) - elif func_name == "generated_commands": - term_cfg.func = self._wrap_generated_commands(original_func, term_cfg) - else: - term_cfg.func = self._wrap_with_proxy(original_func, proxy_env) - - term_cfg.noise = None - - original_compute_group = obs_manager.compute_group + annotating_write_methods = self._build_annotating_write_methods() cache = self._annotated_tensor_cache - def patched_compute_group(*args, **kwargs): - """Clear the tensor dedup cache, then run the real compute_group.""" - cache.clear() - return original_compute_group(*args, **kwargs) + scene_proxy = _SceneProxy(unwrapped.scene, annotating_getters, cache) + proxy_env = _EnvProxy(unwrapped, scene_proxy) - obs_manager.compute_group = patched_compute_group + self._patch_observation_manager(unwrapped.observation_manager, proxy_env) + self._patch_action_manager( + unwrapped.action_manager, + annotating_getters, + cache, + annotating_write_methods, + ) # ── Scanning ────────────────────────────────────────────────── @@ -281,110 +300,10 @@ def getter(data_self): return getter - # ── Term wrappers ───────────────────────────────────────────── - - @staticmethod - def _wrap_with_proxy(original_func, proxy_env): - """Wrap a term function so it receives the proxy env instead of the real env. - - This is the generic wrapper for observation terms that read ArticulationData - properties. By substituting the env, the entire downstream chain - (env.scene[name].data.property) goes through the proxy. - """ - - def wrapped(env, **kwargs): - return original_func(proxy_env, **kwargs) - - wrapped.__name__ = getattr(original_func, "__name__", "unknown") - return wrapped - - def _wrap_last_action(self, original_func): - """Wrap the ``last_action`` observation term to annotate its output as a LEAPP input.""" - task_name = self.task_name - - def wrapped(env, action_name=None, **kwargs): - result = original_func(env, action_name, **kwargs) - return annotate.input_tensors(task_name, {"last_actions": result}) - - wrapped.__name__ = original_func.__name__ - return wrapped - - def _wrap_generated_commands(self, original_func, term_cfg): - """Wrap the ``generated_commands`` observation term to annotate its output as a LEAPP input. - - Resolves command semantics (kind, element_names) from the command manager - configuration when available. - """ - task_name = self.task_name - command_name_from_cfg = term_cfg.params.get("command_name") - - def wrapped(env, command_name=None, **kwargs): - result = original_func(env, command_name, **kwargs) - leapp_input_name = command_name or command_name_from_cfg or "commands" - command_cfg = None - with suppress(AttributeError, KeyError): - command_cfg = env.command_manager.get_term(leapp_input_name).cfg - sem = TensorSemantics( - name=leapp_input_name, - ref=result, - kind=getattr(command_cfg, "cmd_hint", None), - element_names=getattr(command_cfg, "element_names", None), - ) - return annotate.input_tensors(task_name, sem) - - wrapped.__name__ = original_func.__name__ - return wrapped - - -# ══════════════════════════════════════════════════════════════════ -# ActionPatcher -# ══════════════════════════════════════════════════════════════════ - - -class ActionPatcher: - """Permanently patches action terms to annotate their outputs via proxies. - - 1. Scans Articulation for ``_leapp_semantics``-decorated methods once. - 2. Replaces each action term's ``_asset`` with an ``_ArticulationWriteProxy`` that - intercepts those methods and records output semantics. - 3. Patches ``process_action`` and ``apply_action`` on the action manager instance - to coordinate buffer registration and the single ``annotate.output_tensors`` call. - - No Articulation class methods are ever modified. - """ - - def __init__(self, task_name: str): - self.task_name = task_name - self._action_output_cache: list[TensorSemantics] = [] - self._pending_action_output_export: bool = False - - def setup(self, action_manager): - """Patch all action terms and the action manager for LEAPP annotation. - - For each action term with an Articulation asset, replaces ``term._asset`` - with an ``_ArticulationWriteProxy``. Then patches ``process_action`` and - ``apply_action`` on the manager instance. - """ - annotating_methods = self._build_annotating_write_methods() - - for term_name, term in action_manager._terms.items(): - asset = getattr(term, "_asset", None) - if isinstance(asset, Articulation): - term._asset = _ArticulationWriteProxy( - real_asset=asset, - term_name=term_name, - output_cache=self._action_output_cache, - annotating_methods=annotating_methods, - ) - - self._patch_manager_methods(action_manager) - - # ── Scanning ────────────────────────────────────────────────── - def _build_annotating_write_methods(self) -> dict[str, callable]: """Scan Articulation for ``_leapp_semantics`` methods and build interceptors. - Returns a dict mapping method name to a factory callable. The factory takes + Returns a dict mapping method name to a factory callable. The factory takes ``(real_asset, original_bound_method, term_name, output_cache)`` and returns a callable that the proxy returns in ``__getattr__``. """ @@ -399,7 +318,7 @@ def _make_write_interceptor_factory(self, original_unbound, method_name: str): """Create a factory that produces bound annotating wrappers for a single write method. The factory is called by ``_ArticulationWriteProxy.__getattr__`` each time the - method is accessed. It returns a callable that: + method is accessed. It returns a callable that: 1. Calls the real method on the real asset. 2. Inspects the ``target`` argument. @@ -437,20 +356,77 @@ def interceptor(*args, **kwargs): return factory - # ── Manager patches ─────────────────────────────────────────── + # ── Observation manager patches ─────────────────────────────── + + def _patch_observation_manager(self, obs_manager, proxy_env): + """Patch observation terms to use annotating proxies and disable noise.""" + for group_name, term_cfgs in obs_manager._group_obs_term_cfgs.items(): + for term_cfg in term_cfgs: + original_func = term_cfg.func + func_name = getattr(original_func, "__name__", None) + + if func_name == "last_action": + self._uses_last_action_state = True + term_cfg.func = self._wrap_last_action(original_func) + elif func_name == "generated_commands": + term_cfg.func = self._wrap_generated_commands(original_func, term_cfg) + else: + term_cfg.func = self._wrap_with_proxy(original_func, proxy_env) + + term_cfg.noise = None + + original_compute_group = obs_manager.compute_group + cache = self._annotated_tensor_cache + + def patched_compute_group(*args, **kwargs): + """Clear the tensor dedup cache, then run the real compute_group.""" + cache.clear() + return original_compute_group(*args, **kwargs) + + obs_manager.compute_group = patched_compute_group + + # ── Action manager patches ──────────────────────────────────── - def _patch_manager_methods(self, action_manager): + def _patch_action_manager(self, action_manager, annotating_getters, cache, annotating_write_methods): + """Patch action terms with write+read proxies and patch manager methods.""" + for term_name, term in action_manager._terms.items(): + asset = getattr(term, "_asset", None) + if isinstance(asset, Articulation): + data_proxy = _ArticulationDataProxy(asset.data, annotating_getters, cache) + term._asset = _ArticulationWriteProxy( + real_asset=asset, + term_name=term_name, + output_cache=self._action_output_cache, + annotating_methods=annotating_write_methods, + data_proxy=data_proxy, + ) + + self._patch_action_manager_methods(action_manager) + + def _patch_action_manager_methods(self, action_manager): """Patch ``process_action`` and ``apply_action`` on the action manager instance. ``process_action`` registers raw_action buffers for LEAPP tracing and preserves the action tensor clone. - ``apply_action`` coordinates the output cache lifecycle: clear before, - collect and export after. + ``apply_action`` coordinates the cache and output lifecycle: + + - **Tracing pass** (first ``apply_action`` after ``process_action``): + The cache still holds TracedTensors populated by ``compute_group``. + Action terms that read state (e.g. ``RelativeJointPositionAction`` + reading ``joint_pos``) get those TracedTensors from the cache, + keeping the LEAPP graph connected. After ``output_tensors()`` the + cache is cleared so subsequent decimation sub-steps read fresh values. + + - **Non-tracing passes** (remaining decimation sub-steps and all + subsequent iterations): The cache is cleared **before** running + action terms so every ``.data`` read returns the current simulator + value, preserving simulation correctness. """ original_process = action_manager.process_action original_apply = action_manager.apply_action task_name = self.task_name + cache = self._annotated_tensor_cache def patched_process_action(action: torch.Tensor): """Register raw_action buffers, call real process_action, preserve action clone.""" @@ -463,15 +439,18 @@ def patched_process_action(action: torch.Tensor): self._pending_action_output_export = True def patched_apply_action(): - """Clear cache, call real apply_action, collect outputs, call annotate.output_tensors.""" + """Coordinate cache lifecycle and LEAPP output annotation.""" if not self._pending_action_output_export: + cache.clear() return original_apply() + # Tracing pass: cache still holds TracedTensors from compute_group. self._action_output_cache.clear() original_apply() self._action_output_cache.extend(self._collect_action_outputs(action_manager)) - self._action_output_cache.append(TensorSemantics(name="last_action", ref=action_manager._action)) + if self._uses_last_action_state: + annotate.update_state(task_name, {"last_action": action_manager._action}) static_values = self._collect_action_static_outputs(action_manager) annotate.output_tensors( task_name, @@ -481,11 +460,67 @@ def patched_apply_action(): ) self._pending_action_output_export = False self._action_output_cache.clear() + cache.clear() return None action_manager.process_action = patched_process_action action_manager.apply_action = patched_apply_action + # ── Observation term wrappers ───────────────────────────────── + + @staticmethod + def _wrap_with_proxy(original_func, proxy_env): + """Wrap a term function so it receives the proxy env instead of the real env.""" + + def wrapped(env, **kwargs): + return original_func(proxy_env, **kwargs) + + wrapped.__name__ = getattr(original_func, "__name__", "unknown") + return wrapped + + def _wrap_last_action(self, original_func): + """Wrap ``last_action`` as a LEAPP state tensor. + + ``last_action`` is feedback state, not a regular dangling input. We + therefore register it through ``annotate.state_tensors(...)`` on the + observation side and update it through ``annotate.update_state(...)`` + after the traced action pass. + """ + task_name = self.task_name + + def wrapped(env, action_name=None, **kwargs): + result = original_func(env, action_name, **kwargs) + return annotate.state_tensors(task_name, {"last_action": result}) + + wrapped.__name__ = original_func.__name__ + return wrapped + + def _wrap_generated_commands(self, original_func, term_cfg): + """Wrap the ``generated_commands`` observation term to annotate its output as a LEAPP input. + + Resolves command semantics (kind, element_names) from the command manager + configuration when available. + """ + task_name = self.task_name + command_name_from_cfg = term_cfg.params.get("command_name") + + def wrapped(env, command_name=None, **kwargs): + result = original_func(env, command_name, **kwargs) + leapp_input_name = command_name or command_name_from_cfg or "commands" + command_cfg = None + with suppress(AttributeError, KeyError): + command_cfg = env.command_manager.get_term(leapp_input_name).cfg + sem = TensorSemantics( + name=leapp_input_name, + ref=result, + kind=getattr(command_cfg, "cmd_hint", None), + element_names=getattr(command_cfg, "element_names", None), + ) + return annotate.input_tensors(task_name, sem) + + wrapped.__name__ = original_func.__name__ + return wrapped + # ── Output collection ───────────────────────────────────────── @staticmethod @@ -573,21 +608,21 @@ def _unique_output_name(term_name: str, method_name: str, output_cache: list[Ten def patch_env_for_export(env: ManagerBasedEnv, task_name: str) -> None: """Patch the env's observation and action managers for LEAPP export. - This is a thin public entry point around ``ObservationPatcher`` and - ``ActionPatcher``. It mutates the provided env instance in-place so that: + This is a thin public entry point around ``ExportPatcher``. It mutates + the provided env instance in-place so that: - Observation terms route through proxy objects that annotate ``ArticulationData`` reads. - - Action terms route through proxy objects that annotate - ``Articulation`` write methods. + - Action terms route through proxy objects that annotate both + ``ArticulationData`` reads **and** ``Articulation`` write methods. + + State reads are deduplicated across observation and action paths via a + shared cache, so a property like ``joint_pos`` that is read by both an + observation term and a relative-position action term appears as a single + LEAPP input edge. The underlying env, scene, assets, and tensors remain shared with the rest of the pipeline; only the manager call paths are redirected. """ - unwrapped = env.env.unwrapped - - obs_patcher = ObservationPatcher(task_name) - obs_patcher.setup(unwrapped.observation_manager) - - action_patcher = ActionPatcher(task_name) - action_patcher.setup(unwrapped.action_manager) + patcher = ExportPatcher(task_name) + patcher.setup(env) From e0e97af60da16e3a0d2a8732b73914df760d8e5b Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Fri, 13 Mar 2026 16:51:50 -0700 Subject: [PATCH 20/23] new setup adds kp and kd semantics --- .../rsl_rl/export_annotator.py | 77 +++++++++++++++++-- 1 file changed, 70 insertions(+), 7 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index 2d09ec123e3e..5a15c798b015 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -49,6 +49,15 @@ from isaaclab.envs import ManagerBasedEnv +# Reuse the generic joint-name resolver for kp/kd outputs by providing the +# same ``element_names_source`` contract as articulation getters/writers. +_GAIN_JOINT_SEMANTICS = type( + "GainJointSemantics", + (), + {"element_names": None, "element_names_source": "joint_names"}, +)() + + # ══════════════════════════════════════════════════════════════════ # Shared data proxy # ══════════════════════════════════════════════════════════════════ @@ -392,9 +401,10 @@ def _patch_action_manager(self, action_manager, annotating_getters, cache, annot for term_name, term in action_manager._terms.items(): asset = getattr(term, "_asset", None) if isinstance(asset, Articulation): - data_proxy = _ArticulationDataProxy(asset.data, annotating_getters, cache) + real_asset: Articulation = asset + data_proxy = _ArticulationDataProxy(real_asset.data, annotating_getters, cache) term._asset = _ArticulationWriteProxy( - real_asset=asset, + real_asset=real_asset, term_name=term_name, output_cache=self._action_output_cache, annotating_methods=annotating_write_methods, @@ -530,26 +540,48 @@ def _collect_action_outputs(action_manager) -> list[TensorSemantics]: for term_name, term in action_manager._terms.items(): osc = getattr(term, "_osc", None) if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: + asset = getattr(term, "_asset", None) + real_asset = getattr(asset, "_real_asset", asset) + joint_ids = getattr(term, "_joint_ids", None) + joint_name_context = None + if real_asset is not None and hasattr(real_asset, "joint_names"): + joint_name_context = _JointNameContext(real_asset.joint_names, joint_ids) tensors.append( TensorSemantics( name=f"{term_name}_kp_gains", ref=torch.diagonal(osc._motion_p_gains_task, dim1=-2, dim2=-1), kind="kp", + element_names=( + resolve_leapp_element_names( + _GAIN_JOINT_SEMANTICS, + joint_name_context, + ) + if joint_name_context is not None + else None + ), ) ) tensors.append( TensorSemantics( name=f"{term_name}_kd_gains", ref=torch.diagonal(osc._motion_d_gains_task, dim1=-2, dim2=-1), - kind="kp", + kind="kd", + element_names=( + resolve_leapp_element_names( + _GAIN_JOINT_SEMANTICS, + joint_name_context, + ) + if joint_name_context is not None + else None + ), ) ) return tensors @staticmethod - def _collect_action_static_outputs(action_manager) -> dict: + def _collect_action_static_outputs(action_manager) -> list[TensorSemantics]: """Collect static kp/kd gain values from action terms for export metadata.""" - static_values: dict = {} + static_values: list[TensorSemantics] = [] for term_name, term in action_manager._terms.items(): osc = getattr(term, "_osc", None) if osc and hasattr(osc, "cfg") and osc.cfg.impedance_mode in ["variable", "variable_kp"]: @@ -559,12 +591,43 @@ def _collect_action_static_outputs(action_manager) -> dict: if real_asset and hasattr(real_asset, "data"): data = real_asset.data joint_ids = getattr(term, "_joint_ids", None) + joint_name_context = None + if hasattr(real_asset, "joint_names"): + joint_name_context = _JointNameContext(real_asset.joint_names, joint_ids) if hasattr(data, "default_joint_stiffness") and data.default_joint_stiffness is not None: gains = data.default_joint_stiffness - static_values[f"{term_name}_kp_gains"] = gains[:, joint_ids] if joint_ids else gains + static_values.append( + TensorSemantics( + name=f"{term_name}_kp_gains", + ref=gains[:, joint_ids] if joint_ids else gains, + kind="kp", + element_names=( + resolve_leapp_element_names( + _GAIN_JOINT_SEMANTICS, + joint_name_context, + ) + if joint_name_context is not None + else None + ), + ) + ) if hasattr(data, "default_joint_damping") and data.default_joint_damping is not None: gains = data.default_joint_damping - static_values[f"{term_name}_kd_gains"] = gains[:, joint_ids] if joint_ids else gains + static_values.append( + TensorSemantics( + name=f"{term_name}_kd_gains", + ref=gains[:, joint_ids] if joint_ids else gains, + kind="kd", + element_names=( + resolve_leapp_element_names( + _GAIN_JOINT_SEMANTICS, + joint_name_context, + ) + if joint_name_context is not None + else None + ), + ) + ) return static_values From 971e38b0344fb712ec2394e0b9a6037670d5945e Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Fri, 13 Mar 2026 19:28:29 -0700 Subject: [PATCH 21/23] implemented deployment environment. need to validate against more models --- scripts/reinforcement_learning/deploy.py | 75 ++++ source/isaaclab/isaaclab/envs/__init__.py | 1 + .../isaaclab/envs/direct_deployment_env.py | 419 ++++++++++++++++++ 3 files changed, 495 insertions(+) create mode 100644 scripts/reinforcement_learning/deploy.py create mode 100644 source/isaaclab/isaaclab/envs/direct_deployment_env.py diff --git a/scripts/reinforcement_learning/deploy.py b/scripts/reinforcement_learning/deploy.py new file mode 100644 index 000000000000..a669d7d06a52 --- /dev/null +++ b/scripts/reinforcement_learning/deploy.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Deploy a LEAPP-exported policy in an Isaac Lab simulation. + +Usage:: + + ./isaaclab.sh -p scripts/reinforcement_learning/deploy.py \ + --task Isaac-Velocity-Flat-Anymal-B-v0 \ + --leapp_model .pretrained_checkpoints/rsl_rl/Isaac-Velocity-Flat-Anymal-B-v0/Isaac-Velocity-Flat-Anymal-B-v0/Isaac-Velocity-Flat-Anymal-B-v0.yaml \ + --headless +""" + +"""Launch Isaac Sim Simulator first.""" + +import argparse +import sys + +from isaaclab.app import AppLauncher + +parser = argparse.ArgumentParser(description="Deploy a LEAPP-exported policy in simulation.") +parser.add_argument("--task", type=str, required=True, help="Name of the registered Isaac Lab task.") +parser.add_argument("--leapp_model", type=str, required=True, help="Path to the LEAPP .yaml pipeline description.") +parser.add_argument("--seed", type=int, default=None, help="Seed for the environment.") +parser.add_argument( + "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." +) +AppLauncher.add_app_launcher_args(parser) +args_cli, hydra_args = parser.parse_known_args() + +sys.argv = [sys.argv[0]] + hydra_args + +app_launcher = AppLauncher(args_cli) +simulation_app = app_launcher.app + +"""Rest everything follows.""" + +import torch + +from isaaclab.envs.direct_deployment_env import DirectDeploymentEnv + +import isaaclab_tasks # noqa: F401 +from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry + + +def main(): + # ── Load env config from gym registry ───────────────────────── + task_name = args_cli.task.split(":")[-1] + env_cfg = load_cfg_from_registry(task_name, "env_cfg_entry_point") + + if args_cli.seed is not None: + env_cfg.seed = args_cli.seed + if args_cli.device is not None: + env_cfg.sim.device = args_cli.device + + # ── Create deploy env ───────────────────────────────────────── + env = DirectDeploymentEnv(env_cfg, args_cli.leapp_model) + + print(f"[INFO]: Deploying task '{task_name}' with LEAPP model: {args_cli.leapp_model}") + print(f"[INFO]: Num envs: {env.num_envs}, decimation: {env.cfg.decimation}, step_dt: {env.step_dt:.4f}s") + + # ── Run loop ────────────────────────────────────────────────── + env.reset() + with torch.inference_mode(): + while simulation_app.is_running(): + env.step() + + env.close() + + +if __name__ == "__main__": + main() + simulation_app.close() diff --git a/source/isaaclab/isaaclab/envs/__init__.py b/source/isaaclab/isaaclab/envs/__init__.py index 543ff2ad4bac..7f0eb41dffce 100644 --- a/source/isaaclab/isaaclab/envs/__init__.py +++ b/source/isaaclab/isaaclab/envs/__init__.py @@ -44,6 +44,7 @@ from . import mdp, ui from .common import VecEnvObs, VecEnvStepReturn, ViewerCfg +from .direct_deployment_env import DirectDeploymentEnv from .direct_marl_env import DirectMARLEnv from .direct_marl_env_cfg import DirectMARLEnvCfg from .direct_rl_env import DirectRLEnv diff --git a/source/isaaclab/isaaclab/envs/direct_deployment_env.py b/source/isaaclab/isaaclab/envs/direct_deployment_env.py new file mode 100644 index 000000000000..f1e2dd63d444 --- /dev/null +++ b/source/isaaclab/isaaclab/envs/direct_deployment_env.py @@ -0,0 +1,419 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Deployment environment that runs LEAPP-exported policies in simulation. + +This environment bypasses all Isaac Lab managers (observation, action, reward, etc.) +and instead wires raw ``ArticulationData`` properties and ``CommandManager`` outputs +directly to a LEAPP ``InferenceManager``, then writes the model outputs back to the +articulation. All I/O resolution is driven by the ``kind`` field in the LEAPP YAML. +""" + +from __future__ import annotations + +import logging +import torch +import yaml +from dataclasses import dataclass +from typing import Any + +from isaaclab.assets.articulation.articulation import Articulation +from isaaclab.assets.articulation.articulation_data import ArticulationData +from isaaclab.managers import CommandManager, EventManager +from isaaclab.scene import InteractiveScene +from isaaclab.sim import SimulationContext +from isaaclab.sim.utils.stage import attach_stage_to_usd_context, use_stage + +logger = logging.getLogger(__name__) + + +# ══════════════════════════════════════════════════════════════════ +# I/O spec dataclasses +# ══════════════════════════════════════════════════════════════════ + + +@dataclass +class StateInputSpec: + """Read a property from ``ArticulationData``, optionally sliced by joint.""" + + property_name: str + joint_ids: list[int] | None = None + + +@dataclass +class CommandInputSpec: + """Read a command tensor from ``CommandManager``.""" + + command_term_name: str + + +@dataclass +class OutputSpec: + """Write a tensor to an ``Articulation`` method, optionally indexed by joint.""" + + method_name: str + joint_ids: list[int] | None = None + + +# ══════════════════════════════════════════════════════════════════ +# Kind → source/target resolution helpers +# ══════════════════════════════════════════════════════════════════ + +_JOINT_LEVEL_KIND_PREFIXES = ("state/joint/", "target/joint/") +_JOINT_LEVEL_GAIN_KINDS = ("kp", "kd") + + +def _build_kind_to_property_map() -> dict[str, list[str]]: + """Scan ``ArticulationData`` for ``_leapp_semantics`` properties. + + Returns a mapping from ``kind`` string to a list of property names that + carry that kind (there can be more than one, e.g. ``root_lin_vel_b`` and + ``root_lin_vel_w`` both have ``state/body/linear_velocity``). + """ + kind_to_props: dict[str, list[str]] = {} + for prop_name in dir(ArticulationData): + prop = getattr(ArticulationData, prop_name, None) + if isinstance(prop, property) and prop.fget and hasattr(prop.fget, "_leapp_semantics"): + kind = prop.fget._leapp_semantics.kind + if kind is not None: + kind_to_props.setdefault(kind, []).append(prop_name) + return kind_to_props + + +def _build_kind_to_write_method_map() -> dict[str, str]: + """Scan ``Articulation`` for ``_leapp_semantics`` methods + hardcoded kp/kd. + + Returns a mapping from output ``kind`` to the method name on ``Articulation``. + """ + kind_to_method: dict[str, str] = {} + for method_name in dir(Articulation): + method = getattr(Articulation, method_name, None) + if callable(method) and hasattr(method, "_leapp_semantics"): + kind = method._leapp_semantics.kind + if kind is not None: + kind_to_method[kind] = method_name + kind_to_method["kp"] = "write_joint_stiffness_to_sim" + kind_to_method["kd"] = "write_joint_damping_to_sim" + return kind_to_method + + +def _disambiguate_property(kind: str, leapp_name: str, kind_to_props: dict[str, list[str]]) -> str: + """Pick the right ``ArticulationData`` property when multiple share a ``kind``. + + The export path uses the property name as the LEAPP input name, so we strip + the ``_in`` / ``_out`` suffix that LEAPP adds for collision avoidance and match. + """ + candidates = kind_to_props.get(kind) + if candidates is None: + raise ValueError(f"No ArticulationData property found for kind='{kind}'") + if len(candidates) == 1: + return candidates[0] + base_name = leapp_name.removesuffix("_in").removesuffix("_out") + for prop in candidates: + if prop == base_name: + return prop + return candidates[0] + + +def _resolve_joint_ids(element_names: list | None, asset: Articulation) -> list[int] | None: + """Convert ``element_names[0]`` joint names to integer joint indices. + + Returns ``None`` when no slicing is needed (all joints or non-joint tensor). + """ + if element_names is None: + return None + joint_names = element_names[0] + if not isinstance(joint_names, list) or not joint_names: + return None + if joint_names == list(asset.joint_names): + return None + joint_ids, _ = asset.find_joints(joint_names, preserve_order=True) + return joint_ids + + +def _find_command_term_by_hint(kind: str, command_manager: CommandManager) -> str: + """Find the ``CommandTerm`` name whose ``cfg.cmd_hint`` matches ``kind``.""" + for name, term in command_manager._terms.items(): + if getattr(term.cfg, "cmd_hint", None) == kind: + return name + raise ValueError(f"No command term with cmd_hint='{kind}'. Available terms: {list(command_manager._terms.keys())}") + + +def _find_robot_asset(scene: InteractiveScene) -> Articulation: + """Return the first ``Articulation`` in the scene (assumed to be the robot).""" + for entity_name in scene.articulations: + entity = scene[entity_name] + if isinstance(entity, Articulation): + return entity + raise RuntimeError("No Articulation found in scene") + + +# ══════════════════════════════════════════════════════════════════ +# DirectDeploymentEnv +# ══════════════════════════════════════════════════════════════════ + + +class DirectDeploymentEnv: + """Runs a LEAPP-exported policy in an Isaac Lab scene. + + The environment sets up the simulation scene and physics from a standard + Isaac Lab config, then wires raw sensor/command data to a LEAPP + ``InferenceManager`` and writes the model outputs back to the articulation. + + No observation, action, reward, termination, or curriculum managers are used. + The LEAPP model already contains all pre/post-processing. + """ + + def __init__(self, cfg: Any, leapp_yaml_path: str): + """Initialize the deployment environment. + + Args: + cfg: A ``ManagerBasedRLEnvCfg`` (or compatible) task config. + leapp_yaml_path: Path to the LEAPP ``.yaml`` pipeline description. + """ + from leapp import InferenceManager + + cfg.scene.num_envs = 1 + cfg.validate() + self.cfg = cfg + self._is_closed = False + self._leapp_yaml_path = leapp_yaml_path + self._step_count = 0 + + # ── Simulation + scene ──────────────────────────────────── + self.sim = SimulationContext(cfg.sim) + if "cuda" in self.sim.device: + torch.cuda.set_device(self.sim.device) + + with use_stage(self.sim.get_initial_stage()): + self.scene = InteractiveScene(cfg.scene) + attach_stage_to_usd_context() + self.sim.reset() + self.scene.update(dt=self.physics_dt) + + # ── Robot asset ─────────────────────────────────────────── + self._asset = _find_robot_asset(self.scene) + + # ── EventManager (optional, for resets) ─────────────────── + self.event_manager: EventManager | None = None + if hasattr(cfg, "events") and cfg.events is not None: + self.event_manager = EventManager(cfg.events, self) + + # ── CommandManager (optional, for command/* inputs) ─────── + self.command_manager: CommandManager | None = None + if hasattr(cfg, "commands") and cfg.commands is not None: + self.command_manager = CommandManager(cfg.commands, self) + + # ── LEAPP InferenceManager ──────────────────────────────── + self.inference = InferenceManager(leapp_yaml_path) + + # ── Parse YAML and resolve I/O mappings ─────────────────── + with open(leapp_yaml_path) as f: + self._leapp_desc = yaml.safe_load(f) + self._input_mapping: dict[str, StateInputSpec | CommandInputSpec] = {} + self._output_mapping: dict[str, OutputSpec] = {} + self._resolve_io() + + # ── Cache feedback initial values for cheap reset ───────── + self._feedback_initial_values: dict[str, torch.Tensor] = {} + self._cache_feedback_initial_values() + + logger.info( + "DirectDeploymentEnv ready — %d inputs, %d outputs mapped", + len(self._input_mapping), + len(self._output_mapping), + ) + + # ── Properties ──────────────────────────────────────────────── + + @property + def num_envs(self) -> int: + return 1 + + @property + def physics_dt(self) -> float: + return self.cfg.sim.dt + + @property + def step_dt(self) -> float: + return self.cfg.sim.dt * self.cfg.decimation + + @property + def device(self) -> str: + return self.sim.device + + # ── I/O Resolution ──────────────────────────────────────────── + + def _resolve_io(self): + """Build ``_input_mapping`` and ``_output_mapping`` from LEAPP YAML ``kind`` fields.""" + kind_to_props = _build_kind_to_property_map() + kind_to_write = _build_kind_to_write_method_map() + pipeline = self._leapp_desc["pipeline"] + + # --- Inputs --- + for node_name, input_names in pipeline["inputs"].items(): + node = self.inference.nodes[node_name] + desc_by_name = {d["name"]: d for d in node.input_descriptions} + for input_name in input_names: + desc = desc_by_name[input_name] + kind = desc.get("kind") + key = f"{node_name}/{input_name}" + if kind is None: + continue + if kind.startswith("state/"): + prop = _disambiguate_property(kind, input_name, kind_to_props) + needs_joint_slice = kind.startswith("state/joint/") + jids = _resolve_joint_ids(desc.get("element_names"), self._asset) if needs_joint_slice else None + self._input_mapping[key] = StateInputSpec(property_name=prop, joint_ids=jids) + elif kind.startswith("command/"): + if self.command_manager is None: + raise RuntimeError( + f"LEAPP input '{key}' has kind='{kind}' but no CommandManager " + "is available (cfg.commands is None)." + ) + term_name = _find_command_term_by_hint(kind, self.command_manager) + self._input_mapping[key] = CommandInputSpec(command_term_name=term_name) + else: + logger.warning("Unknown input kind '%s' for '%s' — skipping", kind, key) + + # --- Outputs --- + for node_name, output_names in pipeline["outputs"].items(): + node = self.inference.nodes[node_name] + desc_by_name = {d["name"]: d for d in node.output_descriptions} + for output_name in output_names: + desc = desc_by_name[output_name] + kind = desc.get("kind") + key = f"{node_name}/{output_name}" + if kind is None: + continue + if kind not in kind_to_write: + logger.warning("Unknown output kind '%s' for '%s' — skipping", kind, key) + continue + method_name = kind_to_write[kind] + needs_joint_ids = kind.startswith("target/joint/") or kind in _JOINT_LEVEL_GAIN_KINDS + jids = _resolve_joint_ids(desc.get("element_names"), self._asset) if needs_joint_ids else None + self._output_mapping[key] = OutputSpec(method_name=method_name, joint_ids=jids) + + # ── Feedback state caching ──────────────────────────────────── + + def _cache_feedback_initial_values(self): + """Snapshot the feedback input buffers right after InferenceManager init. + + This allows ``reset()`` to cheaply restore feedback state without + re-instantiating the entire InferenceManager. + """ + for fb_key in self.inference.feedback_inputs: + node_name, input_name = fb_key.split("/") + tensor = self.inference.value_dict[node_name][input_name] + self._feedback_initial_values[fb_key] = tensor.clone() + + def _restore_feedback_initial_values(self): + """Restore feedback input buffers to their initial values.""" + for fb_key, initial in self._feedback_initial_values.items(): + node_name, input_name = fb_key.split("/") + self.inference.value_dict[node_name][input_name] = initial.clone() + + # ── Read / Write ────────────────────────────────────────────── + + def _read_inputs(self) -> dict[str, torch.Tensor]: + """Read all mapped inputs from the scene and command manager.""" + inputs: dict[str, torch.Tensor] = {} + for key, spec in self._input_mapping.items(): + if isinstance(spec, StateInputSpec): + value = getattr(self._asset.data, spec.property_name) + if spec.joint_ids is not None: + value = value[:, spec.joint_ids] + inputs[key] = value + elif isinstance(spec, CommandInputSpec): + inputs[key] = self.command_manager.get_command(spec.command_term_name) + return inputs + + def _write_outputs(self, outputs: dict[str, torch.Tensor]): + """Write model outputs to the articulation.""" + for key, tensor in outputs.items(): + spec = self._output_mapping.get(key) + if spec is None: + continue + method = getattr(self._asset, spec.method_name) + if spec.joint_ids is not None: + method(tensor, joint_ids=spec.joint_ids) + else: + method(tensor) + + # ── Public API ──────────────────────────────────────────────── + + def reset(self) -> dict[str, torch.Tensor]: + """Reset the scene and feedback state. + + Returns: + The initial input tensors (for logging / debugging). + """ + env_ids = torch.tensor([0], device=self.device, dtype=torch.long) + + self.scene.reset(env_ids) + + if self.event_manager is not None and "reset" in self.event_manager.available_modes: + self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=self._step_count) + if self.command_manager is not None: + self.command_manager.reset(env_ids) + + self.scene.write_data_to_sim() + self.sim.forward() + self.scene.update(dt=self.physics_dt) + + self._restore_feedback_initial_values() + + return self._read_inputs() + + def step(self, external_inputs: dict[str, torch.Tensor] | None = None) -> dict[str, torch.Tensor]: + """Run one environment step: read → infer → write → physics. + + Args: + external_inputs: Optional overrides keyed by ``"ModelName/input_name"``. + Takes precedence over auto-resolved state/command values. + + Returns: + The dict of pipeline outputs from ``InferenceManager.run_policy()``. + """ + self._step_count += 1 + + # 1. Update commands + if self.command_manager is not None: + self.command_manager.compute(dt=self.step_dt) + + # 2. Read inputs + inputs = self._read_inputs() + + # 3. Merge external overrides + if external_inputs is not None: + inputs.update(external_inputs) + + # 4. Infer + with torch.inference_mode(): + outputs = self.inference.run_policy(inputs) + + # 5. Write outputs to asset + self._write_outputs(outputs) + + # 6. Decimation loop + is_rendering = self.sim.has_gui() or self.sim.has_rtx_sensors() + for _ in range(self.cfg.decimation): + self.scene.write_data_to_sim() + self.sim.step(render=False) + if is_rendering: + self.sim.render() + self.scene.update(dt=self.physics_dt) + + return outputs + + def close(self): + """Clean up the environment.""" + if not self._is_closed: + if self.command_manager is not None: + del self.command_manager + if self.event_manager is not None: + del self.event_manager + del self.scene + self._is_closed = True From b6f5a612a0450952b5de9cfa7dfc372f0a12631f Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Mon, 16 Mar 2026 17:22:07 -0700 Subject: [PATCH 22/23] added a test for export. expanded annotations to include sensor primitives --- scripts/reinforcement_learning/deploy.py | 3 - .../reinforcement_learning/rsl_rl/export.py | 8 +- .../rsl_rl/export_annotator.py | 259 ++++++++++++++---- .../assets/articulation/articulation_data.py | 1 + .../isaaclab/envs/direct_deployment_env.py | 30 +- .../sensors/contact_sensor/contact_sensor.py | 95 +++---- .../contact_sensor/contact_sensor_data.py | 99 ++++++- .../frame_transformer/frame_transformer.py | 26 +- .../frame_transformer_data.py | 70 ++++- source/isaaclab/isaaclab/sensors/imu/imu.py | 58 ++-- .../isaaclab/isaaclab/sensors/imu/imu_data.py | 58 +++- .../multi_mesh_ray_caster_camera.py | 4 +- .../isaaclab/sensors/ray_caster/ray_caster.py | 14 +- .../sensors/ray_caster/ray_caster_data.py | 26 +- .../isaaclab/utils/leapp_semantics.py | 33 ++- .../test/test_rsl_rl_export_flow.py | 143 ++++++++++ 16 files changed, 710 insertions(+), 217 deletions(-) create mode 100644 source/isaaclab_rl/test/test_rsl_rl_export_flow.py diff --git a/scripts/reinforcement_learning/deploy.py b/scripts/reinforcement_learning/deploy.py index a669d7d06a52..cfa05bd549c8 100644 --- a/scripts/reinforcement_learning/deploy.py +++ b/scripts/reinforcement_learning/deploy.py @@ -24,9 +24,6 @@ parser.add_argument("--task", type=str, required=True, help="Name of the registered Isaac Lab task.") parser.add_argument("--leapp_model", type=str, required=True, help="Path to the LEAPP .yaml pipeline description.") parser.add_argument("--seed", type=int, default=None, help="Seed for the environment.") -parser.add_argument( - "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." -) AppLauncher.add_app_launcher_args(parser) args_cli, hydra_args = parser.parse_known_args() diff --git a/scripts/reinforcement_learning/rsl_rl/export.py b/scripts/reinforcement_learning/rsl_rl/export.py index 54040d882053..094d14489b7b 100644 --- a/scripts/reinforcement_learning/rsl_rl/export.py +++ b/scripts/reinforcement_learning/rsl_rl/export.py @@ -36,6 +36,12 @@ help="Use the pre-trained checkpoint from Nucleus.", ) parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.") +parser.add_argument( + "--disable_graph_visualization", + action="store_true", + default=False, + help="Disable LEAPP graph visualization during compile_graph().", +) # append RSL-RL cli arguments cli_args.add_rsl_rl_args(parser) # append AppLauncher cli args @@ -156,7 +162,7 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen obs, _, _, _ = env.step(actions) leapp.stop() - leapp.compile_graph() + leapp.compile_graph(visualize=not args_cli.disable_graph_visualization) # close the simulator env.close() diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index 5a15c798b015..dff93b27c2b4 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -36,13 +36,21 @@ import inspect import torch from contextlib import suppress -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from leapp import annotate from leapp.utils.tensor_description import TensorSemantics from isaaclab.assets.articulation.articulation import Articulation from isaaclab.assets.articulation.articulation_data import ArticulationData +from isaaclab.sensors.camera.camera_data import CameraData +from isaaclab.sensors.contact_sensor.contact_sensor_data import ContactSensorData +from isaaclab.sensors.frame_transformer.frame_transformer_data import FrameTransformerData +from isaaclab.sensors.imu.imu_data import ImuData +from isaaclab.sensors.ray_caster.multi_mesh_ray_caster_camera_data import MultiMeshRayCasterCameraData +from isaaclab.sensors.ray_caster.multi_mesh_ray_caster_data import MultiMeshRayCasterData +from isaaclab.sensors.ray_caster.ray_caster_data import RayCasterData +from isaaclab.sensors.tacsl_sensor.visuotactile_sensor_data import VisuoTactileSensorData from isaaclab.utils.leapp_semantics import resolve_leapp_element_names if TYPE_CHECKING: @@ -57,40 +65,81 @@ {"element_names": None, "element_names_source": "joint_names"}, )() +_ANNOTATED_DATA_CLASSES = ( + ArticulationData, + CameraData, + ContactSensorData, + FrameTransformerData, + ImuData, + MultiMeshRayCasterCameraData, + MultiMeshRayCasterData, + RayCasterData, + VisuoTactileSensorData, +) + # ══════════════════════════════════════════════════════════════════ # Shared data proxy # ══════════════════════════════════════════════════════════════════ -class _ArticulationDataProxy: - """Proxy around a real ArticulationData that intercepts annotated property reads. +def _lookup_annotating_getter( + annotating_getters_by_type: dict[type, dict[str, callable]], real_data: Any, name: str +) -> callable | None: + """Return the annotating getter for a property on the given data object, if any.""" + for data_cls in type(real_data).__mro__: + getter = annotating_getters_by_type.get(data_cls, {}).get(name) + if getter is not None: + return getter + return None + + +def _has_annotated_getters(annotating_getters_by_type: dict[type, dict[str, callable]], real_data: Any) -> bool: + """Return True when the data object's class hierarchy exposes any annotated getters.""" + return any(annotating_getters_by_type.get(data_cls) for data_cls in type(real_data).__mro__) + + +class _DataProxy: + """Proxy around a real data object that intercepts annotated property reads. - For properties whose getter carries ``_leapp_semantics``, the proxy calls - the annotating getter (which records the tensor with LEAPP) and caches the - result for deduplication. Consumers within the same annotation pass - (observation terms **and** action terms) receive the same TracedTensor. + The real data object may be an ``ArticulationData`` instance or any sensor + data class whose properties carry ``_leapp_semantics``. The proxy calls the + annotating getter (which records the tensor with LEAPP) and caches the + result for deduplication. Consumers within the same annotation pass + (observation terms **and** action terms) receive the same TracedTensor for + repeated reads from the same underlying data object. All other attribute access is forwarded transparently to the real object. """ - def __init__(self, real_data: ArticulationData, annotating_getters: dict[str, callable], cache: dict): + def __init__( + self, + real_data: Any, + annotating_getters_by_type: dict[type, dict[str, callable]], + cache: dict, + input_name_resolver: callable, + ): object.__setattr__(self, "_real_data", real_data) - object.__setattr__(self, "_annotating_getters", annotating_getters) + object.__setattr__(self, "_annotating_getters_by_type", annotating_getters_by_type) object.__setattr__(self, "_cache", cache) + object.__setattr__(self, "_input_name_resolver", input_name_resolver) def __getattr__(self, name): """Intercept annotated properties; forward everything else.""" - getters = object.__getattribute__(self, "_annotating_getters") - if name in getters: + real_data = object.__getattribute__(self, "_real_data") + getter = _lookup_annotating_getter( + object.__getattribute__(self, "_annotating_getters_by_type"), real_data, name + ) + if getter is not None: cache = object.__getattribute__(self, "_cache") - if name in cache: - return cache[name].clone() - real_data = object.__getattribute__(self, "_real_data") - result = getters[name](real_data) - cache[name] = result + cache_key = (id(real_data), name) + if cache_key in cache: + return cache[cache_key].clone() + input_name = object.__getattribute__(self, "_input_name_resolver")(name) + result = getter(real_data, input_name) + cache[cache_key] = result return result - return getattr(object.__getattribute__(self, "_real_data"), name) + return getattr(real_data, name) # ══════════════════════════════════════════════════════════════════ @@ -98,55 +147,140 @@ def __getattr__(self, name): # ══════════════════════════════════════════════════════════════════ -class _ArticulationProxy: - """Proxy around a real Articulation that returns _ArticulationDataProxy for ``.data``. +class _EntityProxy: + """Proxy around a real scene entity that returns a ``_DataProxy`` for ``.data``. All other attribute access is forwarded transparently to the real asset. """ - def __init__(self, real_asset: Articulation, data_proxy: _ArticulationDataProxy): - object.__setattr__(self, "_real_asset", real_asset) + def __init__(self, real_entity: Any, data_proxy: _DataProxy): + object.__setattr__(self, "_real_entity", real_entity) object.__setattr__(self, "_data_proxy", data_proxy) @property def data(self): - """Return the annotating data proxy instead of the real ArticulationData.""" + """Return the annotating data proxy instead of the real data object.""" return object.__getattribute__(self, "_data_proxy") def __getattr__(self, name): - """Forward all non-data attribute access to the real asset.""" - return getattr(object.__getattribute__(self, "_real_asset"), name) + """Forward all non-data attribute access to the real scene entity.""" + return getattr(object.__getattribute__(self, "_real_entity"), name) + + +class _EntityMappingProxy: + """Proxy around a mapping of scene entities that lazily wraps data-producing entries.""" + + def __init__(self, real_mapping, annotating_getters_by_type: dict[type, dict[str, callable]], cache: dict): + object.__setattr__(self, "_real_mapping", real_mapping) + object.__setattr__(self, "_annotating_getters_by_type", annotating_getters_by_type) + object.__setattr__(self, "_cache", cache) + object.__setattr__(self, "_proxied", {}) + + def __getitem__(self, key): + """Return a proxied entity when it exposes annotated data properties.""" + proxied = object.__getattribute__(self, "_proxied") + if key in proxied: + return proxied[key] + real_mapping = object.__getattribute__(self, "_real_mapping") + entity = real_mapping[key] + data = getattr(entity, "data", None) + if data is None: + return entity + annotating_getters_by_type = object.__getattribute__(self, "_annotating_getters_by_type") + if not _has_annotated_getters(annotating_getters_by_type, data): + return entity + data_proxy = _DataProxy( + data, + annotating_getters_by_type, + object.__getattribute__(self, "_cache"), + input_name_resolver=lambda prop_name: f"{key}_{prop_name}", + ) + proxy = _EntityProxy(entity, data_proxy) + proxied[key] = proxy + return proxy + + def get(self, key, default=None): + """Return a proxied entity when present, default otherwise.""" + real_mapping = object.__getattribute__(self, "_real_mapping") + if key not in real_mapping: + return default + return self[key] + + def __iter__(self): + return iter(object.__getattribute__(self, "_real_mapping")) + + def __len__(self): + return len(object.__getattribute__(self, "_real_mapping")) + + def __getattr__(self, name): + """Forward all other mapping access to the real mapping.""" + return getattr(object.__getattribute__(self, "_real_mapping"), name) class _SceneProxy: """Proxy around the real InteractiveScene. - When an observation term looks up an asset by name, this proxy lazily wraps - Articulation entities in _ArticulationProxy so their data getters annotate. - Non-Articulation entities are returned as-is. + When an observation term looks up a scene entity by name, this proxy lazily + wraps entities whose ``.data`` object exposes ``_leapp_semantics``-decorated + properties. This covers articulations and sensors through both + ``scene["name"]`` and ``scene.sensors["name"]`` access paths. """ - def __init__(self, real_scene, annotating_getters: dict[str, callable], cache: dict): + def __init__(self, real_scene, annotating_getters_by_type: dict[type, dict[str, callable]], cache: dict): object.__setattr__(self, "_real_scene", real_scene) - object.__setattr__(self, "_annotating_getters", annotating_getters) + object.__setattr__(self, "_annotating_getters_by_type", annotating_getters_by_type) object.__setattr__(self, "_cache", cache) object.__setattr__(self, "_proxied", {}) + object.__setattr__(self, "_sensor_mapping_proxy", None) - def __getitem__(self, key): - """Return an ArticulationProxy for Articulation entities, real entity otherwise.""" + def _maybe_proxy_entity(self, key: str, entity: Any): + """Return a proxy for entities whose data object has annotated getters.""" proxied = object.__getattribute__(self, "_proxied") if key in proxied: return proxied[key] + + data = getattr(entity, "data", None) + if data is None: + return entity + + annotating_getters_by_type = object.__getattribute__(self, "_annotating_getters_by_type") + if not _has_annotated_getters(annotating_getters_by_type, data): + return entity + + cache = object.__getattribute__(self, "_cache") + data_proxy = _DataProxy( + data, + annotating_getters_by_type, + cache, + input_name_resolver=( + (lambda prop_name: f"ego_{prop_name}") + if isinstance(entity, Articulation) + else (lambda prop_name: f"{key}_{prop_name}") + ), + ) + proxy = _EntityProxy(entity, data_proxy) + proxied[key] = proxy + return proxy + + def __getitem__(self, key): + """Return a proxied entity when it exposes annotated data getters.""" real_scene = object.__getattribute__(self, "_real_scene") entity = real_scene[key] - if isinstance(entity, Articulation): - getters = object.__getattribute__(self, "_annotating_getters") - cache = object.__getattribute__(self, "_cache") - data_proxy = _ArticulationDataProxy(entity.data, getters, cache) - proxy = _ArticulationProxy(entity, data_proxy) - proxied[key] = proxy - return proxy - return entity + return self._maybe_proxy_entity(key, entity) + + @property + def sensors(self): + """Return a mapping proxy for scene sensors.""" + sensor_mapping_proxy = object.__getattribute__(self, "_sensor_mapping_proxy") + if sensor_mapping_proxy is None: + real_scene = object.__getattribute__(self, "_real_scene") + sensor_mapping_proxy = _EntityMappingProxy( + real_scene.sensors, + object.__getattribute__(self, "_annotating_getters_by_type"), + object.__getattribute__(self, "_cache"), + ) + object.__setattr__(self, "_sensor_mapping_proxy", sensor_mapping_proxy) + return sensor_mapping_proxy def __getattr__(self, name): """Forward all other scene access to the real scene.""" @@ -183,7 +317,7 @@ class _ArticulationWriteProxy: """Proxy around a real Articulation for action terms. Intercepts ``_leapp_semantics``-decorated write methods **and** routes - ``.data`` reads through a shared ``_ArticulationDataProxy`` so that + ``.data`` reads through a shared ``_DataProxy`` so that action-side state reads (e.g. ``self._asset.data.joint_pos`` inside ``RelativeJointPositionAction``) participate in LEAPP annotation and share the dedup cache with observation-side reads. @@ -197,7 +331,7 @@ def __init__( term_name: str, output_cache: list[TensorSemantics], annotating_methods: dict[str, callable], - data_proxy: _ArticulationDataProxy, + data_proxy: _DataProxy, ): object.__setattr__(self, "_real_asset", real_asset) object.__setattr__(self, "_term_name", term_name) @@ -234,11 +368,11 @@ class ExportPatcher: shared dedup cache, then wires them into both: - The observation proxy chain (``_EnvProxy`` → ``_SceneProxy`` → - ``_ArticulationProxy`` → ``_ArticulationDataProxy``) for state reads + ``_EntityProxy`` → ``_DataProxy``) for state reads by observation term functions. - The ``_ArticulationWriteProxy`` on each action term, which intercepts target writes **and** routes ``.data`` reads through the same - ``_ArticulationDataProxy`` / cache. + ``_DataProxy`` / cache. This ensures that a property like ``joint_pos`` read by both an observation term and ``RelativeJointPositionAction.apply_actions()`` @@ -248,7 +382,7 @@ class ExportPatcher: def __init__(self, task_name: str): self.task_name = task_name - self._annotated_tensor_cache: dict[str, torch.Tensor] = {} + self._annotated_tensor_cache: dict[tuple[int, str], torch.Tensor] = {} self._action_output_cache: list[TensorSemantics] = [] self._pending_action_output_export: bool = False self._uses_last_action_state: bool = False @@ -274,33 +408,39 @@ def setup(self, env): # ── Scanning ────────────────────────────────────────────────── - def _build_annotating_getters(self) -> dict[str, callable]: - """Scan ArticulationData for ``_leapp_semantics`` properties and build annotating getters. + def _build_annotating_getters(self) -> dict[type, dict[str, callable]]: + """Scan articulation and sensor data classes for ``_leapp_semantics`` properties. - Returns a dict mapping property name to a callable ``(data_self) -> annotated_tensor``. + Returns a dict mapping data class type to a dict of + ``property_name -> callable(data_self, input_name) -> annotated_tensor``. """ - getters: dict[str, callable] = {} - for prop_name in dir(ArticulationData): - prop = getattr(ArticulationData, prop_name, None) - if isinstance(prop, property) and prop.fget and hasattr(prop.fget, "_leapp_semantics"): - getters[prop_name] = self._make_annotating_getter(prop.fget, prop_name) + getters: dict[type, dict[str, callable]] = {} + for data_cls in _ANNOTATED_DATA_CLASSES: + class_getters: dict[str, callable] = {} + for prop_name in dir(data_cls): + prop = getattr(data_cls, prop_name, None) + if isinstance(prop, property) and prop.fget and hasattr(prop.fget, "_leapp_semantics"): + class_getters[prop_name] = self._make_annotating_getter(prop.fget, prop_name) + if class_getters: + getters[data_cls] = class_getters return getters def _make_annotating_getter(self, original_fget, prop_name: str): - """Create an annotating getter callable for a single ArticulationData property. + """Create an annotating getter callable for a single annotated data property. The returned callable invokes the real getter, then registers the result - as a LEAPP input tensor with the property's semantic metadata. + as a LEAPP input tensor with the property's semantic metadata and the + caller-supplied public input name. """ task_name = self.task_name - def getter(data_self): + def getter(data_self, input_name: str): result = original_fget(data_self) if not isinstance(result, torch.Tensor): return result semantics_meta = getattr(original_fget, "_leapp_semantics", None) sem = TensorSemantics( - name=prop_name, + name=input_name, ref=result, kind=semantics_meta.kind if semantics_meta else None, element_names=resolve_leapp_element_names(semantics_meta, data_self), @@ -402,7 +542,12 @@ def _patch_action_manager(self, action_manager, annotating_getters, cache, annot asset = getattr(term, "_asset", None) if isinstance(asset, Articulation): real_asset: Articulation = asset - data_proxy = _ArticulationDataProxy(real_asset.data, annotating_getters, cache) + data_proxy = _DataProxy( + real_asset.data, + annotating_getters, + cache, + input_name_resolver=lambda prop_name: f"ego_{prop_name}", + ) term._asset = _ArticulationWriteProxy( real_asset=real_asset, term_name=term_name, diff --git a/source/isaaclab/isaaclab/assets/articulation/articulation_data.py b/source/isaaclab/isaaclab/assets/articulation/articulation_data.py index a627e91e9e3f..2f4a88546efc 100644 --- a/source/isaaclab/isaaclab/assets/articulation/articulation_data.py +++ b/source/isaaclab/isaaclab/assets/articulation/articulation_data.py @@ -721,6 +721,7 @@ def body_com_pose_b(self) -> torch.Tensor: return self._body_com_pose_b.data @property + @leapp_tensor_semantics(kind="state/body/incoming_joint_wrench", element_names_source="body_wrench") def body_incoming_joint_wrench_b(self) -> torch.Tensor: """Joint reaction wrench applied from body parent to child body in parent body frame. diff --git a/source/isaaclab/isaaclab/envs/direct_deployment_env.py b/source/isaaclab/isaaclab/envs/direct_deployment_env.py index f1e2dd63d444..7841282432b3 100644 --- a/source/isaaclab/isaaclab/envs/direct_deployment_env.py +++ b/source/isaaclab/isaaclab/envs/direct_deployment_env.py @@ -19,6 +19,8 @@ from dataclasses import dataclass from typing import Any +from leapp import InferenceManager + from isaaclab.assets.articulation.articulation import Articulation from isaaclab.assets.articulation.articulation_data import ArticulationData from isaaclab.managers import CommandManager, EventManager @@ -173,7 +175,6 @@ def __init__(self, cfg: Any, leapp_yaml_path: str): cfg: A ``ManagerBasedRLEnvCfg`` (or compatible) task config. leapp_yaml_path: Path to the LEAPP ``.yaml`` pipeline description. """ - from leapp import InferenceManager cfg.scene.num_envs = 1 cfg.validate() @@ -216,10 +217,6 @@ def __init__(self, cfg: Any, leapp_yaml_path: str): self._output_mapping: dict[str, OutputSpec] = {} self._resolve_io() - # ── Cache feedback initial values for cheap reset ───────── - self._feedback_initial_values: dict[str, torch.Tensor] = {} - self._cache_feedback_initial_values() - logger.info( "DirectDeploymentEnv ready — %d inputs, %d outputs mapped", len(self._input_mapping), @@ -296,25 +293,6 @@ def _resolve_io(self): jids = _resolve_joint_ids(desc.get("element_names"), self._asset) if needs_joint_ids else None self._output_mapping[key] = OutputSpec(method_name=method_name, joint_ids=jids) - # ── Feedback state caching ──────────────────────────────────── - - def _cache_feedback_initial_values(self): - """Snapshot the feedback input buffers right after InferenceManager init. - - This allows ``reset()`` to cheaply restore feedback state without - re-instantiating the entire InferenceManager. - """ - for fb_key in self.inference.feedback_inputs: - node_name, input_name = fb_key.split("/") - tensor = self.inference.value_dict[node_name][input_name] - self._feedback_initial_values[fb_key] = tensor.clone() - - def _restore_feedback_initial_values(self): - """Restore feedback input buffers to their initial values.""" - for fb_key, initial in self._feedback_initial_values.items(): - node_name, input_name = fb_key.split("/") - self.inference.value_dict[node_name][input_name] = initial.clone() - # ── Read / Write ────────────────────────────────────────────── def _read_inputs(self) -> dict[str, torch.Tensor]: @@ -345,7 +323,7 @@ def _write_outputs(self, outputs: dict[str, torch.Tensor]): # ── Public API ──────────────────────────────────────────────── def reset(self) -> dict[str, torch.Tensor]: - """Reset the scene and feedback state. + """Reset the scene and inference state. Returns: The initial input tensors (for logging / debugging). @@ -363,7 +341,7 @@ def reset(self) -> dict[str, torch.Tensor]: self.sim.forward() self.scene.update(dt=self.physics_dt) - self._restore_feedback_initial_values() + self.inference.reset() return self._read_inputs() diff --git a/source/isaaclab/isaaclab/sensors/contact_sensor/contact_sensor.py b/source/isaaclab/isaaclab/sensors/contact_sensor/contact_sensor.py index 3a3f4d5c2e9b..b5ed1f9b3488 100644 --- a/source/isaaclab/isaaclab/sensors/contact_sensor/contact_sensor.py +++ b/source/isaaclab/isaaclab/sensors/contact_sensor/contact_sensor.py @@ -147,24 +147,24 @@ def reset(self, env_ids: Sequence[int] | None = None): if env_ids is None: env_ids = slice(None) # reset accumulative data buffers - self._data.net_forces_w[env_ids] = 0.0 - self._data.net_forces_w_history[env_ids] = 0.0 + self._data._net_forces_w[env_ids] = 0.0 + self._data._net_forces_w_history[env_ids] = 0.0 # reset force matrix if len(self.cfg.filter_prim_paths_expr) != 0: - self._data.force_matrix_w[env_ids] = 0.0 - self._data.force_matrix_w_history[env_ids] = 0.0 + self._data._force_matrix_w[env_ids] = 0.0 + self._data._force_matrix_w_history[env_ids] = 0.0 # reset the current air time if self.cfg.track_air_time: - self._data.current_air_time[env_ids] = 0.0 - self._data.last_air_time[env_ids] = 0.0 - self._data.current_contact_time[env_ids] = 0.0 - self._data.last_contact_time[env_ids] = 0.0 + self._data._current_air_time[env_ids] = 0.0 + self._data._last_air_time[env_ids] = 0.0 + self._data._current_contact_time[env_ids] = 0.0 + self._data._last_contact_time[env_ids] = 0.0 # reset contact positions if self.cfg.track_contact_points: - self._data.contact_pos_w[env_ids, :] = torch.nan + self._data._contact_pos_w[env_ids, :] = torch.nan # reset friction forces if self.cfg.track_friction_forces: - self._data.friction_forces_w[env_ids, :] = 0.0 + self._data._friction_forces_w[env_ids, :] = 0.0 def find_bodies(self, name_keys: str | Sequence[str], preserve_order: bool = False) -> tuple[list[int], list[str]]: """Find bodies in the articulation based on the name keys. @@ -298,19 +298,20 @@ def _initialize_impl(self): ) # prepare data buffers - self._data.net_forces_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device) + self._data._body_names = list(self.body_names) + self._data._net_forces_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device) # optional buffers # -- history of net forces if self.cfg.history_length > 0: - self._data.net_forces_w_history = torch.zeros( + self._data._net_forces_w_history = torch.zeros( self._num_envs, self.cfg.history_length, self._num_bodies, 3, device=self._device ) else: - self._data.net_forces_w_history = self._data.net_forces_w.unsqueeze(1) + self._data._net_forces_w_history = self._data._net_forces_w.unsqueeze(1) # -- pose of sensor origins if self.cfg.track_pose: - self._data.pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device) - self._data.quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device) + self._data._pos_w = torch.zeros(self._num_envs, self._num_bodies, 3, device=self._device) + self._data._quat_w = torch.zeros(self._num_envs, self._num_bodies, 4, device=self._device) # check if filter paths are valid if self.cfg.track_contact_points or self.cfg.track_friction_forces: @@ -328,37 +329,37 @@ def _initialize_impl(self): # -- position of contact points if self.cfg.track_contact_points: - self._data.contact_pos_w = torch.full( + self._data._contact_pos_w = torch.full( (self._num_envs, self._num_bodies, self.contact_physx_view.filter_count, 3), torch.nan, device=self._device, ) # -- friction forces at contact points if self.cfg.track_friction_forces: - self._data.friction_forces_w = torch.full( + self._data._friction_forces_w = torch.full( (self._num_envs, self._num_bodies, self.contact_physx_view.filter_count, 3), 0.0, device=self._device, ) # -- air/contact time between contacts if self.cfg.track_air_time: - self._data.last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) - self._data.current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) - self._data.last_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) - self._data.current_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) + self._data._last_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) + self._data._current_air_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) + self._data._last_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) + self._data._current_contact_time = torch.zeros(self._num_envs, self._num_bodies, device=self._device) # force matrix: (num_envs, num_bodies, num_filter_shapes, 3) # force matrix history: (num_envs, history_length, num_bodies, num_filter_shapes, 3) if len(self.cfg.filter_prim_paths_expr) != 0: num_filters = self.contact_physx_view.filter_count - self._data.force_matrix_w = torch.zeros( + self._data._force_matrix_w = torch.zeros( self._num_envs, self._num_bodies, num_filters, 3, device=self._device ) if self.cfg.history_length > 0: - self._data.force_matrix_w_history = torch.zeros( + self._data._force_matrix_w_history = torch.zeros( self._num_envs, self.cfg.history_length, self._num_bodies, num_filters, 3, device=self._device ) else: - self._data.force_matrix_w_history = self._data.force_matrix_w.unsqueeze(1) + self._data._force_matrix_w_history = self._data._force_matrix_w.unsqueeze(1) def _update_buffers_impl(self, env_ids: Sequence[int]): """Fills the buffers of the sensor data.""" @@ -370,11 +371,11 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): # TODO: We are handling the indexing ourself because of the shape; (N, B) vs expected (N * B). # This isn't the most efficient way to do this, but it's the easiest to implement. net_forces_w = self.contact_physx_view.get_net_contact_forces(dt=self._sim_physics_dt) - self._data.net_forces_w[env_ids, :, :] = net_forces_w.view(-1, self._num_bodies, 3)[env_ids] + self._data._net_forces_w[env_ids, :, :] = net_forces_w.view(-1, self._num_bodies, 3)[env_ids] # update contact force history if self.cfg.history_length > 0: - self._data.net_forces_w_history[env_ids] = self._data.net_forces_w_history[env_ids].roll(1, dims=1) - self._data.net_forces_w_history[env_ids, 0] = self._data.net_forces_w[env_ids] + self._data._net_forces_w_history[env_ids] = self._data._net_forces_w_history[env_ids].roll(1, dims=1) + self._data._net_forces_w_history[env_ids, 0] = self._data._net_forces_w[env_ids] # obtain the contact force matrix if len(self.cfg.filter_prim_paths_expr) != 0: @@ -383,23 +384,25 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): # acquire and shape the force matrix force_matrix_w = self.contact_physx_view.get_contact_force_matrix(dt=self._sim_physics_dt) force_matrix_w = force_matrix_w.view(-1, self._num_bodies, num_filters, 3) - self._data.force_matrix_w[env_ids] = force_matrix_w[env_ids] + self._data._force_matrix_w[env_ids] = force_matrix_w[env_ids] if self.cfg.history_length > 0: - self._data.force_matrix_w_history[env_ids] = self._data.force_matrix_w_history[env_ids].roll(1, dims=1) - self._data.force_matrix_w_history[env_ids, 0] = self._data.force_matrix_w[env_ids] + self._data._force_matrix_w_history[env_ids] = self._data._force_matrix_w_history[env_ids].roll( + 1, dims=1 + ) + self._data._force_matrix_w_history[env_ids, 0] = self._data._force_matrix_w[env_ids] # obtain the pose of the sensor origin if self.cfg.track_pose: pose = self.body_physx_view.get_transforms().view(-1, self._num_bodies, 7)[env_ids] pose[..., 3:] = convert_quat(pose[..., 3:], to="wxyz") - self._data.pos_w[env_ids], self._data.quat_w[env_ids] = pose.split([3, 4], dim=-1) + self._data._pos_w[env_ids], self._data._quat_w[env_ids] = pose.split([3, 4], dim=-1) # obtain contact points if self.cfg.track_contact_points: _, buffer_contact_points, _, _, buffer_count, buffer_start_indices = ( self.contact_physx_view.get_contact_data(dt=self._sim_physics_dt) ) - self._data.contact_pos_w[env_ids] = self._unpack_contact_buffer_data( + self._data._contact_pos_w[env_ids] = self._unpack_contact_buffer_data( buffer_contact_points, buffer_count, buffer_start_indices )[env_ids] @@ -408,7 +411,7 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): friction_forces, _, buffer_count, buffer_start_indices = self.contact_physx_view.get_friction_data( dt=self._sim_physics_dt ) - self._data.friction_forces_w[env_ids] = self._unpack_contact_buffer_data( + self._data._friction_forces_w[env_ids] = self._unpack_contact_buffer_data( friction_forces, buffer_count, buffer_start_indices, avg=False, default=0.0 )[env_ids] @@ -418,28 +421,28 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): # since this function is called every frame, we can use the difference to get the elapsed time elapsed_time = self._timestamp[env_ids] - self._timestamp_last_update[env_ids] # -- check contact state of bodies - is_contact = torch.norm(self._data.net_forces_w[env_ids, :, :], dim=-1) > self.cfg.force_threshold - is_first_contact = (self._data.current_air_time[env_ids] > 0) * is_contact - is_first_detached = (self._data.current_contact_time[env_ids] > 0) * ~is_contact + is_contact = torch.norm(self._data._net_forces_w[env_ids, :, :], dim=-1) > self.cfg.force_threshold + is_first_contact = (self._data._current_air_time[env_ids] > 0) * is_contact + is_first_detached = (self._data._current_contact_time[env_ids] > 0) * ~is_contact # -- update the last contact time if body has just become in contact - self._data.last_air_time[env_ids] = torch.where( + self._data._last_air_time[env_ids] = torch.where( is_first_contact, - self._data.current_air_time[env_ids] + elapsed_time.unsqueeze(-1), - self._data.last_air_time[env_ids], + self._data._current_air_time[env_ids] + elapsed_time.unsqueeze(-1), + self._data._last_air_time[env_ids], ) # -- increment time for bodies that are not in contact - self._data.current_air_time[env_ids] = torch.where( - ~is_contact, self._data.current_air_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0 + self._data._current_air_time[env_ids] = torch.where( + ~is_contact, self._data._current_air_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0 ) # -- update the last contact time if body has just detached - self._data.last_contact_time[env_ids] = torch.where( + self._data._last_contact_time[env_ids] = torch.where( is_first_detached, - self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), - self._data.last_contact_time[env_ids], + self._data._current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), + self._data._last_contact_time[env_ids], ) # -- increment time for bodies that are in contact - self._data.current_contact_time[env_ids] = torch.where( - is_contact, self._data.current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0 + self._data._current_contact_time[env_ids] = torch.where( + is_contact, self._data._current_contact_time[env_ids] + elapsed_time.unsqueeze(-1), 0.0 ) def _unpack_contact_buffer_data( diff --git a/source/isaaclab/isaaclab/sensors/contact_sensor/contact_sensor_data.py b/source/isaaclab/isaaclab/sensors/contact_sensor/contact_sensor_data.py index c959792f77fc..ced3535187a7 100644 --- a/source/isaaclab/isaaclab/sensors/contact_sensor/contact_sensor_data.py +++ b/source/isaaclab/isaaclab/sensors/contact_sensor/contact_sensor_data.py @@ -9,12 +9,15 @@ import torch from dataclasses import dataclass +from isaaclab.utils.leapp_semantics import leapp_tensor_semantics + @dataclass class ContactSensorData: """Data container for the contact reporting sensor.""" - pos_w: torch.Tensor | None = None + _body_names: list[str] | None = None + _pos_w: torch.Tensor | None = None """Position of the sensor origin in world frame. Shape is (N, 3), where N is the number of sensors. @@ -24,7 +27,7 @@ class ContactSensorData: """ - contact_pos_w: torch.Tensor | None = None + _contact_pos_w: torch.Tensor | None = None """Average of the positions of contact points between sensor body and filter prim in world frame. Shape is (N, B, M, 3), where N is the number of sensors, B is number of bodies in each sensor @@ -43,7 +46,7 @@ class ContactSensorData: """ - friction_forces_w: torch.Tensor | None = None + _friction_forces_w: torch.Tensor | None = None """Sum of the friction forces between sensor body and filter prim in world frame. Shape is (N, B, M, 3), where N is the number of sensors, B is number of bodies in each sensor @@ -61,7 +64,7 @@ class ContactSensorData: """ - quat_w: torch.Tensor | None = None + _quat_w: torch.Tensor | None = None """Orientation of the sensor origin in quaternion (w, x, y, z) in world frame. Shape is (N, 4), where N is the number of sensors. @@ -70,7 +73,7 @@ class ContactSensorData: If the :attr:`ContactSensorCfg.track_pose` is False, then this quantity is None. """ - net_forces_w: torch.Tensor | None = None + _net_forces_w: torch.Tensor | None = None """The net normal contact forces in world frame. Shape is (N, B, 3), where N is the number of sensors and B is the number of bodies in each sensor. @@ -80,7 +83,7 @@ class ContactSensorData: with the total contact forces acting on the sensor bodies (which also includes the tangential forces). """ - net_forces_w_history: torch.Tensor | None = None + _net_forces_w_history: torch.Tensor | None = None """The net normal contact forces in world frame. Shape is (N, T, B, 3), where N is the number of sensors, T is the configured history length @@ -93,7 +96,7 @@ class ContactSensorData: with the total contact forces acting on the sensor bodies (which also includes the tangential forces). """ - force_matrix_w: torch.Tensor | None = None + _force_matrix_w: torch.Tensor | None = None """The normal contact forces filtered between the sensor bodies and filtered bodies in world frame. Shape is (N, B, M, 3), where N is the number of sensors, B is number of bodies in each sensor @@ -103,7 +106,7 @@ class ContactSensorData: If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None. """ - force_matrix_w_history: torch.Tensor | None = None + _force_matrix_w_history: torch.Tensor | None = None """The normal contact forces filtered between the sensor bodies and filtered bodies in world frame. Shape is (N, T, B, M, 3), where N is the number of sensors, T is the configured history length, @@ -115,7 +118,7 @@ class ContactSensorData: If the :attr:`ContactSensorCfg.filter_prim_paths_expr` is empty, then this quantity is None. """ - last_air_time: torch.Tensor | None = None + _last_air_time: torch.Tensor | None = None """Time spent (in s) in the air before the last contact. Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor. @@ -124,7 +127,7 @@ class ContactSensorData: If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None. """ - current_air_time: torch.Tensor | None = None + _current_air_time: torch.Tensor | None = None """Time spent (in s) in the air since the last detach. Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor. @@ -133,7 +136,7 @@ class ContactSensorData: If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None. """ - last_contact_time: torch.Tensor | None = None + _last_contact_time: torch.Tensor | None = None """Time spent (in s) in contact before the last detach. Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor. @@ -142,7 +145,7 @@ class ContactSensorData: If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None. """ - current_contact_time: torch.Tensor | None = None + _current_contact_time: torch.Tensor | None = None """Time spent (in s) in contact since the last contact. Shape is (N, B), where N is the number of sensors and B is the number of bodies in each sensor. @@ -150,3 +153,75 @@ class ContactSensorData: Note: If the :attr:`ContactSensorCfg.track_air_time` is False, then this quantity is None. """ + + @property + @leapp_tensor_semantics(kind="state/sensor/position", element_names_source="body_xyz") + def pos_w(self) -> torch.Tensor | None: + """Position of the sensor origin in world frame.""" + return self._pos_w + + @property + @leapp_tensor_semantics(kind="state/sensor/rotation", element_names_source="body_quat") + def quat_w(self) -> torch.Tensor | None: + """Orientation of the sensor origin in quaternion (w, x, y, z) in world frame.""" + return self._quat_w + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/contact_position") + def contact_pos_w(self) -> torch.Tensor | None: + """Average contact positions in world frame.""" + return self._contact_pos_w + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/friction_force") + def friction_forces_w(self) -> torch.Tensor | None: + """Friction forces between sensor body and filter prim in world frame.""" + return self._friction_forces_w + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/net_force", element_names_source="body_xyz") + def net_forces_w(self) -> torch.Tensor | None: + """Net normal contact forces in world frame.""" + return self._net_forces_w + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/net_force_history") + def net_forces_w_history(self) -> torch.Tensor | None: + """History of net normal contact forces in world frame.""" + return self._net_forces_w_history + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/force_matrix") + def force_matrix_w(self) -> torch.Tensor | None: + """Filtered contact force matrix in world frame.""" + return self._force_matrix_w + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/force_matrix_history") + def force_matrix_w_history(self) -> torch.Tensor | None: + """History of filtered contact force matrices in world frame.""" + return self._force_matrix_w_history + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/last_air_time", element_names_source="body_names") + def last_air_time(self) -> torch.Tensor | None: + """Time spent in the air before the last contact.""" + return self._last_air_time + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/current_air_time", element_names_source="body_names") + def current_air_time(self) -> torch.Tensor | None: + """Time spent in the air since the last detach.""" + return self._current_air_time + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/last_contact_time", element_names_source="body_names") + def last_contact_time(self) -> torch.Tensor | None: + """Time spent in contact before the last detach.""" + return self._last_contact_time + + @property + @leapp_tensor_semantics(kind="state/contact_sensor/current_contact_time", element_names_source="body_names") + def current_contact_time(self) -> torch.Tensor | None: + """Time spent in contact since the last contact.""" + return self._current_contact_time diff --git a/source/isaaclab/isaaclab/sensors/frame_transformer/frame_transformer.py b/source/isaaclab/isaaclab/sensors/frame_transformer/frame_transformer.py index 50f75b565e17..cd631dd9d176 100644 --- a/source/isaaclab/isaaclab/sensors/frame_transformer/frame_transformer.py +++ b/source/isaaclab/isaaclab/sensors/frame_transformer/frame_transformer.py @@ -353,13 +353,13 @@ def extract_env_num_and_prim_path(item: str) -> tuple[int, str]: self._target_frame_offset_quat = torch.stack(target_frame_offset_quat).repeat(self._num_envs, 1) # fill the data buffer - self._data.target_frame_names = self._target_frame_names - self._data.source_pos_w = torch.zeros(self._num_envs, 3, device=self._device) - self._data.source_quat_w = torch.zeros(self._num_envs, 4, device=self._device) - self._data.target_pos_w = torch.zeros(self._num_envs, len(duplicate_frame_indices), 3, device=self._device) - self._data.target_quat_w = torch.zeros(self._num_envs, len(duplicate_frame_indices), 4, device=self._device) - self._data.target_pos_source = torch.zeros_like(self._data.target_pos_w) - self._data.target_quat_source = torch.zeros_like(self._data.target_quat_w) + self._data._target_frame_names = self._target_frame_names + self._data._source_pos_w = torch.zeros(self._num_envs, 3, device=self._device) + self._data._source_quat_w = torch.zeros(self._num_envs, 4, device=self._device) + self._data._target_pos_w = torch.zeros(self._num_envs, len(duplicate_frame_indices), 3, device=self._device) + self._data._target_quat_w = torch.zeros(self._num_envs, len(duplicate_frame_indices), 4, device=self._device) + self._data._target_pos_source = torch.zeros_like(self._data._target_pos_w) + self._data._target_quat_source = torch.zeros_like(self._data._target_quat_w) def _update_buffers_impl(self, env_ids: Sequence[int]): """Fills the buffers of the sensor data.""" @@ -419,12 +419,12 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): # Update buffers # note: The frame names / ordering don't change so no need to update them after initialization - self._data.source_pos_w[:] = source_pos_w.view(-1, 3) - self._data.source_quat_w[:] = source_quat_w.view(-1, 4) - self._data.target_pos_w[:] = target_pos_w.view(-1, total_num_frames, 3) - self._data.target_quat_w[:] = target_quat_w.view(-1, total_num_frames, 4) - self._data.target_pos_source[:] = target_pos_source.view(-1, total_num_frames, 3) - self._data.target_quat_source[:] = target_quat_source.view(-1, total_num_frames, 4) + self._data._source_pos_w[:] = source_pos_w.view(-1, 3) + self._data._source_quat_w[:] = source_quat_w.view(-1, 4) + self._data._target_pos_w[:] = target_pos_w.view(-1, total_num_frames, 3) + self._data._target_quat_w[:] = target_quat_w.view(-1, total_num_frames, 4) + self._data._target_pos_source[:] = target_pos_source.view(-1, total_num_frames, 3) + self._data._target_quat_source[:] = target_quat_source.view(-1, total_num_frames, 4) def _set_debug_vis_impl(self, debug_vis: bool): # set visibility of markers diff --git a/source/isaaclab/isaaclab/sensors/frame_transformer/frame_transformer_data.py b/source/isaaclab/isaaclab/sensors/frame_transformer/frame_transformer_data.py index 7ce9b0f436d6..42c2b5c962cf 100644 --- a/source/isaaclab/isaaclab/sensors/frame_transformer/frame_transformer_data.py +++ b/source/isaaclab/isaaclab/sensors/frame_transformer/frame_transformer_data.py @@ -6,12 +6,14 @@ import torch from dataclasses import dataclass +from isaaclab.utils.leapp_semantics import leapp_tensor_semantics + @dataclass class FrameTransformerData: """Data container for the frame transformer sensor.""" - target_frame_names: list[str] = None + _target_frame_names: list[str] = None """Target frame names (this denotes the order in which that frame data is ordered). The frame names are resolved from the :attr:`FrameTransformerCfg.FrameCfg.name` field. @@ -19,38 +21,92 @@ class FrameTransformerData: the regex matching. """ - target_pos_source: torch.Tensor = None + _target_pos_source: torch.Tensor = None """Position of the target frame(s) relative to source frame. Shape is (N, M, 3), where N is the number of environments, and M is the number of target frames. """ - target_quat_source: torch.Tensor = None + _target_quat_source: torch.Tensor = None """Orientation of the target frame(s) relative to source frame quaternion (w, x, y, z). Shape is (N, M, 4), where N is the number of environments, and M is the number of target frames. """ - target_pos_w: torch.Tensor = None + _target_pos_w: torch.Tensor = None """Position of the target frame(s) after offset (in world frame). Shape is (N, M, 3), where N is the number of environments, and M is the number of target frames. """ - target_quat_w: torch.Tensor = None + _target_quat_w: torch.Tensor = None """Orientation of the target frame(s) after offset (in world frame) quaternion (w, x, y, z). Shape is (N, M, 4), where N is the number of environments, and M is the number of target frames. """ - source_pos_w: torch.Tensor = None + _source_pos_w: torch.Tensor = None """Position of the source frame after offset (in world frame). Shape is (N, 3), where N is the number of environments. """ - source_quat_w: torch.Tensor = None + _source_quat_w: torch.Tensor = None """Orientation of the source frame after offset (in world frame) quaternion (w, x, y, z). Shape is (N, 4), where N is the number of environments. """ + + @property + def target_frame_names(self) -> list[str] | None: + """Target frame names in the same order as the target-frame tensors.""" + return self._target_frame_names + + @property + @leapp_tensor_semantics( + kind="state/frame_transformer/target_position_source", element_names=[None, None, ["x", "y", "z"]] + ) + def target_pos_source(self) -> torch.Tensor: + """Position of the target frame(s) relative to source frame.""" + return self._target_pos_source + + @property + @leapp_tensor_semantics( + kind="state/frame_transformer/target_rotation_source", + element_names=[None, None, ["qw", "qx", "qy", "qz"]], + ) + def target_quat_source(self) -> torch.Tensor: + """Orientation of the target frame(s) relative to source frame quaternion (w, x, y, z).""" + return self._target_quat_source + + @property + @leapp_tensor_semantics( + kind="state/frame_transformer/target_position_world", element_names=[None, None, ["x", "y", "z"]] + ) + def target_pos_w(self) -> torch.Tensor: + """Position of the target frame(s) after offset in world frame.""" + return self._target_pos_w + + @property + @leapp_tensor_semantics( + kind="state/frame_transformer/target_rotation_world", + element_names=[None, None, ["qw", "qx", "qy", "qz"]], + ) + def target_quat_w(self) -> torch.Tensor: + """Orientation of the target frame(s) after offset in world frame quaternion (w, x, y, z).""" + return self._target_quat_w + + @property + @leapp_tensor_semantics(kind="state/frame_transformer/source_position_world", element_names=[None, ["x", "y", "z"]]) + def source_pos_w(self) -> torch.Tensor: + """Position of the source frame after offset in world frame.""" + return self._source_pos_w + + @property + @leapp_tensor_semantics( + kind="state/frame_transformer/source_rotation_world", + element_names=[None, ["qw", "qx", "qy", "qz"]], + ) + def source_quat_w(self) -> torch.Tensor: + """Orientation of the source frame after offset in world frame quaternion (w, x, y, z).""" + return self._source_quat_w diff --git a/source/isaaclab/isaaclab/sensors/imu/imu.py b/source/isaaclab/isaaclab/sensors/imu/imu.py index 1cf0dda12b14..dd10a6e8a0ce 100644 --- a/source/isaaclab/isaaclab/sensors/imu/imu.py +++ b/source/isaaclab/isaaclab/sensors/imu/imu.py @@ -103,15 +103,15 @@ def reset(self, env_ids: Sequence[int] | None = None): if env_ids is None: env_ids = slice(None) # reset accumulative data buffers - self._data.pos_w[env_ids] = 0.0 - self._data.quat_w[env_ids] = 0.0 - self._data.quat_w[env_ids, 0] = 1.0 - self._data.projected_gravity_b[env_ids] = 0.0 - self._data.projected_gravity_b[env_ids, 2] = -1.0 - self._data.lin_vel_b[env_ids] = 0.0 - self._data.ang_vel_b[env_ids] = 0.0 - self._data.lin_acc_b[env_ids] = 0.0 - self._data.ang_acc_b[env_ids] = 0.0 + self._data._pos_w[env_ids] = 0.0 + self._data._quat_w[env_ids] = 0.0 + self._data._quat_w[env_ids, 0] = 1.0 + self._data._projected_gravity_b[env_ids] = 0.0 + self._data._projected_gravity_b[env_ids, 2] = -1.0 + self._data._lin_vel_b[env_ids] = 0.0 + self._data._ang_vel_b[env_ids] = 0.0 + self._data._lin_acc_b[env_ids] = 0.0 + self._data._ang_acc_b[env_ids] = 0.0 self._prev_lin_vel_w[env_ids] = 0.0 self._prev_ang_vel_w[env_ids] = 0.0 @@ -198,8 +198,8 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): quat_w = quat_w.roll(1, dims=-1) # sensor pose in world: apply composed offset - self._data.pos_w[env_ids] = pos_w + math_utils.quat_apply(quat_w, self._offset_pos_b[env_ids]) - self._data.quat_w[env_ids] = math_utils.quat_mul(quat_w, self._offset_quat_b[env_ids]) + self._data._pos_w[env_ids] = pos_w + math_utils.quat_apply(quat_w, self._offset_pos_b[env_ids]) + self._data._quat_w[env_ids] = math_utils.quat_mul(quat_w, self._offset_quat_b[env_ids]) # COM of rigid source (body frame) com_pos_b = self._view.get_coms().to(self.device).split([3, 4], dim=-1)[0] @@ -218,17 +218,17 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): # batch rotate world->body using current sensor orientation dynamics_data = torch.stack((lin_vel_w, ang_vel_w, lin_acc_w, ang_acc_w, self.GRAVITY_VEC_W[env_ids]), dim=0) - dynamics_data_rot = math_utils.quat_apply_inverse(self._data.quat_w[env_ids].repeat(5, 1), dynamics_data).chunk( - 5, dim=0 - ) + dynamics_data_rot = math_utils.quat_apply_inverse( + self._data._quat_w[env_ids].repeat(5, 1), dynamics_data + ).chunk(5, dim=0) # store the velocities. - self._data.lin_vel_b[env_ids] = dynamics_data_rot[0] - self._data.ang_vel_b[env_ids] = dynamics_data_rot[1] + self._data._lin_vel_b[env_ids] = dynamics_data_rot[0] + self._data._ang_vel_b[env_ids] = dynamics_data_rot[1] # store the accelerations - self._data.lin_acc_b[env_ids] = dynamics_data_rot[2] - self._data.ang_acc_b[env_ids] = dynamics_data_rot[3] + self._data._lin_acc_b[env_ids] = dynamics_data_rot[2] + self._data._ang_acc_b[env_ids] = dynamics_data_rot[3] # store projected gravity - self._data.projected_gravity_b[env_ids] = dynamics_data_rot[4] + self._data._projected_gravity_b[env_ids] = dynamics_data_rot[4] self._prev_lin_vel_w[env_ids] = lin_vel_w self._prev_ang_vel_w[env_ids] = ang_vel_w @@ -236,16 +236,16 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): def _initialize_buffers_impl(self): """Create buffers for storing data.""" # data buffers - self._data.pos_w = torch.zeros(self._view.count, 3, device=self._device) - self._data.quat_w = torch.zeros(self._view.count, 4, device=self._device) - self._data.quat_w[:, 0] = 1.0 - self._data.projected_gravity_b = torch.zeros(self._view.count, 3, device=self._device) - self._data.lin_vel_b = torch.zeros_like(self._data.pos_w) - self._data.ang_vel_b = torch.zeros_like(self._data.pos_w) - self._data.lin_acc_b = torch.zeros_like(self._data.pos_w) - self._data.ang_acc_b = torch.zeros_like(self._data.pos_w) - self._prev_lin_vel_w = torch.zeros_like(self._data.pos_w) - self._prev_ang_vel_w = torch.zeros_like(self._data.pos_w) + self._data._pos_w = torch.zeros(self._view.count, 3, device=self._device) + self._data._quat_w = torch.zeros(self._view.count, 4, device=self._device) + self._data._quat_w[:, 0] = 1.0 + self._data._projected_gravity_b = torch.zeros(self._view.count, 3, device=self._device) + self._data._lin_vel_b = torch.zeros_like(self._data._pos_w) + self._data._ang_vel_b = torch.zeros_like(self._data._pos_w) + self._data._lin_acc_b = torch.zeros_like(self._data._pos_w) + self._data._ang_acc_b = torch.zeros_like(self._data._pos_w) + self._prev_lin_vel_w = torch.zeros_like(self._data._pos_w) + self._prev_ang_vel_w = torch.zeros_like(self._data._pos_w) # store sensor offset (applied relative to rigid source). This may be composed later with a fixed ancestor->target transform. self._offset_pos_b = torch.tensor(list(self.cfg.offset.pos), device=self._device).repeat(self._view.count, 1) diff --git a/source/isaaclab/isaaclab/sensors/imu/imu_data.py b/source/isaaclab/isaaclab/sensors/imu/imu_data.py index ee365f191468..e41cc44900f2 100644 --- a/source/isaaclab/isaaclab/sensors/imu/imu_data.py +++ b/source/isaaclab/isaaclab/sensors/imu/imu_data.py @@ -8,49 +8,93 @@ import torch from dataclasses import dataclass +from isaaclab.utils.leapp_semantics import leapp_tensor_semantics + @dataclass class ImuData: """Data container for the Imu sensor.""" - pos_w: torch.Tensor = None + _pos_w: torch.Tensor = None """Position of the sensor origin in world frame. Shape is (N, 3), where ``N`` is the number of environments. """ - quat_w: torch.Tensor = None + _quat_w: torch.Tensor = None """Orientation of the sensor origin in quaternion ``(w, x, y, z)`` in world frame. Shape is (N, 4), where ``N`` is the number of environments. """ - projected_gravity_b: torch.Tensor = None + _projected_gravity_b: torch.Tensor = None """Gravity direction unit vector projected on the imu frame. Shape is (N,3), where ``N`` is the number of environments. """ - lin_vel_b: torch.Tensor = None + _lin_vel_b: torch.Tensor = None """IMU frame angular velocity relative to the world expressed in IMU frame. Shape is (N, 3), where ``N`` is the number of environments. """ - ang_vel_b: torch.Tensor = None + _ang_vel_b: torch.Tensor = None """IMU frame angular velocity relative to the world expressed in IMU frame. Shape is (N, 3), where ``N`` is the number of environments. """ - lin_acc_b: torch.Tensor = None + _lin_acc_b: torch.Tensor = None """IMU frame linear acceleration relative to the world expressed in IMU frame. Shape is (N, 3), where ``N`` is the number of environments. """ - ang_acc_b: torch.Tensor = None + _ang_acc_b: torch.Tensor = None """IMU frame angular acceleration relative to the world expressed in IMU frame. Shape is (N, 3), where ``N`` is the number of environments. """ + + @property + @leapp_tensor_semantics(kind="state/sensor/position", element_names_source="xyz") + def pos_w(self) -> torch.Tensor: + """Position of the sensor origin in world frame.""" + return self._pos_w + + @property + @leapp_tensor_semantics(kind="state/sensor/rotation", element_names_source="quat_wxyz") + def quat_w(self) -> torch.Tensor: + """Orientation of the sensor origin in quaternion ``(w, x, y, z)`` in world frame.""" + return self._quat_w + + @property + @leapp_tensor_semantics(kind="state/imu/projected_gravity", element_names_source="xyz") + def projected_gravity_b(self) -> torch.Tensor: + """Gravity direction unit vector projected on the imu frame.""" + return self._projected_gravity_b + + @property + @leapp_tensor_semantics(kind="state/imu/linear_velocity", element_names_source="xyz") + def lin_vel_b(self) -> torch.Tensor: + """IMU frame linear velocity relative to the world expressed in IMU frame.""" + return self._lin_vel_b + + @property + @leapp_tensor_semantics(kind="state/imu/angular_velocity", element_names_source="xyz") + def ang_vel_b(self) -> torch.Tensor: + """IMU frame angular velocity relative to the world expressed in IMU frame.""" + return self._ang_vel_b + + @property + @leapp_tensor_semantics(kind="state/imu/linear_acceleration", element_names_source="xyz") + def lin_acc_b(self) -> torch.Tensor: + """IMU frame linear acceleration relative to the world expressed in IMU frame.""" + return self._lin_acc_b + + @property + @leapp_tensor_semantics(kind="state/imu/angular_acceleration", element_names_source="xyz") + def ang_acc_b(self) -> torch.Tensor: + """IMU frame angular acceleration relative to the world expressed in IMU frame.""" + return self._ang_acc_b diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster_camera.py b/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster_camera.py index 6e57c4b04600..136149e6e991 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster_camera.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster_camera.py @@ -105,8 +105,8 @@ def _initialize_rays_impl(self): self._offset_quat = quat_w.repeat(self._view.count, 1) self._offset_pos = torch.tensor(list(self.cfg.offset.pos), device=self._device).repeat(self._view.count, 1) - self._data.quat_w = torch.zeros(self._view.count, 4, device=self.device) - self._data.pos_w = torch.zeros(self._view.count, 3, device=self.device) + self._data._quat_w = torch.zeros(self._view.count, 4, device=self.device) + self._data._pos_w = torch.zeros(self._view.count, 3, device=self.device) self._ray_starts_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) self._ray_directions_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py index d7aa07419d4b..8804670c2655 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py @@ -228,9 +228,9 @@ def _initialize_rays_impl(self): self.drift = torch.zeros(self._view.count, 3, device=self.device) self.ray_cast_drift = torch.zeros(self._view.count, 3, device=self.device) # fill the data buffer - self._data.pos_w = torch.zeros(self._view.count, 3, device=self.device) - self._data.quat_w = torch.zeros(self._view.count, 4, device=self.device) - self._data.ray_hits_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) + self._data._pos_w = torch.zeros(self._view.count, 3, device=self.device) + self._data._quat_w = torch.zeros(self._view.count, 4, device=self.device) + self._data._ray_hits_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) self._ray_starts_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) self._ray_directions_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) @@ -244,8 +244,8 @@ def _update_ray_infos(self, env_ids: Sequence[int]): # apply drift to ray starting position in world frame pos_w += self.drift[env_ids] # store the poses - self._data.pos_w[env_ids] = pos_w - self._data.quat_w[env_ids] = quat_w + self._data._pos_w[env_ids] = pos_w + self._data._quat_w[env_ids] = quat_w # check if user provided attach_yaw_only flag if self.cfg.attach_yaw_only is not None: @@ -296,7 +296,7 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): # ray cast and store the hits # TODO: Make this work for multiple meshes? - self._data.ray_hits_w[env_ids] = raycast_mesh( + self._data._ray_hits_w[env_ids] = raycast_mesh( self._ray_starts_w[env_ids], self._ray_directions_w[env_ids], max_dist=self.cfg.max_distance, @@ -304,7 +304,7 @@ def _update_buffers_impl(self, env_ids: Sequence[int]): )[0] # apply vertical drift to ray starting position in ray caster frame - self._data.ray_hits_w[env_ids, :, 2] += self.ray_cast_drift[env_ids, 2].unsqueeze(-1) + self._data._ray_hits_w[env_ids, :, 2] += self.ray_cast_drift[env_ids, 2].unsqueeze(-1) def _set_debug_vis_impl(self, debug_vis: bool): # set visibility of markers diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster_data.py b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster_data.py index e4961a60603d..6dd2e6a6f753 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster_data.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster_data.py @@ -6,24 +6,44 @@ import torch from dataclasses import dataclass +from isaaclab.utils.leapp_semantics import leapp_tensor_semantics + @dataclass class RayCasterData: """Data container for the ray-cast sensor.""" - pos_w: torch.Tensor = None + _pos_w: torch.Tensor = None """Position of the sensor origin in world frame. Shape is (N, 3), where N is the number of sensors. """ - quat_w: torch.Tensor = None + _quat_w: torch.Tensor = None """Orientation of the sensor origin in quaternion (w, x, y, z) in world frame. Shape is (N, 4), where N is the number of sensors. """ - ray_hits_w: torch.Tensor = None + _ray_hits_w: torch.Tensor = None """The ray hit positions in the world frame. Shape is (N, B, 3), where N is the number of sensors, B is the number of rays in the scan pattern per sensor. """ + + @property + @leapp_tensor_semantics(kind="state/sensor/position", element_names_source="xyz") + def pos_w(self) -> torch.Tensor: + """Position of the sensor origin in world frame.""" + return self._pos_w + + @property + @leapp_tensor_semantics(kind="state/sensor/rotation", element_names_source="quat_wxyz") + def quat_w(self) -> torch.Tensor: + """Orientation of the sensor origin in quaternion (w, x, y, z) in world frame.""" + return self._quat_w + + @property + @leapp_tensor_semantics(kind="state/sensor/ray_hit_position") + def ray_hits_w(self) -> torch.Tensor: + """The ray hit positions in the world frame.""" + return self._ray_hits_w diff --git a/source/isaaclab/isaaclab/utils/leapp_semantics.py b/source/isaaclab/isaaclab/utils/leapp_semantics.py index b101cda1a991..3a9c53c8c7e8 100644 --- a/source/isaaclab/isaaclab/utils/leapp_semantics.py +++ b/source/isaaclab/isaaclab/utils/leapp_semantics.py @@ -25,6 +25,7 @@ class LeappTensorSemantics: XYZ_ELEMENT_NAMES: list[str] = ["x", "y", "z"] QUAT_WXYZ_ELEMENT_NAMES: list[str] = ["qw", "qx", "qy", "qz"] POSE7_ELEMENT_NAMES: list[str] = ["x", "y", "z", "qw", "qx", "qy", "qz"] +WRENCH6_ELEMENT_NAMES: list[str] = ["fx", "fy", "fz", "tx", "ty", "tz"] def leapp_tensor_semantics( @@ -74,23 +75,47 @@ def resolve_leapp_element_names(semantics: LeappTensorSemantics | None, data_sel source = semantics.element_names_source if source == "joint_names": - return _select_element_names(getattr(data_self, "joint_names", None), getattr(data_self, "_joint_ids", None)) + return _select_element_names( + getattr(data_self, "joint_names", getattr(data_self, "_joint_names", None)), + getattr(data_self, "_joint_ids", None), + ) if source == "body_names": - return _select_element_names(getattr(data_self, "body_names", None), getattr(data_self, "_body_ids", None)) + return _select_element_names( + getattr(data_self, "body_names", getattr(data_self, "_body_names", None)), + getattr(data_self, "_body_ids", None), + ) + if source == "body_xyz": + body_names = _select_element_names( + getattr(data_self, "body_names", getattr(data_self, "_body_names", None)), + getattr(data_self, "_body_ids", None), + ) + if body_names is None: + return None + return [body_names, XYZ_ELEMENT_NAMES] if source == "body_pose": body_names = _select_element_names( - getattr(data_self, "body_names", None), getattr(data_self, "_body_ids", None) + getattr(data_self, "body_names", getattr(data_self, "_body_names", None)), + getattr(data_self, "_body_ids", None), ) if body_names is None: return None return [body_names, POSE7_ELEMENT_NAMES] if source == "body_quat": body_names = _select_element_names( - getattr(data_self, "body_names", None), getattr(data_self, "_body_ids", None) + getattr(data_self, "body_names", getattr(data_self, "_body_names", None)), + getattr(data_self, "_body_ids", None), ) if body_names is None: return None return [body_names, QUAT_WXYZ_ELEMENT_NAMES] + if source == "body_wrench": + body_names = _select_element_names( + getattr(data_self, "body_names", getattr(data_self, "_body_names", None)), + getattr(data_self, "_body_ids", None), + ) + if body_names is None: + return None + return [body_names, WRENCH6_ELEMENT_NAMES] if source == "pose7": return POSE7_ELEMENT_NAMES if source == "xyz": diff --git a/source/isaaclab_rl/test/test_rsl_rl_export_flow.py b/source/isaaclab_rl/test/test_rsl_rl_export_flow.py new file mode 100644 index 000000000000..2f656aee89b2 --- /dev/null +++ b/source/isaaclab_rl/test/test_rsl_rl_export_flow.py @@ -0,0 +1,143 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Export pipeline integration tests. + +Each test calls ``export.py`` as a subprocess so that Isaac Sim's AppLauncher +is fully isolated per task and the export logic is not duplicated here. +The export artifacts land in the default checkpoint directory; only the +per-task export subdirectory is removed after each test. +""" + +import os +import pytest +import shutil +import subprocess + +# Root of the repository (three levels up from this file). +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +_EXPORT_SCRIPT = os.path.join("scripts", "reinforcement_learning", "rsl_rl", "export.py") + + +# Tasks with confirmed pretrained checkpoints (Direct and no-checkpoint tasks excluded). +# See FRANK_TESTING/no-pretrain.txt for the excluded set. +TASKS = [ + # Classic + "Isaac-Ant-v0", + "Isaac-Cartpole-v0", + "Isaac-Humanoid-v0", + # Navigation + "Isaac-Navigation-Flat-Anymal-C-v0", + "Isaac-Navigation-Flat-Anymal-C-Play-v0", + # Locomotion Velocity + "Isaac-Velocity-Flat-Anymal-B-v0", + "Isaac-Velocity-Flat-Anymal-B-Play-v0", + "Isaac-Velocity-Rough-Anymal-B-v0", + "Isaac-Velocity-Rough-Anymal-B-Play-v0", + "Isaac-Velocity-Flat-Anymal-C-v0", + "Isaac-Velocity-Flat-Anymal-C-Play-v0", + "Isaac-Velocity-Rough-Anymal-C-v0", + "Isaac-Velocity-Rough-Anymal-C-Play-v0", + "Isaac-Velocity-Flat-Anymal-D-v0", + "Isaac-Velocity-Flat-Anymal-D-Play-v0", + "Isaac-Velocity-Rough-Anymal-D-v0", + "Isaac-Velocity-Rough-Anymal-D-Play-v0", + "Isaac-Velocity-Flat-Cassie-v0", + "Isaac-Velocity-Flat-Cassie-Play-v0", + "Isaac-Velocity-Rough-Cassie-v0", + "Isaac-Velocity-Rough-Cassie-Play-v0", + "Isaac-Velocity-Flat-G1-v0", + "Isaac-Velocity-Flat-G1-Play-v0", + "Isaac-Velocity-Rough-G1-v0", + "Isaac-Velocity-Rough-G1-Play-v0", + "Isaac-Velocity-Flat-H1-v0", + "Isaac-Velocity-Flat-H1-Play-v0", + "Isaac-Velocity-Rough-H1-v0", + "Isaac-Velocity-Rough-H1-Play-v0", + "Isaac-Velocity-Flat-Spot-v0", + "Isaac-Velocity-Flat-Spot-Play-v0", + "Isaac-Velocity-Flat-Unitree-A1-v0", + "Isaac-Velocity-Flat-Unitree-A1-Play-v0", + "Isaac-Velocity-Rough-Unitree-A1-v0", + "Isaac-Velocity-Rough-Unitree-A1-Play-v0", + "Isaac-Velocity-Flat-Unitree-Go1-v0", + "Isaac-Velocity-Flat-Unitree-Go1-Play-v0", + "Isaac-Velocity-Rough-Unitree-Go1-v0", + "Isaac-Velocity-Rough-Unitree-Go1-Play-v0", + "Isaac-Velocity-Flat-Unitree-Go2-v0", + "Isaac-Velocity-Flat-Unitree-Go2-Play-v0", + "Isaac-Velocity-Rough-Unitree-Go2-v0", + "Isaac-Velocity-Rough-Unitree-Go2-Play-v0", + # Manipulation Reach + "Isaac-Reach-Franka-v0", + "Isaac-Reach-Franka-Play-v0", + "Isaac-Reach-UR10-v0", + "Isaac-Reach-UR10-Play-v0", + # Manipulation Lift + "Isaac-Lift-Cube-Franka-v0", + "Isaac-Lift-Cube-Franka-Play-v0", + # Manipulation Cabinet + "Isaac-Open-Drawer-Franka-v0", + "Isaac-Open-Drawer-Franka-Play-v0", + # Manipulation In-Hand + "Isaac-Repose-Cube-Allegro-v0", + "Isaac-Repose-Cube-Allegro-Play-v0", + "Isaac-Repose-Cube-Allegro-NoVelObs-v0", + "Isaac-Repose-Cube-Allegro-NoVelObs-Play-v0", + # Dexsuite + "Isaac-Dexsuite-Kuka-Allegro-Reorient-v0", + "Isaac-Dexsuite-Kuka-Allegro-Reorient-Play-v0", + "Isaac-Dexsuite-Kuka-Allegro-Lift-v0", + "Isaac-Dexsuite-Kuka-Allegro-Lift-Play-v0", +] + + +def _export_dir(task_name: str) -> str: + """Return the directory where export.py writes artifacts for *task_name*.""" + train_task = task_name.replace("-Play", "") + return os.path.join(_REPO_ROOT, ".pretrained_checkpoints", "rsl_rl", train_task, task_name) + + +@pytest.mark.parametrize("task_name", TASKS) +def test_export_flow(task_name): + """Run export.py for *task_name* and assert the expected artifacts are created.""" + export_dir = _export_dir(task_name) + + try: + result = subprocess.run( + [ + "./isaaclab.sh", + "-p", + _EXPORT_SCRIPT, + "--task", + task_name, + "--use_pretrained_checkpoint", + "--disable_graph_visualization", + "--headless", + ], + cwd=_REPO_ROOT, + capture_output=True, + text=True, + timeout=600, + ) + + # Gracefully skip tasks whose checkpoint isn't published yet + if "pre-trained checkpoint is currently unavailable" in result.stdout: + pytest.skip(f"No pretrained checkpoint available for {task_name.replace('-Play', '')}") + + # Surface stdout/stderr on failure for easier debugging + if result.returncode != 0: + pytest.fail( + f"export.py exited with code {result.returncode}.\n" + f"--- stdout ---\n{result.stdout[-3000:]}\n" + f"--- stderr ---\n{result.stderr[-3000:]}" + ) + + assert os.path.isfile(os.path.join(export_dir, f"{task_name}.onnx")), "Missing .onnx export" + assert os.path.isfile(os.path.join(export_dir, f"{task_name}.yaml")), "Missing .yaml export" + assert os.path.isfile(os.path.join(export_dir, "log.txt")), "Missing log.txt" + + finally: + shutil.rmtree(export_dir, ignore_errors=True) From 93025a707f5a09589afc7862104bcbee99fb7baa Mon Sep 17 00:00:00 2001 From: Frank Lai Date: Tue, 17 Mar 2026 12:03:10 -0700 Subject: [PATCH 23/23] fixed issue with callables in observation terms --- .../rsl_rl/export_annotator.py | 52 ++++++++++++++++++- .../test/test_rsl_rl_export_flow.py | 7 +++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/scripts/reinforcement_learning/rsl_rl/export_annotator.py b/scripts/reinforcement_learning/rsl_rl/export_annotator.py index dff93b27c2b4..9f3409a1fd11 100644 --- a/scripts/reinforcement_learning/rsl_rl/export_annotator.py +++ b/scripts/reinforcement_learning/rsl_rl/export_annotator.py @@ -43,6 +43,7 @@ from isaaclab.assets.articulation.articulation import Articulation from isaaclab.assets.articulation.articulation_data import ArticulationData +from isaaclab.managers import ManagerTermBase from isaaclab.sensors.camera.camera_data import CameraData from isaaclab.sensors.contact_sensor.contact_sensor_data import ContactSensorData from isaaclab.sensors.frame_transformer.frame_transformer_data import FrameTransformerData @@ -308,6 +309,46 @@ def __getattr__(self, name): return getattr(object.__getattribute__(self, "_real_env"), name) +class _ManagerTermProxy(ManagerTermBase): + """Proxy a class-based manager term while preserving its lifecycle methods. + + Observation manager terms can be stateful ``ManagerTermBase`` instances that + expose ``reset()`` and ``serialize()`` in addition to being callable. This + proxy preserves that interface while swapping the env argument passed into + ``__call__`` for the observation-side proxy env. + """ + + def __init__(self, target: ManagerTermBase, proxy_env: _EnvProxy): + super().__init__(target.cfg, target._env) + self._target = target + self._proxy_env = proxy_env + + @property + def __name__(self) -> str: + """Expose the wrapped term name for compatibility and debugging.""" + return getattr(self._target, "__name__", self._target.__class__.__name__) + + def reset(self, env_ids=None) -> None: + """Forward resets to the wrapped term instance.""" + self._target.reset(env_ids=env_ids) + + def serialize(self) -> dict: + """Forward serialization to the wrapped term instance.""" + return self._target.serialize() + + def __call__(self, *args, **kwargs): + """Call the wrapped term with the proxy env in place of the real env.""" + if args: + args = (self._proxy_env, *args[1:]) + else: + args = (self._proxy_env,) + return self._target(*args, **kwargs) + + def __getattr__(self, name): + """Forward all other attribute access to the wrapped term instance.""" + return getattr(self._target, name) + + # ══════════════════════════════════════════════════════════════════ # Action-side proxy # ══════════════════════════════════════════════════════════════════ @@ -627,8 +668,15 @@ def patched_apply_action(): def _wrap_with_proxy(original_func, proxy_env): """Wrap a term function so it receives the proxy env instead of the real env.""" - def wrapped(env, **kwargs): - return original_func(proxy_env, **kwargs) + if isinstance(original_func, ManagerTermBase): + return _ManagerTermProxy(original_func, proxy_env) + + def wrapped(*args, **kwargs): + if args: + args = (proxy_env, *args[1:]) + else: + args = (proxy_env,) + return original_func(*args, **kwargs) wrapped.__name__ = getattr(original_func, "__name__", "unknown") return wrapped diff --git a/source/isaaclab_rl/test/test_rsl_rl_export_flow.py b/source/isaaclab_rl/test/test_rsl_rl_export_flow.py index 2f656aee89b2..e7908f6deba4 100644 --- a/source/isaaclab_rl/test/test_rsl_rl_export_flow.py +++ b/source/isaaclab_rl/test/test_rsl_rl_export_flow.py @@ -129,10 +129,17 @@ def test_export_flow(task_name): # Surface stdout/stderr on failure for easier debugging if result.returncode != 0: + log_txt_path = os.path.join(export_dir, "log.txt") + leapp_tail = "" + if os.path.isfile(log_txt_path): + with open(log_txt_path) as f: + last_lines = f.readlines()[-50:] + leapp_tail = f"\n--- leapp log.txt (last 50 lines) ---\n{''.join(last_lines)}" pytest.fail( f"export.py exited with code {result.returncode}.\n" f"--- stdout ---\n{result.stdout[-3000:]}\n" f"--- stderr ---\n{result.stderr[-3000:]}" + f"{leapp_tail}" ) assert os.path.isfile(os.path.join(export_dir, f"{task_name}.onnx")), "Missing .onnx export"