diff --git a/source/isaaclab_tasks/changelog.d/fix-automate-run-w-id-forward-args.rst b/source/isaaclab_tasks/changelog.d/fix-automate-run-w-id-forward-args.rst new file mode 100644 index 00000000000..db3aa00829e --- /dev/null +++ b/source/isaaclab_tasks/changelog.d/fix-automate-run-w-id-forward-args.rst @@ -0,0 +1,4 @@ +Fixed +^^^^^ + +* Fixed the AutoMate ``run_w_id.py`` wrapper to accept ``--viz``/``--visualizer`` and forward additional arguments to the delegated RL-Games train/play script. diff --git a/source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_w_id.py b/source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_w_id.py index c888f2c0e89..76d04145cde 100644 --- a/source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_w_id.py +++ b/source/isaaclab_tasks/isaaclab_tasks/direct/automate/run_w_id.py @@ -39,29 +39,8 @@ def update_task_param(task_cfg, assembly_id, if_sbc, if_log_eval): f.writelines(updated_lines) -def main(): - parser = argparse.ArgumentParser(description="Update assembly_id and run training script.") - parser.add_argument( - "--cfg_path", - type=str, - help="Path to the file containing assembly_id.", - default="source/isaaclab_tasks/isaaclab_tasks/direct/automate/assembly_tasks_cfg.py", - ) - parser.add_argument("--assembly_id", type=str, help="New assembly ID to set.") - parser.add_argument("--checkpoint", type=str, help="Checkpoint path.") - parser.add_argument("--num_envs", type=int, default=128, help="Number of parallel environment.") - parser.add_argument("--seed", type=int, default=-1, help="Random seed.") - parser.add_argument("--train", action="store_true", help="Run training mode.") - parser.add_argument("--log_eval", action="store_true", help="Log evaluation results.") - parser.add_argument("--max_iterations", type=int, default=1500, help="Number of iteration for policy learning.") - args = parser.parse_args() - - if args.assembly_id == "ASSEMBLY_ID": - parser.error("replace ASSEMBLY_ID with an AutoMate assembly ID, for example 00032") - - update_task_param(args.cfg_path, args.assembly_id, args.train, args.log_eval) - - # build the command +def build_command(args, downstream_args): + """Build the delegated Isaac Lab command for the selected AutoMate assembly task.""" if sys.platform.startswith("win"): command = ["isaaclab.bat"] else: @@ -88,6 +67,49 @@ def main(): if args.checkpoint: command.append(f"--checkpoint={args.checkpoint}") + if args.visualizer: + command.extend(["--visualizer", args.visualizer]) + + command.extend(downstream_args) + + return command + + +def main(): + parser = argparse.ArgumentParser( + description="Update assembly_id and run training script.", + epilog="Additional arguments are forwarded to the delegated RL-Games train/play script.", + allow_abbrev=False, + ) + parser.add_argument( + "--cfg_path", + type=str, + help="Path to the file containing assembly_id.", + default="source/isaaclab_tasks/isaaclab_tasks/direct/automate/assembly_tasks_cfg.py", + ) + parser.add_argument("--assembly_id", type=str, help="New assembly ID to set.") + parser.add_argument("--checkpoint", type=str, help="Checkpoint path.") + parser.add_argument("--num_envs", type=int, default=128, help="Number of parallel environment.") + parser.add_argument("--seed", type=int, default=-1, help="Random seed.") + parser.add_argument("--train", action="store_true", help="Run training mode.") + parser.add_argument("--log_eval", action="store_true", help="Log evaluation results.") + parser.add_argument("--max_iterations", type=int, default=1500, help="Number of iteration for policy learning.") + parser.add_argument( + "--visualizer", + "--viz", + type=str, + default=None, + help="Visualizer backend(s) to forward to the delegated RL-Games script, for example 'kit'.", + ) + args, downstream_args = parser.parse_known_args() + + if args.assembly_id == "ASSEMBLY_ID": + parser.error("replace ASSEMBLY_ID with an AutoMate assembly ID, for example 00032") + + update_task_param(args.cfg_path, args.assembly_id, args.train, args.log_eval) + + command = build_command(args, downstream_args) + # Run the command subprocess.run(command, check=True)