From d838f56366db4c77d8c17627ff14f269c19d45f8 Mon Sep 17 00:00:00 2001 From: Rafael Wiltz Date: Wed, 10 Jun 2026 16:36:23 -0400 Subject: [PATCH 1/5] feat: add generic joint-space teleop device (SO-101 leader) Add a reusable joint-space device path to Isaac Teleop: a name-keyed JointStateOutput FlatBuffer schema, a JointStateTracker with live and MCAP-replay backends (registered in the live/replay factories), a JointStateSource, and a dual-mode JointStateRetargeter -- joint mirror (name remap + per-joint affine) or URDF forward-kinematics EE pose. The SO-101 leader arm is the reference instance, with a so101_leader plugin that pushes JointStateOutput (synthetic backend; the real Feetech serial read is left as a marked seam). Includes the Python schema/tracker bindings, the joint_space_device example (live over OpenXR or offline), sim-free unit tests for both retargeter modes, and device + retargeting reference docs. --- CMakeLists.txt | 1 + docs/source/device/index.rst | 1 + docs/source/device/joint_space.rst | 123 +++++ docs/source/references/retargeting/index.rst | 11 + .../references/retargeting/joint_space.rst | 128 +++++ .../python/joint_space_device_example.py | 237 +++++++++ .../joint_state_tracker_base.hpp | 23 + src/core/deviceio_trackers/cpp/CMakeLists.txt | 2 + .../deviceio_trackers/joint_state_tracker.hpp | 84 ++++ .../cpp/joint_state_tracker.cpp | 23 + .../python/deviceio_trackers_init.py | 2 + .../python/tracker_bindings.cpp | 12 + src/core/live_trackers/cpp/CMakeLists.txt | 2 + .../live_trackers/live_deviceio_factory.hpp | 3 + .../cpp/live_deviceio_factory.cpp | 19 + .../cpp/live_joint_state_tracker_impl.cpp | 64 +++ .../cpp/live_joint_state_tracker_impl.hpp | 52 ++ .../mcap/cpp/inc/mcap/recording_traits.hpp | 7 + src/core/python/deviceio_init.py | 2 + src/core/python/pyproject.toml.in | 1 + src/core/python/requirements-retargeters.txt | 14 + src/core/replay_trackers/cpp/CMakeLists.txt | 2 + .../replay_deviceio_factory.hpp | 3 + .../cpp/replay_deviceio_factory.cpp | 22 +- .../cpp/replay_joint_state_tracker_impl.cpp | 57 +++ .../cpp/replay_joint_state_tracker_impl.hpp | 39 ++ .../python/deviceio_source_nodes/__init__.py | 6 + .../deviceio_tensor_types.py | 33 ++ .../joint_state_source.py | 112 +++++ .../python/test_joint_state_retargeter.py | 461 ++++++++++++++++++ .../python/test_joint_state_source.py | 129 +++++ src/core/schema/fbs/joint_state.fbs | 73 +++ src/core/schema/python/CMakeLists.txt | 1 + src/core/schema/python/joint_state_bindings.h | 116 +++++ src/core/schema/python/schema_init.py | 10 + src/core/schema/python/schema_module.cpp | 4 + src/plugins/so101_leader/CMakeLists.txt | 16 + src/plugins/so101_leader/README.md | 31 ++ src/plugins/so101_leader/main.cpp | 49 ++ src/plugins/so101_leader/plugin.yaml | 11 + .../so101_leader/so101_leader_plugin.cpp | 116 +++++ .../so101_leader/so101_leader_plugin.hpp | 61 +++ src/retargeters/CMakeLists.txt | 2 + src/retargeters/__init__.py | 15 + src/retargeters/joint_space/__init__.py | 2 + .../joint_space/joint_state_retargeter.py | 327 +++++++++++++ 46 files changed, 2507 insertions(+), 2 deletions(-) create mode 100644 docs/source/device/joint_space.rst create mode 100644 docs/source/references/retargeting/joint_space.rst create mode 100644 examples/teleop/python/joint_space_device_example.py create mode 100644 src/core/deviceio_base/cpp/inc/deviceio_base/joint_state_tracker_base.hpp create mode 100644 src/core/deviceio_trackers/cpp/inc/deviceio_trackers/joint_state_tracker.hpp create mode 100644 src/core/deviceio_trackers/cpp/joint_state_tracker.cpp create mode 100644 src/core/live_trackers/cpp/live_joint_state_tracker_impl.cpp create mode 100644 src/core/live_trackers/cpp/live_joint_state_tracker_impl.hpp create mode 100644 src/core/replay_trackers/cpp/replay_joint_state_tracker_impl.cpp create mode 100644 src/core/replay_trackers/cpp/replay_joint_state_tracker_impl.hpp create mode 100644 src/core/retargeting_engine/python/deviceio_source_nodes/joint_state_source.py create mode 100644 src/core/retargeting_engine_tests/python/test_joint_state_retargeter.py create mode 100644 src/core/retargeting_engine_tests/python/test_joint_state_source.py create mode 100644 src/core/schema/fbs/joint_state.fbs create mode 100644 src/core/schema/python/joint_state_bindings.h create mode 100644 src/plugins/so101_leader/CMakeLists.txt create mode 100644 src/plugins/so101_leader/README.md create mode 100644 src/plugins/so101_leader/main.cpp create mode 100644 src/plugins/so101_leader/plugin.yaml create mode 100644 src/plugins/so101_leader/so101_leader_plugin.cpp create mode 100644 src/plugins/so101_leader/so101_leader_plugin.hpp create mode 100644 src/retargeters/joint_space/__init__.py create mode 100644 src/retargeters/joint_space/joint_state_retargeter.py diff --git a/CMakeLists.txt b/CMakeLists.txt index e2a7b7b0a..74b10ada9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -165,6 +165,7 @@ if(BUILD_PLUGINS) add_subdirectory(src/plugins/controller_synthetic_hands) add_subdirectory(src/plugins/generic_3axis_pedal) + add_subdirectory(src/plugins/so101_leader) add_subdirectory(src/plugins/manus) add_subdirectory(src/plugins/haptikos) if(BUILD_PLUGIN_OAK_CAMERA) diff --git a/docs/source/device/index.rst b/docs/source/device/index.rst index 8f4579af5..8a9bb6db8 100644 --- a/docs/source/device/index.rst +++ b/docs/source/device/index.rst @@ -19,6 +19,7 @@ See the `Plugins directory `_, +6 Feetech STS3215 bus servos) is the reference instance. + +At a glance +----------- + +.. list-table:: + :header-rows: 1 + :widths: 18 82 + + * - Layer + - Component + * - Schema + - :code-file:`src/core/schema/fbs/joint_state.fbs` -- ``JointState`` (name + position + + optional velocity/effort) and ``JointStateOutput`` (a vector of joints + ``device_id``). + * - Plugin + - :code-dir:`src/plugins/so101_leader` -- pushes ``JointStateOutput`` via ``SchemaPusher``. + Ships a synthetic backend; the real Feetech/serial read is a marked seam. + * - Tracker + - ``JointStateTracker`` (facade) with live (``LiveJointStateTrackerImpl``) and MCAP-replay + (``ReplayJointStateTrackerImpl``) backends, registered in the live/replay factories. + * - Source + - ``JointStateSource`` (``IDeviceIOSource``) -- converts the FlatBuffer into a name-keyed + group of joint positions for the retargeting graph. + * - Retargeter + - ``JointStateRetargeter`` -- ``joint`` (mirror) or ``ee_pose`` (URDF FK) mode. See + :doc:`/references/retargeting/joint_space`. + +Data schema +----------- + +Joints are modeled as **name -> value** records so consumers read them by name, independent of +wire order: + +.. code-block:: idl + :class: code-100col + + table JointState { + name: string (id: 0, key); // e.g. "shoulder_pan", "gripper" + position: float (id: 1); // [rad] revolute, [m] prismatic + velocity: float (id: 2); // optional (JointStateOutput.has_velocity) + effort: float (id: 3); // optional (JointStateOutput.has_effort) + valid: bool = true (id: 4); + } + + table JointStateOutput { + joints: [JointState] (id: 0); + device_id: string (id: 1); + has_velocity: bool (id: 2); + has_effort: bool (id: 3); + ee_pose: Pose (id: 4); // RESERVED: device-side FK; not consumed yet + ee_pose_valid: bool (id: 5); + } + +The gripper is just another named DOF (conventionally ``"gripper"``). ``velocity``, ``effort``, +and ``ee_pose`` are optional/reserved: the reference plugin and ``JointStateSource`` populate and +surface joint **positions** only. + +The SO-101 leader plugin +------------------------ + +``so101_leader`` reads the six SO-101 servos (``shoulder_pan, shoulder_lift, elbow_flex, +wrist_flex, wrist_roll, gripper``) and pushes them to a tensor collection. To keep the example +hardware-free and headless it ships a **synthetic backend**; the real Feetech read (via LeRobot's +``FeetechMotorsBus`` + calibration) is the marked seam in ``So101LeaderPlugin::read_hardware()``. + +.. code-block:: bash + + # Synthetic backend (no hardware), default collection id "so101_leader": + ./install/plugins/so101_leader/so101_leader_plugin + + # Reserved for the real serial backend + a custom collection id: + ./install/plugins/so101_leader/so101_leader_plugin /dev/ttyACM0 so101_leader + +The consumer side creates a ``JointStateSource(name=..., collection_id="so101_leader", +joint_names=[...])`` on the same ``collection_id``; ``TeleopSession`` discovers and polls the +``JointStateTracker`` each frame. + +Record and replay +----------------- + +The live tracker records to MCAP, and ``ReplayJointStateTrackerImpl`` replays it back with no +OpenXR runtime, so a recorded session drives the retargeting graph headlessly: + +.. code-block:: python + + from isaacteleop.deviceio import McapRecordingConfig, McapReplayConfig + from isaacteleop.teleop_session_manager import SessionMode, TeleopSession, TeleopSessionConfig + + # Record (live): TeleopSessionConfig(..., mcap_config=McapRecordingConfig("leader.mcap")) + # Replay (headless): TeleopSessionConfig(..., mode=SessionMode.REPLAY, + # mcap_config=McapReplayConfig("leader.mcap")) + +Add another joint-space device +------------------------------ + +Reuse everything above by writing only: + +#. A **plugin** that reads your hardware and fills ``JointStateOutput`` (positions; optionally + velocity/effort), modeled on :code-dir:`src/plugins/so101_leader`. +#. A **config**: a ``collection_id``, the device joint names, and -- for ``ee_pose`` mode -- a URDF + and end-effector link. + +The schema, ``JointStateTracker``, ``JointStateSource``, and ``JointStateRetargeter`` are unchanged. + +.. seealso:: + + :doc:`add_device` -- the general four-step device-plugin recipe (foot-pedal reference). + + :doc:`/references/retargeting/joint_space` -- the ``JointStateRetargeter`` (joint / EE modes), + the end-to-end example, and validation. diff --git a/docs/source/references/retargeting/index.rst b/docs/source/references/retargeting/index.rst index 81da6c89b..6cf5043ca 100644 --- a/docs/source/references/retargeting/index.rst +++ b/docs/source/references/retargeting/index.rst @@ -11,6 +11,8 @@ Source Nodes * ``HandsSource`` -- provides hand tracking data (left/right, 26 joints each). * ``ControllersSource`` -- provides motion controller data (grip pose, trigger, thumbstick, etc.). * ``Generic3AxisPedalSource`` -- provides 3-axis foot pedal data (left/right pedals, rudder). +* ``JointStateSource`` -- provides name-keyed joint positions from a generic joint-space device + (leader arm, exoskeleton, ...). See :doc:`joint_space`. * ``FullBodySource`` -- provides full-body pose (e.g. Pico tracking). Available Retargeters @@ -41,6 +43,14 @@ Available Retargeters ``hand_side`` (``"left"`` or ``"right"``), ``gripper_close_meters``, ``gripper_open_meters``, and ``controller_threshold`` for trigger-based closing. +.. dropdown:: JointStateRetargeter + + Maps a name-keyed joint-state input (from ``JointStateSource``) to an action for a generic + joint-space device -- leader arm, exoskeleton, etc. Two modes via ``JointStateRetargeterConfig``: + ``"joint"`` (lossless leader -> follower mirror with optional per-joint affine; no extra deps) + and ``"ee_pose"`` (URDF forward kinematics -> 7D EE pose + gripper, requires ``pinocchio``). + See :doc:`joint_space` for the full setup, modes, and the SO-101 example. + .. dropdown:: DexHandRetargeter / DexBiManualRetargeter Accurate hand tracking retargeter using the ``dex-retargeting`` library. It maps full hand @@ -315,3 +325,4 @@ and :doc:`Contributing Guide <../../getting_started/contributing>` for details. :caption: Retargeter setup guides sharpa + joint_space diff --git a/docs/source/references/retargeting/joint_space.rst b/docs/source/references/retargeting/joint_space.rst new file mode 100644 index 000000000..85da0412f --- /dev/null +++ b/docs/source/references/retargeting/joint_space.rst @@ -0,0 +1,128 @@ +.. SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +.. SPDX-License-Identifier: Apache-2.0 + +Retargeter: Joint-Space Device +============================== + +``JointStateRetargeter`` maps a name-keyed joint-state input (from +:doc:`/device/joint_space`'s ``JointStateSource``) onto an Isaac Lab action, in one of two modes. +It is the generic retargeter for leader arms, exoskeletons, and other joint-encoder devices; the +SO-101 leader arm is the reference instance. + +At a glance +----------- + +.. list-table:: + :header-rows: 1 + :widths: 16 30 54 + + * - Mode + - Output + - Use + * - ``joint`` + - one float per target joint (``joint_targets``) + - Lossless leader -> follower mirror for same-kinematics teleoperation. Name remap + optional + per-joint affine. No extra dependencies. + * - ``ee_pose`` + - 7-D ``ee_pose`` ``[x,y,z,qx,qy,qz,qw]`` + ``gripper_command`` + - Task-space / cross-embodiment teleoperation via URDF forward kinematics. Requires + ``pinocchio`` (the ``[retargeters]`` extra). + +``joint`` mode +-------------- + +Each target joint is filled from a device joint (by name) with an optional affine +``offset + sign * scale * value``. Defaults are an identity mirror. ``JointStateRetargeterConfig``: + +* ``device_joints`` -- ordered device DOF names (must match the source's ``joint_names`` order). +* ``target_joints`` -- robot joint names to emit (defaults to ``device_joints``). +* ``joint_map`` -- ``{device_name: target_name}`` overrides; ``scale`` / ``offset`` / ``sign`` -- + per-target affine. + +``ee_pose`` mode +---------------- + +Forward-kinematics the device joints through a URDF and emit the end-effector pose plus a gripper +command. Config: ``urdf_path``, ``ee_link``, ``gripper_joint`` (and optional ``gripper_open`` / +``gripper_close`` to emit normalized closedness in ``[0, 1]`` instead of the raw value). + +* FK uses ``pinocchio`` (imported lazily; ``joint`` mode never needs it). Install via + ``pip install 'isaacteleop[retargeters]'``. +* Assumes a fixed-base model of single-DOF joints (the common leader-arm / exoskeleton case). +* The schema's device ``ee_pose`` field is **not** consumed yet -- FK is always computed from the + joint positions. +* ``clutch=True`` rebases the EE around an origin captured on the first ``RUNNING`` frame so + engaging teleop does not jump the robot; when the optional ``robot_ee_pos`` input (the live + ``world_T_ee``) is connected, the latched home is the robot's current end-effector. + +.. note:: + + The ``joints`` input is read positionally in ``device_joints`` order, so the upstream source's + ``joint_names`` must list the same names in the same order. A name mismatch is rejected by the + graph's type check at ``connect`` time. + +Use it from Python +------------------ + +A pipeline builder returns an ``OutputCombiner`` with a single ``"action"`` key (the layout your +environment's action space expects): + +.. code-block:: python + + from isaacteleop.retargeting_engine.deviceio_source_nodes import JointStateSource + from isaacteleop.retargeting_engine.interface import OutputCombiner + from isaacteleop.retargeters import ( + JointStateRetargeter, + JointStateRetargeterConfig, + TensorReorderer, + ) + + SO101_JOINTS = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"] + + def build_so101_joint_pipeline(): + source = JointStateSource(name="leader", collection_id="so101_leader", joint_names=SO101_JOINTS) + retargeter = JointStateRetargeter( + name="leader", + mode="joint", + config=JointStateRetargeterConfig(device_joints=SO101_JOINTS, target_joints=SO101_JOINTS), + ) + head = retargeter.connect({JointStateRetargeter.JOINTS: source.output(JointStateSource.JOINTS)}) + reorderer = TensorReorderer( + input_config={"joint_targets": SO101_JOINTS}, + output_order=SO101_JOINTS, + name="action_reorderer", + input_types={"joint_targets": "scalar"}, + ) + connected = reorderer.connect({"joint_targets": head.output("joint_targets")}) + return OutputCombiner({"action": connected.output("output")}) + +For ``ee_pose`` mode, build the retargeter with ``mode="ee_pose"`` + a ``urdf_path`` / ``ee_link`` +and flatten ``ee_pose`` + ``gripper_command`` into the env's task-space action layout. + +Run the example +--------------- + +The repo ships ``examples/teleop/python/joint_space_device_example.py``: + +.. code-block:: console + + # Consumes the so101_leader plugin over OpenXR (source cloudxr.env first): + $ python joint_space_device_example.py --launch-plugin --mode joint --frames 8 + $ python joint_space_device_example.py --launch-plugin --mode ee --urdf so101_new_calib.urdf + +Validate +-------- + +Sim-free unit tests cover both modes (joint affine/remap/hold/reset, EE forward kinematics, clutch +rebasing, and the flattened action width/order): + +.. code-block:: console + + $ ctest --test-dir build -R 'retargeting_test_joint_state' --output-on-failure + +.. seealso:: + + :doc:`/device/joint_space` -- the schema, ``JointStateTracker``, ``JointStateSource``, the + SO-101 plugin, and MCAP record/replay. + + :doc:`index` -- the broader retargeting interface and pipeline-builder pattern. diff --git a/examples/teleop/python/joint_space_device_example.py b/examples/teleop/python/joint_space_device_example.py new file mode 100644 index 000000000..0e0febc93 --- /dev/null +++ b/examples/teleop/python/joint_space_device_example.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Joint-space device example: SO-101 leader arm -> Isaac Lab action. + +This CLI is both an **API reference** and an **end-to-end test** for the generic joint-space +device path on the Isaac Teleop side: + + JointStateOutput schema -> JointStateSource -> JointStateRetargeter -> TensorReorderer + | + OutputCombiner("action") + +It consumes joint state streamed by the real ``so101_leader`` plugin over the OpenXR tensor +transport via a ``TeleopSession`` (the ``JointStateSource`` auto-discovers and polls the +``JointStateTracker``). Like the other CloudXR examples/tests, it **expects the CloudXR runtime +to already be running** -- ``source ~/.cloudxr/run/cloudxr.env`` first -- and does not probe for +it. Use ``--launch-plugin`` to spawn the synthetic plugin process automatically; otherwise start +``so101_leader`` (or any device pushing the same ``collection_id``) separately. + +Two modes: + +* ``--mode joint`` -> 6-D joint mirror ``[shoulder_pan, ..., gripper]`` (no extra deps). +* ``--mode ee`` -> 8-D ``[pos_xyz, quat_xyzw, gripper]`` via URDF forward kinematics + (needs ``pinocchio`` and ``--urdf`` pointing at ``so101_new_calib.urdf``). + +Examples:: + + source ~/.cloudxr/run/cloudxr.env + python joint_space_device_example.py --launch-plugin --mode joint --frames 8 + python joint_space_device_example.py --launch-plugin --mode ee --urdf /path/to/so101_new_calib.urdf +""" + +from __future__ import annotations + +import argparse +import subprocess +import time +from pathlib import Path + +import numpy as np + +from isaacteleop.retargeting_engine.deviceio_source_nodes import JointStateSource +from isaacteleop.retargeting_engine.interface import OutputCombiner +from isaacteleop.retargeters import ( + JointStateRetargeter, + JointStateRetargeterConfig, + TensorReorderer, +) + +# Canonical SO-101 DOF names (match Simulation/SO101/so101_new_calib.urdf and the schema). +SO101_JOINTS = [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", +] + +_COLLECTION_ID = "so101_leader" +_POSE_LABELS = ["pos_x", "pos_y", "pos_z", "quat_x", "quat_y", "quat_z", "quat_w"] +_DEFAULT_PLUGIN_BIN = ( + Path(__file__).resolve().parents[3] + / "build/src/plugins/so101_leader/so101_leader_plugin" +) + + +def build_pipeline( + mode: str, source: JointStateSource, urdf_path: str | None, ee_link: str +): + """Build the action pipeline from the live ``JointStateSource`` leaf. + + Returns ``(combiner, action_labels)``.""" + if mode == "joint": + retargeter = JointStateRetargeter( + name="leader", + mode="joint", + config=JointStateRetargeterConfig( + device_joints=SO101_JOINTS, target_joints=SO101_JOINTS + ), + ) + head = retargeter.connect( + {JointStateRetargeter.JOINTS: source.output(JointStateSource.JOINTS)} + ) + reorderer = TensorReorderer( + input_config={"joint_targets": SO101_JOINTS}, + output_order=SO101_JOINTS, + name="action_reorderer", + input_types={"joint_targets": "scalar"}, + ) + connected = reorderer.connect({"joint_targets": head.output("joint_targets")}) + combiner = OutputCombiner({"action": connected.output("output")}) + return combiner, list(SO101_JOINTS) + + if not urdf_path: + raise SystemExit("--mode ee requires --urdf ") + retargeter = JointStateRetargeter( + name="leader", + mode="ee_pose", + config=JointStateRetargeterConfig( + device_joints=SO101_JOINTS, + urdf_path=urdf_path, + ee_link=ee_link, + gripper_joint="gripper", + ), + ) + head = retargeter.connect( + {JointStateRetargeter.JOINTS: source.output(JointStateSource.JOINTS)} + ) + action_labels = _POSE_LABELS + ["gripper_value"] + reorderer = TensorReorderer( + input_config={"ee_pose": _POSE_LABELS, "gripper_command": ["gripper_value"]}, + output_order=action_labels, + name="action_reorderer", + input_types={"ee_pose": "array", "gripper_command": "scalar"}, + ) + connected = reorderer.connect( + { + "ee_pose": head.output("ee_pose"), + "gripper_command": head.output("gripper_command"), + } + ) + combiner = OutputCombiner({"action": connected.output("output")}) + return combiner, action_labels + + +def run_live( + mode: str, num_frames: int, urdf_path: str | None, ee_link: str, timeout_s: float +) -> None: + """Consume the live so101_leader plugin stream through a TeleopSession over OpenXR.""" + from isaacteleop.teleop_session_manager import TeleopSession, TeleopSessionConfig + + source = JointStateSource( + name="leader", collection_id=_COLLECTION_ID, joint_names=SO101_JOINTS + ) + combiner, labels = build_pipeline(mode, source, urdf_path, ee_link) + + print( + f"mode={mode} action_dim={len(labels)} layout={labels} collection={_COLLECTION_ID!r}" + ) + print("-" * 80) + + session_config = TeleopSessionConfig( + app_name="JointSpaceDeviceLiveExample", + trackers=[], + pipeline=combiner, + plugins=[], + ) + actions: list[np.ndarray] = [] + with TeleopSession(session_config) as session: + deadline = time.time() + timeout_s + frame = 0 + while len(actions) < num_frames and time.time() < deadline: + result = session.step() + action = result.get("action") + if action is not None: + arr = np.asarray(action[0], dtype=np.float64) + actions.append(arr) + print( + f"step {frame:02d} | action = [ {' '.join(f'{v:+.3f}' for v in arr)} ]" + ) + frame += 1 + time.sleep(0.05) + + print("-" * 80) + if len(actions) < num_frames: + raise SystemExit( + f"FAILED: only {len(actions)}/{num_frames} action(s) received from the live plugin " + f"within {timeout_s:.0f}s (is the so101_leader plugin pushing? is cloudxr.env sourced?)" + ) + # A single received frame can't be "stale" -- only flag multi-frame runs that never change. + varied = len(actions) <= 1 or any( + not np.allclose(actions[i], actions[0], atol=1e-4) + for i in range(1, len(actions)) + ) + print( + f"OK: received {len(actions)} live action(s) of width {len(labels)}; varying over time: {varied}" + ) + if not varied: + raise SystemExit( + "FAILED: live actions did not vary -- stream may be stale (held-last only)" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument("--mode", choices=["joint", "ee"], default="joint") + parser.add_argument("--frames", type=int, default=5) + parser.add_argument( + "--urdf", default=None, help="Path to so101_new_calib.urdf (ee mode)" + ) + parser.add_argument( + "--ee-link", + default="gripper_frame_link", + help="URDF end-effector frame (ee mode)", + ) + parser.add_argument( + "--launch-plugin", + action="store_true", + help="Spawn the synthetic so101_leader plugin process automatically", + ) + parser.add_argument( + "--plugin-bin", + default=str(_DEFAULT_PLUGIN_BIN), + help="Path to so101_leader_plugin", + ) + parser.add_argument( + "--timeout", + type=float, + default=20.0, + help="Seconds to wait for plugin frames", + ) + args = parser.parse_args() + + plugin_proc = None + if args.launch_plugin: + if not Path(args.plugin_bin).exists(): + raise SystemExit( + f"plugin binary not found: {args.plugin_bin} (build it first)" + ) + print(f"launching plugin: {args.plugin_bin}") + # Empty device_path -> synthetic backend; collection id must match the source. + plugin_proc = subprocess.Popen([args.plugin_bin, "", _COLLECTION_ID]) + time.sleep(1.5) # let it create its OpenXR session and start pushing + try: + run_live(args.mode, args.frames, args.urdf, args.ee_link, args.timeout) + finally: + if plugin_proc is not None: + plugin_proc.terminate() + try: + plugin_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + plugin_proc.kill() + + +if __name__ == "__main__": + main() diff --git a/src/core/deviceio_base/cpp/inc/deviceio_base/joint_state_tracker_base.hpp b/src/core/deviceio_base/cpp/inc/deviceio_base/joint_state_tracker_base.hpp new file mode 100644 index 000000000..b9567e09f --- /dev/null +++ b/src/core/deviceio_base/cpp/inc/deviceio_base/joint_state_tracker_base.hpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tracker.hpp" + +namespace core +{ + +struct JointStateOutputTrackedT; + +// Abstract base interface for JointStateTracker implementations. +// +// Backs a generic joint-space device (leader arm, exoskeleton, glove, ...): the implementation +// keeps the last-known JointStateOutput snapshot, which the JointStateTracker facade exposes. +class IJointStateTrackerImpl : public ITrackerImpl +{ +public: + virtual const JointStateOutputTrackedT& get_data() const = 0; +}; + +} // namespace core diff --git a/src/core/deviceio_trackers/cpp/CMakeLists.txt b/src/core/deviceio_trackers/cpp/CMakeLists.txt index 48b460d71..113bda804 100644 --- a/src/core/deviceio_trackers/cpp/CMakeLists.txt +++ b/src/core/deviceio_trackers/cpp/CMakeLists.txt @@ -10,6 +10,7 @@ add_library(deviceio_trackers STATIC controller_tracker.cpp message_channel_tracker.cpp generic_3axis_pedal_tracker.cpp + joint_state_tracker.cpp frame_metadata_tracker_oak.cpp full_body_tracker_pico.cpp inc/deviceio_trackers/head_tracker.hpp @@ -18,6 +19,7 @@ add_library(deviceio_trackers STATIC inc/deviceio_trackers/message_channel_tracker.hpp inc/deviceio_trackers/full_body_tracker_pico.hpp inc/deviceio_trackers/generic_3axis_pedal_tracker.hpp + inc/deviceio_trackers/joint_state_tracker.hpp inc/deviceio_trackers/frame_metadata_tracker_oak.hpp ) diff --git a/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/joint_state_tracker.hpp b/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/joint_state_tracker.hpp new file mode 100644 index 000000000..97a53b406 --- /dev/null +++ b/src/core/deviceio_trackers/cpp/inc/deviceio_trackers/joint_state_tracker.hpp @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include +#include + +namespace core +{ + +/*! + * @brief Facade for a generic joint-space device exposed as ``JointStateOutputTrackedT``. + * + * Generic across joint-space input devices (leader arms, exoskeletons, gloves, ...): the payload + * is a list of named joints (``JointStateOutput.joints``, keyed by ``JointState.name``) plus an + * optional end-effector pose. The semantics of each joint (units, calibration) are defined by the + * data producer (the device plugin). A distinct ``collection_id`` per device allows several + * joint-space devices to stream simultaneously. + * + * After each ``ITrackerSession::update()`` that includes this tracker, ``get_data(session)`` + * reflects the implementation's tracked snapshot. As with other ``SchemaTracker``-backed trackers, + * the live backend may retain the last-known sample when a tick has no new samples while the + * collection remains available (``data`` stays non-null but may be stale); ``data`` is null only + * when no sample has arrived yet or the collection is unavailable. + * + * Usage: + * @code + * auto tracker = std::make_shared("so101_leader"); + * // ... register the tracker with a session, then each tick: ... + * session->update(); + * const auto& data = tracker->get_data(*session); + * @endcode + */ +class JointStateTracker : public ITracker +{ +public: + //! Default maximum FlatBuffer size for JointStateOutput messages. + //! Large enough for a few dozen named joints with optional velocity/effort. Pusher and tracker + //! must agree on this value (it sizes the fixed tensor buffer). + static constexpr size_t DEFAULT_MAX_FLATBUFFER_SIZE = 4096; + + /*! + * @brief Constructs a JointStateTracker. + * @param collection_id Logical stream identifier; must match the device plugin / pusher. + * @param max_flatbuffer_size Upper bound for serialized ``JointStateOutput`` / record payloads. + */ + explicit JointStateTracker(const std::string& collection_id, + size_t max_flatbuffer_size = DEFAULT_MAX_FLATBUFFER_SIZE); + + std::string_view get_name() const override + { + return TRACKER_NAME; + } + + /*! + * @brief Joint-state snapshot from the session's implementation. + * + * ``tracked.data`` is null when no valid sample exists. When non-null, the nested + * ``JointStateOutputT`` (joints, device_id, optional ee_pose) is safe to read. + */ + const JointStateOutputTrackedT& get_data(const ITrackerSession& session) const; + + const std::string& collection_id() const + { + return collection_id_; + } + + size_t max_flatbuffer_size() const + { + return max_flatbuffer_size_; + } + +private: + static constexpr const char* TRACKER_NAME = "JointStateTracker"; + + std::string collection_id_; + size_t max_flatbuffer_size_; +}; + +} // namespace core diff --git a/src/core/deviceio_trackers/cpp/joint_state_tracker.cpp b/src/core/deviceio_trackers/cpp/joint_state_tracker.cpp new file mode 100644 index 000000000..5766ffef5 --- /dev/null +++ b/src/core/deviceio_trackers/cpp/joint_state_tracker.cpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "inc/deviceio_trackers/joint_state_tracker.hpp" + +namespace core +{ + +// ============================================================================ +// JointStateTracker +// ============================================================================ + +JointStateTracker::JointStateTracker(const std::string& collection_id, size_t max_flatbuffer_size) + : collection_id_(collection_id), max_flatbuffer_size_(max_flatbuffer_size) +{ +} + +const JointStateOutputTrackedT& JointStateTracker::get_data(const ITrackerSession& session) const +{ + return static_cast(session.get_tracker_impl(*this)).get_data(); +} + +} // namespace core diff --git a/src/core/deviceio_trackers/python/deviceio_trackers_init.py b/src/core/deviceio_trackers/python/deviceio_trackers_init.py index f867e8f54..323a499a3 100644 --- a/src/core/deviceio_trackers/python/deviceio_trackers_init.py +++ b/src/core/deviceio_trackers/python/deviceio_trackers_init.py @@ -12,6 +12,7 @@ MessageChannelTracker, FrameMetadataTrackerOak, Generic3AxisPedalTracker, + JointStateTracker, FullBodyTrackerPico, ITrackerSession, NUM_JOINTS, @@ -28,6 +29,7 @@ "FrameMetadataTrackerOak", "FullBodyTrackerPico", "Generic3AxisPedalTracker", + "JointStateTracker", "HandTracker", "HeadTracker", "ITracker", diff --git a/src/core/deviceio_trackers/python/tracker_bindings.cpp b/src/core/deviceio_trackers/python/tracker_bindings.cpp index 601c7db06..189cefc78 100644 --- a/src/core/deviceio_trackers/python/tracker_bindings.cpp +++ b/src/core/deviceio_trackers/python/tracker_bindings.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -150,6 +151,17 @@ PYBIND11_MODULE(_deviceio_trackers, m) { return self.get_data(session); }, py::arg("session"), "Get the current foot pedal tracked state (data is None when no data available)"); + py::class_>(m, "JointStateTracker") + .def(py::init(), py::arg("collection_id"), + py::arg("max_flatbuffer_size") = core::JointStateTracker::DEFAULT_MAX_FLATBUFFER_SIZE, + "Construct a JointStateTracker for the given tensor collection ID (one generic " + "joint-space device: leader arm, exoskeleton, ...)") + .def( + "get_data", + [](const core::JointStateTracker& self, const core::ITrackerSession& session) -> core::JointStateOutputTrackedT + { return self.get_data(session); }, + py::arg("session"), "Get the current joint-state tracked snapshot (data is None when no data available)"); + py::class_>( m, "FullBodyTrackerPico") .def(py::init<>()) diff --git a/src/core/live_trackers/cpp/CMakeLists.txt b/src/core/live_trackers/cpp/CMakeLists.txt index 23d105b7d..9f93e58df 100644 --- a/src/core/live_trackers/cpp/CMakeLists.txt +++ b/src/core/live_trackers/cpp/CMakeLists.txt @@ -12,6 +12,7 @@ add_library(live_trackers STATIC live_message_channel_tracker_impl.cpp live_full_body_tracker_pico_impl.cpp live_generic_3axis_pedal_tracker_impl.cpp + live_joint_state_tracker_impl.cpp live_frame_metadata_tracker_oak_impl.cpp inc/live_trackers/schema_tracker_base.hpp inc/live_trackers/schema_tracker.hpp @@ -22,6 +23,7 @@ add_library(live_trackers STATIC live_message_channel_tracker_impl.hpp live_full_body_tracker_pico_impl.hpp live_generic_3axis_pedal_tracker_impl.hpp + live_joint_state_tracker_impl.hpp live_frame_metadata_tracker_oak_impl.hpp ) diff --git a/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp b/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp index 7d6b5c4f9..24a78a475 100644 --- a/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp +++ b/src/core/live_trackers/cpp/inc/live_trackers/live_deviceio_factory.hpp @@ -30,6 +30,8 @@ class FullBodyTrackerPico; class IFullBodyTrackerPicoImpl; class Generic3AxisPedalTracker; class IGeneric3AxisPedalTrackerImpl; +class JointStateTracker; +class IJointStateTrackerImpl; class HandTracker; class IHandTrackerImpl; class HeadTracker; @@ -62,6 +64,7 @@ class LiveDeviceIOFactory std::unique_ptr create_full_body_tracker_pico_impl(const FullBodyTrackerPico* tracker); std::unique_ptr create_generic_3axis_pedal_tracker_impl( const Generic3AxisPedalTracker* tracker); + std::unique_ptr create_joint_state_tracker_impl(const JointStateTracker* tracker); std::unique_ptr create_frame_metadata_tracker_oak_impl( const FrameMetadataTrackerOak* tracker); diff --git a/src/core/live_trackers/cpp/live_deviceio_factory.cpp b/src/core/live_trackers/cpp/live_deviceio_factory.cpp index 2c304480c..c0b3471ba 100644 --- a/src/core/live_trackers/cpp/live_deviceio_factory.cpp +++ b/src/core/live_trackers/cpp/live_deviceio_factory.cpp @@ -9,6 +9,7 @@ #include "live_generic_3axis_pedal_tracker_impl.hpp" #include "live_hand_tracker_impl.hpp" #include "live_head_tracker_impl.hpp" +#include "live_joint_state_tracker_impl.hpp" #include "live_message_channel_tracker_impl.hpp" #include @@ -17,6 +18,7 @@ #include #include #include +#include #include #include @@ -79,6 +81,12 @@ std::unique_ptr try_create_generic_pedal_impl(LiveDeviceIOFactory& return typed ? factory.create_generic_3axis_pedal_tracker_impl(typed) : nullptr; } +std::unique_ptr try_create_joint_state_impl(LiveDeviceIOFactory& factory, const ITracker& tracker) +{ + auto* typed = dynamic_cast(&tracker); + return typed ? factory.create_joint_state_tracker_impl(typed) : nullptr; +} + std::unique_ptr try_create_oak_impl(LiveDeviceIOFactory& factory, const ITracker& tracker) { auto* typed = dynamic_cast(&tracker); @@ -102,6 +110,7 @@ inline const TrackerDispatchEntry k_tracker_dispatch[] = { { &try_add_extensions, &try_create_message_channel_impl }, { &try_add_extensions, &try_create_full_body_pico_impl }, { &try_add_extensions, &try_create_generic_pedal_impl }, + { &try_add_extensions, &try_create_joint_state_impl }, { &try_add_extensions, &try_create_oak_impl }, }; @@ -244,6 +253,16 @@ std::unique_ptr LiveDeviceIOFactory::create_gener return std::make_unique(handles_, tracker, std::move(channels)); } +std::unique_ptr LiveDeviceIOFactory::create_joint_state_tracker_impl(const JointStateTracker* tracker) +{ + std::unique_ptr channels; + if (should_record(tracker)) + { + channels = LiveJointStateTrackerImpl::create_mcap_channels(*writer_, get_name(tracker)); + } + return std::make_unique(handles_, tracker, std::move(channels)); +} + std::unique_ptr LiveDeviceIOFactory::create_frame_metadata_tracker_oak_impl( const FrameMetadataTrackerOak* tracker) { diff --git a/src/core/live_trackers/cpp/live_joint_state_tracker_impl.cpp b/src/core/live_trackers/cpp/live_joint_state_tracker_impl.cpp new file mode 100644 index 000000000..6aa679810 --- /dev/null +++ b/src/core/live_trackers/cpp/live_joint_state_tracker_impl.cpp @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "live_joint_state_tracker_impl.hpp" + +#include +#include + +namespace core +{ + +namespace +{ + +SchemaTrackerConfig make_joint_state_tensor_config(const JointStateTracker* tracker) +{ + SchemaTrackerConfig cfg; + cfg.collection_id = tracker->collection_id(); + cfg.max_flatbuffer_size = tracker->max_flatbuffer_size(); + cfg.tensor_identifier = "joint_state"; + cfg.localized_name = "JointStateTracker"; + return cfg; +} + +} // namespace + +// ============================================================================ +// LiveJointStateTrackerImpl +// ============================================================================ + +std::unique_ptr LiveJointStateTrackerImpl::create_mcap_channels(mcap::McapWriter& writer, + std::string_view base_name) +{ + return std::make_unique( + writer, base_name, JointStateRecordingTraits::schema_name, + std::vector(JointStateRecordingTraits::recording_channels.begin(), + JointStateRecordingTraits::recording_channels.end())); +} + +LiveJointStateTrackerImpl::LiveJointStateTrackerImpl(const OpenXRSessionHandles& handles, + const JointStateTracker* tracker, + std::unique_ptr mcap_channels) + : mcap_channels_(std::move(mcap_channels)), + m_schema_reader(handles, + make_joint_state_tensor_config(tracker), + mcap_channels_.get(), + /*mcap_channel_index=*/0, + /*mcap_channel_tracked_index=*/1) +{ +} + +void LiveJointStateTrackerImpl::update(int64_t /*monotonic_time_ns*/) +{ + // Policy: SchemaTracker throws on critical OpenXR/tensor API failures. + // Missing collection/no new data are treated as common non-fatal cases. + m_schema_reader.update(m_tracked.data); +} + +const JointStateOutputTrackedT& LiveJointStateTrackerImpl::get_data() const +{ + return m_tracked; +} + +} // namespace core diff --git a/src/core/live_trackers/cpp/live_joint_state_tracker_impl.hpp b/src/core/live_trackers/cpp/live_joint_state_tracker_impl.hpp new file mode 100644 index 000000000..b39a7e5ae --- /dev/null +++ b/src/core/live_trackers/cpp/live_joint_state_tracker_impl.hpp @@ -0,0 +1,52 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "inc/live_trackers/schema_tracker.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace core +{ + +using JointStateMcapChannels = McapTrackerChannels; +using JointStateSchemaTracker = SchemaTracker; + +class LiveJointStateTrackerImpl : public IJointStateTrackerImpl +{ +public: + static std::vector required_extensions() + { + return SchemaTrackerBase::get_required_extensions(); + } + static std::unique_ptr create_mcap_channels(mcap::McapWriter& writer, + std::string_view base_name); + + LiveJointStateTrackerImpl(const OpenXRSessionHandles& handles, + const JointStateTracker* tracker, + std::unique_ptr mcap_channels); + + LiveJointStateTrackerImpl(const LiveJointStateTrackerImpl&) = delete; + LiveJointStateTrackerImpl& operator=(const LiveJointStateTrackerImpl&) = delete; + LiveJointStateTrackerImpl(LiveJointStateTrackerImpl&&) = delete; + LiveJointStateTrackerImpl& operator=(LiveJointStateTrackerImpl&&) = delete; + + void update(int64_t monotonic_time_ns) override; + const JointStateOutputTrackedT& get_data() const override; + +private: + std::unique_ptr mcap_channels_; + JointStateSchemaTracker m_schema_reader; + JointStateOutputTrackedT m_tracked; +}; + +} // namespace core diff --git a/src/core/mcap/cpp/inc/mcap/recording_traits.hpp b/src/core/mcap/cpp/inc/mcap/recording_traits.hpp index 8eb960396..cf0031e23 100644 --- a/src/core/mcap/cpp/inc/mcap/recording_traits.hpp +++ b/src/core/mcap/cpp/inc/mcap/recording_traits.hpp @@ -52,6 +52,13 @@ struct PedalRecordingTraits static constexpr std::array replay_channels = { "pedals_tracked" }; }; +struct JointStateRecordingTraits +{ + static constexpr std::string_view schema_name = "core.JointStateOutputRecord"; + static constexpr std::array recording_channels = { "joint_state", "joint_state_tracked" }; + static constexpr std::array replay_channels = { "joint_state_tracked" }; +}; + struct OakRecordingTraits { static constexpr std::string_view schema_name = "core.FrameMetadataOakRecord"; diff --git a/src/core/python/deviceio_init.py b/src/core/python/deviceio_init.py index ea4a5aafe..a640a10e4 100644 --- a/src/core/python/deviceio_init.py +++ b/src/core/python/deviceio_init.py @@ -17,6 +17,7 @@ MessageChannelTracker, FrameMetadataTrackerOak, Generic3AxisPedalTracker, + JointStateTracker, FullBodyTrackerPico, NUM_JOINTS, JOINT_PALM, @@ -60,6 +61,7 @@ "MessageChannelTracker", "FrameMetadataTrackerOak", "Generic3AxisPedalTracker", + "JointStateTracker", "FullBodyTrackerPico", "OpenXRSessionHandles", "DeviceIOSession", diff --git a/src/core/python/pyproject.toml.in b/src/core/python/pyproject.toml.in index f87db8f17..c196d6019 100644 --- a/src/core/python/pyproject.toml.in +++ b/src/core/python/pyproject.toml.in @@ -50,6 +50,7 @@ packages = [ "isaacteleop.retargeting_engine.utilities", "isaacteleop.retargeters", "isaacteleop.retargeters.G1", + "isaacteleop.retargeters.joint_space", "isaacteleop.retargeting_engine_ui", "isaacteleop.teleop_session_manager", "isaacteleop.cloudxr", diff --git a/src/core/python/requirements-retargeters.txt b/src/core/python/requirements-retargeters.txt index 9cd5b3e99..e62f8f579 100644 --- a/src/core/python/requirements-retargeters.txt +++ b/src/core/python/requirements-retargeters.txt @@ -10,6 +10,20 @@ torch>=2.7.0 # and NumPy 1.x, which cannot coexist with Pinocchio 3.x in combined environments. nlopt>=2.6.2 +# Pinocchio for JointStateRetargeter(mode="ee_pose"): URDF forward kinematics for +# joint-space leader arms / exoskeletons. Declared explicitly (rather than relying +# on dex-retargeting pulling it transitively) so EE-mode is reproducible. +# `mode="joint"` never imports it. +# +# Constraints MUST mirror requirements-grounding.txt: CI installs every extra into +# one environment, so a higher `pin` floor here (e.g. pin>=4, which needs +# cmeel-urdfdom>=6) is unsatisfiable against grounding's cmeel-urdfdom<5. The cmeel +# caps keep pin's compiled bindings linking the sonames they were built against; +# newer cmeel-urdfdom 6.x / cmeel-tinyxml2 11.x break `import pinocchio` at runtime. +pin>=2.7.0 +cmeel-urdfdom<5 +cmeel-tinyxml2<11 + # NOTE: The Sharpa retargeter's Pinocchio/Pink/daqp/loop-rate-limiters deps live # in `requirements-grounding.txt`, transitively via the `robotic_grounding` # wheel. They were briefly listed here while sharpa_hand_retargeter.py imported diff --git a/src/core/replay_trackers/cpp/CMakeLists.txt b/src/core/replay_trackers/cpp/CMakeLists.txt index 3647af299..770cc4ff3 100644 --- a/src/core/replay_trackers/cpp/CMakeLists.txt +++ b/src/core/replay_trackers/cpp/CMakeLists.txt @@ -10,6 +10,7 @@ add_library(replay_trackers STATIC replay_controller_tracker_impl.cpp replay_full_body_tracker_pico_impl.cpp replay_generic_3axis_pedal_tracker_impl.cpp + replay_joint_state_tracker_impl.cpp replay_message_channel_tracker_impl.cpp inc/replay_trackers/replay_deviceio_factory.hpp replay_hand_tracker_impl.hpp @@ -17,6 +18,7 @@ add_library(replay_trackers STATIC replay_controller_tracker_impl.hpp replay_full_body_tracker_pico_impl.hpp replay_generic_3axis_pedal_tracker_impl.hpp + replay_joint_state_tracker_impl.hpp replay_message_channel_tracker_impl.hpp ) diff --git a/src/core/replay_trackers/cpp/inc/replay_trackers/replay_deviceio_factory.hpp b/src/core/replay_trackers/cpp/inc/replay_trackers/replay_deviceio_factory.hpp index c8272babb..ddc04b063 100644 --- a/src/core/replay_trackers/cpp/inc/replay_trackers/replay_deviceio_factory.hpp +++ b/src/core/replay_trackers/cpp/inc/replay_trackers/replay_deviceio_factory.hpp @@ -21,6 +21,8 @@ class FullBodyTrackerPico; class IFullBodyTrackerPicoImpl; class Generic3AxisPedalTracker; class IGeneric3AxisPedalTrackerImpl; +class JointStateTracker; +class IJointStateTrackerImpl; class HandTracker; class IHandTrackerImpl; class HeadTracker; @@ -50,6 +52,7 @@ class ReplayDeviceIOFactory std::unique_ptr create_full_body_tracker_pico_impl(const FullBodyTrackerPico* tracker); std::unique_ptr create_generic_3axis_pedal_tracker_impl( const Generic3AxisPedalTracker* tracker); + std::unique_ptr create_joint_state_tracker_impl(const JointStateTracker* tracker); std::unique_ptr create_message_channel_tracker_impl(const MessageChannelTracker* tracker); private: diff --git a/src/core/replay_trackers/cpp/replay_deviceio_factory.cpp b/src/core/replay_trackers/cpp/replay_deviceio_factory.cpp index a3d6c3b6a..4941f3a63 100644 --- a/src/core/replay_trackers/cpp/replay_deviceio_factory.cpp +++ b/src/core/replay_trackers/cpp/replay_deviceio_factory.cpp @@ -8,6 +8,7 @@ #include "replay_generic_3axis_pedal_tracker_impl.hpp" #include "replay_hand_tracker_impl.hpp" #include "replay_head_tracker_impl.hpp" +#include "replay_joint_state_tracker_impl.hpp" #include "replay_message_channel_tracker_impl.hpp" #include @@ -15,6 +16,7 @@ #include #include #include +#include #include #include @@ -71,6 +73,12 @@ std::unique_ptr try_create_generic_pedal_impl(ReplayDeviceIOFactor return typed ? factory.create_generic_3axis_pedal_tracker_impl(typed) : nullptr; } +std::unique_ptr try_create_joint_state_impl(ReplayDeviceIOFactory& factory, const ITracker& tracker) +{ + auto* typed = dynamic_cast(&tracker); + return typed ? factory.create_joint_state_tracker_impl(typed) : nullptr; +} + std::unique_ptr try_create_message_channel_impl(ReplayDeviceIOFactory& factory, const ITracker& tracker) { auto* typed = dynamic_cast(&tracker); @@ -80,8 +88,13 @@ std::unique_ptr try_create_message_channel_impl(ReplayDeviceIOFact using TryCreateFn = std::unique_ptr (*)(ReplayDeviceIOFactory&, const ITracker&); inline const TryCreateFn k_tracker_dispatch[] = { - &try_create_head_impl, &try_create_hand_impl, &try_create_controller_impl, - &try_create_full_body_pico_impl, &try_create_generic_pedal_impl, &try_create_message_channel_impl, + &try_create_head_impl, + &try_create_hand_impl, + &try_create_controller_impl, + &try_create_full_body_pico_impl, + &try_create_generic_pedal_impl, + &try_create_joint_state_impl, + &try_create_message_channel_impl, }; } // namespace @@ -148,6 +161,11 @@ std::unique_ptr ReplayDeviceIOFactory::create_gen return std::make_unique(open_reader(filename_), get_name(tracker)); } +std::unique_ptr ReplayDeviceIOFactory::create_joint_state_tracker_impl(const JointStateTracker* tracker) +{ + return std::make_unique(open_reader(filename_), get_name(tracker)); +} + std::unique_ptr ReplayDeviceIOFactory::create_message_channel_tracker_impl( const MessageChannelTracker* tracker) { diff --git a/src/core/replay_trackers/cpp/replay_joint_state_tracker_impl.cpp b/src/core/replay_trackers/cpp/replay_joint_state_tracker_impl.cpp new file mode 100644 index 000000000..00567e051 --- /dev/null +++ b/src/core/replay_trackers/cpp/replay_joint_state_tracker_impl.cpp @@ -0,0 +1,57 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "replay_joint_state_tracker_impl.hpp" + +#include +#include +#include + +#include +#include +#include +#include + +namespace core +{ + +// ============================================================================ +// ReplayJointStateTrackerImpl +// ============================================================================ + +ReplayJointStateTrackerImpl::ReplayJointStateTrackerImpl(std::unique_ptr reader, + std::string_view base_name) + : mcap_viewers_(std::make_unique( + std::move(reader), + base_name, + std::vector( + JointStateRecordingTraits::replay_channels.begin(), JointStateRecordingTraits::replay_channels.end()))) +{ +} + +const JointStateOutputTrackedT& ReplayJointStateTrackerImpl::get_data() const +{ + return tracked_; +} + +void ReplayJointStateTrackerImpl::update(int64_t /*monotonic_time_ns*/) +{ + auto record = mcap_viewers_->read(0); + if (record) + { + tracked_.data = std::move(record->data); + warned_no_data_ = false; + } + else + { + // EOF / sparse streams call this every frame; log once per gap, not per frame. + if (!warned_no_data_) + { + std::cerr << "ReplayJointStateTrackerImpl: joint state data not found" << std::endl; + warned_no_data_ = true; + } + tracked_.data.reset(); + } +} + +} // namespace core diff --git a/src/core/replay_trackers/cpp/replay_joint_state_tracker_impl.hpp b/src/core/replay_trackers/cpp/replay_joint_state_tracker_impl.hpp new file mode 100644 index 000000000..63dd9e174 --- /dev/null +++ b/src/core/replay_trackers/cpp/replay_joint_state_tracker_impl.hpp @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace core +{ + +using JointStateMcapViewers = McapTrackerViewers; + +class ReplayJointStateTrackerImpl : public IJointStateTrackerImpl +{ +public: + ReplayJointStateTrackerImpl(std::unique_ptr reader, std::string_view base_name); + + ReplayJointStateTrackerImpl(const ReplayJointStateTrackerImpl&) = delete; + ReplayJointStateTrackerImpl& operator=(const ReplayJointStateTrackerImpl&) = delete; + ReplayJointStateTrackerImpl(ReplayJointStateTrackerImpl&&) = delete; + ReplayJointStateTrackerImpl& operator=(ReplayJointStateTrackerImpl&&) = delete; + + void update(int64_t monotonic_time_ns) override; + const JointStateOutputTrackedT& get_data() const override; + +private: + JointStateOutputTrackedT tracked_; + std::unique_ptr mcap_viewers_; + // Warn only on the first frame of a no-data gap (EOF / sparse stream) to avoid per-frame spam. + bool warned_no_data_ = false; +}; + +} // namespace core diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py b/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py index 4945cbe97..7b0a0efee 100644 --- a/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/__init__.py @@ -9,6 +9,7 @@ from .hands_source import HandsSource from .controllers_source import ControllersSource from .pedals_source import Generic3AxisPedalSource +from .joint_state_source import JointStateSource from .full_body_source import FullBodySource from .message_channel_source import MessageChannelSource from .message_channel_sink import MessageChannelSink @@ -23,11 +24,13 @@ HandPoseTrackedType, ControllerSnapshotTrackedType, Generic3AxisPedalOutputTrackedType, + JointStateOutputTrackedType, FullBodyPosePicoTrackedType, DeviceIOHeadPoseTracked, DeviceIOHandPoseTracked, DeviceIOControllerSnapshotTracked, DeviceIOGeneric3AxisPedalOutputTracked, + DeviceIOJointStateOutputTracked, DeviceIOFullBodyPosePicoTracked, MessageChannelMessagesTrackedType, MessageChannelConnectionStatus, @@ -44,6 +47,7 @@ "HandsSource", "ControllersSource", "Generic3AxisPedalSource", + "JointStateSource", "FullBodySource", "MessageChannelSource", "MessageChannelSink", @@ -55,6 +59,7 @@ "HandPoseTrackedType", "ControllerSnapshotTrackedType", "Generic3AxisPedalOutputTrackedType", + "JointStateOutputTrackedType", "FullBodyPosePicoTrackedType", "MessageChannelMessagesTrackedType", "MessageChannelConnectionStatus", @@ -63,6 +68,7 @@ "DeviceIOHandPoseTracked", "DeviceIOControllerSnapshotTracked", "DeviceIOGeneric3AxisPedalOutputTracked", + "DeviceIOJointStateOutputTracked", "DeviceIOFullBodyPosePicoTracked", "DeviceIOMessageChannelMessagesTracked", "MessageChannelMessagesTrackedGroup", diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/deviceio_tensor_types.py b/src/core/retargeting_engine/python/deviceio_source_nodes/deviceio_tensor_types.py index 77aaa4e4f..89f92b5db 100644 --- a/src/core/retargeting_engine/python/deviceio_source_nodes/deviceio_tensor_types.py +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/deviceio_tensor_types.py @@ -18,6 +18,7 @@ HandPoseTrackedT, ControllerSnapshotTrackedT, Generic3AxisPedalOutputTrackedT, + JointStateOutputTrackedT, FullBodyPosePicoTrackedT, MessageChannelMessagesTrackedT, ) @@ -99,6 +100,26 @@ def validate_value(self, value: Any) -> None: ) +class JointStateOutputTrackedType(TensorType): + """JointStateOutputTrackedT wrapper type from DeviceIO JointStateTracker.""" + + def __init__(self, name: str) -> None: + super().__init__(name) + + def _check_instance_compatibility(self, other: TensorType) -> bool: + if not isinstance(other, JointStateOutputTrackedType): + raise TypeError( + f"Expected JointStateOutputTrackedType, got {type(other).__name__}" + ) + return True + + def validate_value(self, value: Any) -> None: + if not isinstance(value, JointStateOutputTrackedT): + raise TypeError( + f"Expected JointStateOutputTrackedT for '{self.name}', got {type(value).__name__}" + ) + + class FullBodyPosePicoTrackedType(TensorType): """FullBodyPosePicoTrackedT wrapper type from DeviceIO FullBodyTrackerPico.""" @@ -211,6 +232,18 @@ def DeviceIOGeneric3AxisPedalOutputTracked() -> TensorGroupType: ) +def DeviceIOJointStateOutputTracked() -> TensorGroupType: + """Tracked joint-state data from DeviceIO JointStateTracker. + + Contains: + joint_state_tracked: JointStateOutputTrackedT wrapper (always set; .data is None when inactive) + """ + return TensorGroupType( + "deviceio_joint_state_output", + [JointStateOutputTrackedType("joint_state_tracked")], + ) + + def DeviceIOFullBodyPosePicoTracked() -> TensorGroupType: """Tracked full body pose data from DeviceIO FullBodyTrackerPico. diff --git a/src/core/retargeting_engine/python/deviceio_source_nodes/joint_state_source.py b/src/core/retargeting_engine/python/deviceio_source_nodes/joint_state_source.py new file mode 100644 index 000000000..96f8c76e9 --- /dev/null +++ b/src/core/retargeting_engine/python/deviceio_source_nodes/joint_state_source.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Joint-State Source Node - DeviceIO to Retargeting Engine converter. + +Converts raw ``JointStateOutput`` flatbuffer data (from a generic joint-space device such as a +leader arm or exoskeleton) into a name-keyed tensor group with one ``FloatType`` per joint +position, ready for consumption by ``JointStateRetargeter`` (or a ``TensorReorderer``) downstream. + +The set of joints is fixed at construction (``joint_names``) so the retargeting graph has a static +input spec; the per-frame schema names are looked up against it, so wiring is order-independent. +""" + +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from .interface import IDeviceIOSource +from ..interface.retargeter_core_types import RetargeterIO, RetargeterIOType +from ..interface.tensor_group import TensorGroup +from ..interface.tensor_group_type import OptionalType, TensorGroupType +from ..tensor_types import FloatType +from .deviceio_tensor_types import DeviceIOJointStateOutputTracked + +if TYPE_CHECKING: + from isaacteleop.deviceio import ITracker + from isaacteleop.schema import JointStateOutputTrackedT + + +class JointStateSource(IDeviceIOSource): + """Stateless converter: DeviceIO ``JointStateOutput`` -> name-keyed joint-position group. + + Inputs: + - "deviceio_joint_state": Raw ``JointStateOutput`` flatbuffer from ``JointStateTracker``. + + Outputs (Optional -- absent when the device is inactive): + - :data:`JOINTS`: one ``FloatType`` per ``joint_names`` entry (joint position [rad or m]). + + Usage:: + + source = JointStateSource(name="leader", collection_id="so101_leader", + joint_names=["shoulder_pan", ..., "gripper"]) + # In a TeleopSession the tracker is discovered from the pipeline and polled each frame. + + Note: + ``joint_names`` defines the output group layout and must match the downstream consumer's + expected order (e.g. ``JointStateRetargeterConfig.device_joints``). Only joint *positions* + are surfaced; the schema's ``velocity`` / ``effort`` / ``valid`` / ``ee_pose`` fields are + not exposed yet (reserved for future use). + """ + + JOINTS = "joints" + + def __init__(self, name: str, collection_id: str, joint_names: list[str]) -> None: + """Initialize the joint-state source node. + + Args: + name: Unique name for this source node. + collection_id: Tensor collection ID for the device (must match the plugin / pusher). + joint_names: Ordered device DOF names; defines the static output spec and the order + the downstream pipeline consumes. + """ + import isaacteleop.deviceio as deviceio + + self._tracker = deviceio.JointStateTracker(collection_id) + self._collection_id = collection_id + self._joint_names = list(joint_names) + super().__init__(name) + + def get_tracker(self) -> "ITracker": + """Return the ``JointStateTracker`` instance for ``TeleopSession`` to initialize.""" + return self._tracker + + def poll_tracker(self, deviceio_session: Any) -> RetargeterIO: + """Poll the tracker and wrap the raw tracked data for the compute step.""" + tracked = self._tracker.get_data(deviceio_session) + tg = TensorGroup(DeviceIOJointStateOutputTracked()) + tg[0] = tracked + return {"deviceio_joint_state": tg} + + def input_spec(self) -> RetargeterIOType: + """Declare the raw DeviceIO joint-state input.""" + return {"deviceio_joint_state": DeviceIOJointStateOutputTracked()} + + def output_spec(self) -> RetargeterIOType: + """Declare the name-keyed joint-position output (Optional -- may be absent).""" + return { + self.JOINTS: OptionalType( + TensorGroupType(self.JOINTS, [FloatType(n) for n in self._joint_names]) + ) + } + + def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: + """Convert ``JointStateOutputTrackedT`` to a name-keyed joint-position group. + + Calls ``set_none()`` on the output when the device is inactive. + """ + tracked: "JointStateOutputTrackedT" = inputs["deviceio_joint_state"][0] + data = tracked.data + + out = outputs[self.JOINTS] + if data is None: + out.set_none() + return + + by_name: dict[str, float] = {} + for joint in data.joints: + name = joint.name.decode() if isinstance(joint.name, bytes) else joint.name + by_name[name] = float(joint.position) + + for i, name in enumerate(self._joint_names): + out[i] = by_name.get(name, 0.0) diff --git a/src/core/retargeting_engine_tests/python/test_joint_state_retargeter.py b/src/core/retargeting_engine_tests/python/test_joint_state_retargeter.py new file mode 100644 index 000000000..2ff62efde --- /dev/null +++ b/src/core/retargeting_engine_tests/python/test_joint_state_retargeter.py @@ -0,0 +1,461 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Sim-free unit tests for the generic JointStateRetargeter (leader arms, exoskeletons, ...). + +Covers both modes at the ``BaseRetargeter.compute`` level and end-to-end through a +``TensorReorderer`` + ``OutputCombiner`` pipeline (the object an Isaac Lab pipeline_builder +returns), with no ``gym.make``, USD, GPU, or XR device: + +* ``mode="joint"`` -- name-keyed remap + per-joint affine, hold-last, reset. +* ``mode="ee_pose"`` -- URDF forward kinematics (guarded on ``pinocchio``) + gripper command. +""" + +import importlib.util +import math +import os +import tempfile + +import numpy as np +import pytest + +from isaacteleop.retargeting_engine.interface import ( + ComputeContext, + ExecutionEvents, + ExecutionState, + OutputCombiner, + OptionalTensorGroup, + TensorGroup, +) +from isaacteleop.retargeting_engine.interface.retargeter_core_types import GraphTime +from isaacteleop.retargeting_engine.interface.tensor_group_type import ( + OptionalTensorGroupType, +) +from isaacteleop.retargeting_engine.tensor_types import TransformMatrix +from isaacteleop.retargeters import ( + JointStateRetargeter, + JointStateRetargeterConfig, + TensorReorderer, +) + +_HAS_PINOCCHIO = importlib.util.find_spec("pinocchio") is not None + +SO101_JOINTS = [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", +] + +# Minimal 2-revolute-joint arm with a tool frame, enough to exercise FK without external assets. +_MINIMAL_URDF = """ + + + + + + + + + + + + + + + + + + +""" + + +def _ctx( + reset: bool = False, state: ExecutionState = ExecutionState.RUNNING +) -> ComputeContext: + return ComputeContext( + graph_time=GraphTime(sim_time_ns=0, real_time_ns=0), + execution_events=ExecutionEvents(reset=reset, execution_state=state), + ) + + +def _build_io(node): + inputs = {} + for k, v in node.input_spec().items(): + inputs[k] = ( + OptionalTensorGroup(v) + if isinstance(v, OptionalTensorGroupType) + else TensorGroup(v) + ) + outputs = {} + for k, v in node.output_spec().items(): + outputs[k] = ( + OptionalTensorGroup(v) + if isinstance(v, OptionalTensorGroupType) + else TensorGroup(v) + ) + return inputs, outputs + + +def _joints_group(node, device_joints, positions): + """Build a present, name-keyed joints TensorGroup for the retargeter's JOINTS input.""" + inner = node.input_spec()[JointStateRetargeter.JOINTS].inner_type + group = TensorGroup(inner) + for i, name in enumerate(device_joints): + group[i] = float(positions.get(name, 0.0)) + return group + + +def _world_T_ee(translation) -> TensorGroup: + """Build a TransformMatrix group with identity rotation and the given translation.""" + group = TensorGroup(TransformMatrix()) + matrix = np.eye(4, dtype=np.float32) + matrix[:3, 3] = np.asarray(translation, dtype=np.float32) + group[0] = matrix + return group + + +# =========================================================================== +# Joint mode +# =========================================================================== + + +class TestJointMode: + def test_output_spec_matches_target_joints(self): + r = JointStateRetargeter( + "r", "joint", JointStateRetargeterConfig(device_joints=SO101_JOINTS) + ) + spec = r.output_spec() + assert list(spec) == ["joint_targets"] + # Defaults target_joints to device_joints (identity mirror). + names = [t.name for t in spec["joint_targets"].types] + assert names == SO101_JOINTS + + def test_identity_mirror(self): + r = JointStateRetargeter( + "r", "joint", JointStateRetargeterConfig(device_joints=SO101_JOINTS) + ) + inputs, outputs = _build_io(r) + positions = {n: 0.1 * (i + 1) for i, n in enumerate(SO101_JOINTS)} + inputs[JointStateRetargeter.JOINTS] = _joints_group(r, SO101_JOINTS, positions) + r.compute(inputs, outputs, _ctx(reset=True)) + out = outputs["joint_targets"] + for i, n in enumerate(SO101_JOINTS): + assert float(out[i]) == pytest.approx(positions[n]) + + def test_affine_scale_offset_sign(self): + cfg = JointStateRetargeterConfig( + device_joints=["a", "b"], + target_joints=["a", "b"], + scale={"a": 2.0}, + offset={"b": 0.5}, + sign={"a": -1.0}, + ) + r = JointStateRetargeter("r", "joint", cfg) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["a", "b"], {"a": 0.3, "b": 0.4} + ) + r.compute(inputs, outputs, _ctx(reset=True)) + out = outputs["joint_targets"] + assert float(out[0]) == pytest.approx(-1.0 * 2.0 * 0.3) # sign * scale * value + assert float(out[1]) == pytest.approx(0.5 + 0.4) # offset + value + + def test_name_remap(self): + # Device joint "lead_a" feeds robot joint "robot_a". + cfg = JointStateRetargeterConfig( + device_joints=["lead_a", "lead_b"], + target_joints=["robot_a", "robot_b"], + joint_map={"lead_a": "robot_a", "lead_b": "robot_b"}, + ) + r = JointStateRetargeter("r", "joint", cfg) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["lead_a", "lead_b"], {"lead_a": 0.7, "lead_b": -0.2} + ) + r.compute(inputs, outputs, _ctx(reset=True)) + out = outputs["joint_targets"] + assert float(out[0]) == pytest.approx(0.7) + assert float(out[1]) == pytest.approx(-0.2) + + def test_hold_last_on_dropped_frame(self): + r = JointStateRetargeter( + "r", "joint", JointStateRetargeterConfig(device_joints=["a", "b"]) + ) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["a", "b"], {"a": 0.9, "b": 0.1} + ) + r.compute(inputs, outputs, _ctx()) + # Next frame: joints input absent -> hold last commanded targets. + inputs2, outputs2 = _build_io(r) + r.compute(inputs2, outputs2, _ctx()) + assert float(outputs2["joint_targets"][0]) == pytest.approx(0.9) + assert float(outputs2["joint_targets"][1]) == pytest.approx(0.1) + + def test_reset_zeros_targets(self): + r = JointStateRetargeter( + "r", "joint", JointStateRetargeterConfig(device_joints=["a"]) + ) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group(r, ["a"], {"a": 0.5}) + r.compute(inputs, outputs, _ctx()) + # Reset with no input -> targets cleared to zero. + inputs2, outputs2 = _build_io(r) + r.compute(inputs2, outputs2, _ctx(reset=True)) + assert float(outputs2["joint_targets"][0]) == pytest.approx(0.0) + + +# =========================================================================== +# EE-pose mode (URDF forward kinematics) -- guarded on pinocchio +# =========================================================================== + + +@pytest.fixture(scope="module") +def minimal_urdf_path(): + fd, path = tempfile.mkstemp(suffix=".urdf") + with os.fdopen(fd, "w") as f: + f.write(_MINIMAL_URDF) + yield path + os.remove(path) + + +@pytest.mark.skipif(not _HAS_PINOCCHIO, reason="pinocchio not installed") +class TestEePoseMode: + def _make(self, urdf_path, **overrides): + cfg = JointStateRetargeterConfig( + device_joints=["j1", "j2", "gripper"], + urdf_path=urdf_path, + ee_link="tool", + gripper_joint="gripper", + **overrides, + ) + return JointStateRetargeter("ee", "ee_pose", cfg) + + def test_requires_urdf_and_ee_link(self): + with pytest.raises(ValueError): + JointStateRetargeter( + "ee", "ee_pose", JointStateRetargeterConfig(device_joints=["j1"]) + ) + + def test_output_spec_is_pose_plus_gripper(self, minimal_urdf_path): + r = self._make(minimal_urdf_path) + spec = r.output_spec() + assert set(spec) == {"ee_pose", "gripper_command"} + assert spec["ee_pose"].types[0].shape == (7,) + + def test_fk_unit_quaternion_and_shape(self, minimal_urdf_path): + r = self._make(minimal_urdf_path) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], {"j1": 0.0, "j2": 0.0, "gripper": 0.0} + ) + r.compute(inputs, outputs, _ctx(reset=True)) + pose = np.asarray(outputs["ee_pose"][0], dtype=np.float64) + assert pose.shape == (7,) + assert np.linalg.norm(pose[3:7]) == pytest.approx(1.0, abs=1e-5) + + def test_fk_moves_with_joints(self, minimal_urdf_path): + r = self._make(minimal_urdf_path) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], {"j1": 0.0, "j2": 0.0, "gripper": 0.0} + ) + r.compute(inputs, outputs, _ctx(reset=True)) + pos0 = np.asarray(outputs["ee_pose"][0], dtype=np.float64)[:3].copy() + + inputs2, outputs2 = _build_io(r) + inputs2[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], {"j1": math.pi / 2, "j2": 0.0, "gripper": 0.0} + ) + r.compute(inputs2, outputs2, _ctx()) + pos1 = np.asarray(outputs2["ee_pose"][0], dtype=np.float64)[:3] + assert np.linalg.norm(pos1 - pos0) > 0.05 + + def test_gripper_raw_passthrough(self, minimal_urdf_path): + r = self._make(minimal_urdf_path) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], {"j1": 0.0, "j2": 0.0, "gripper": 0.42} + ) + r.compute(inputs, outputs, _ctx(reset=True)) + assert float(outputs["gripper_command"][0]) == pytest.approx(0.42) + + def test_gripper_normalized_closedness(self, minimal_urdf_path): + r = self._make(minimal_urdf_path, gripper_open=0.0, gripper_close=2.0) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], {"j1": 0.0, "j2": 0.0, "gripper": 1.0} + ) + r.compute(inputs, outputs, _ctx(reset=True)) + # 1.0 is halfway between open (0) and close (2) -> closedness 0.5. + assert float(outputs["gripper_command"][0]) == pytest.approx(0.5) + + +@pytest.mark.skipif(not _HAS_PINOCCHIO, reason="pinocchio not installed") +class TestEeClutch: + """Clutch rebasing: no jump on engage, then track FK deltas off the latched home.""" + + def _make(self, urdf_path): + return JointStateRetargeter( + "ee", + "ee_pose", + JointStateRetargeterConfig( + device_joints=["j1", "j2", "gripper"], + urdf_path=urdf_path, + ee_link="tool", + clutch=True, + ), + ) + + def _fk_pos(self, urdf_path, joint_values): + """Absolute FK position from a non-clutch retargeter (reference for delta checks).""" + nc = JointStateRetargeter( + "nc", + "ee_pose", + JointStateRetargeterConfig( + device_joints=["j1", "j2", "gripper"], + urdf_path=urdf_path, + ee_link="tool", + ), + ) + inputs, outputs = _build_io(nc) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + nc, ["j1", "j2", "gripper"], joint_values + ) + nc.compute(inputs, outputs, _ctx(reset=True)) + return np.asarray(outputs["ee_pose"][0], dtype=np.float64)[:3] + + def test_clutch_adds_robot_ee_pos_input(self, minimal_urdf_path): + r = self._make(minimal_urdf_path) + assert JointStateRetargeter.ROBOT_EE_POS_INPUT in r.input_spec() + + def test_engage_latches_robot_ee_home_no_jump(self, minimal_urdf_path): + r = self._make(minimal_urdf_path) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], {"j1": 0.5, "j2": -0.3, "gripper": 0.0} + ) + home = [0.31, -0.12, 0.44] + inputs[JointStateRetargeter.ROBOT_EE_POS_INPUT] = _world_T_ee(home) + # First RUNNING frame: EE sits at the robot's current EE (home), not the leader's FK pose. + r.compute(inputs, outputs, _ctx(reset=True, state=ExecutionState.RUNNING)) + pose = np.asarray(outputs["ee_pose"][0], dtype=np.float64) + np.testing.assert_allclose(pose[:3], home, atol=1e-5) + + def test_not_running_holds_and_does_not_latch(self, minimal_urdf_path): + r = self._make(minimal_urdf_path) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], {"j1": 0.5, "j2": 0.0, "gripper": 0.0} + ) + r.compute(inputs, outputs, _ctx(state=ExecutionState.STOPPED)) + pose = np.asarray(outputs["ee_pose"][0], dtype=np.float64) + np.testing.assert_allclose(pose[:3], [0.0, 0.0, 0.0], atol=1e-9) # held seed + assert r._origin is None # not latched while stopped + + def test_motion_after_engage_adds_fk_delta(self, minimal_urdf_path): + j0 = {"j1": 0.2, "j2": -0.1, "gripper": 0.0} + j1 = {"j1": 0.8, "j2": 0.3, "gripper": 0.0} + fk0 = self._fk_pos(minimal_urdf_path, j0) + fk1 = self._fk_pos(minimal_urdf_path, j1) + home = np.array([0.3, 0.1, 0.5]) + + r = self._make(minimal_urdf_path) + inputs, outputs = _build_io(r) + inputs[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], j0 + ) + inputs[JointStateRetargeter.ROBOT_EE_POS_INPUT] = _world_T_ee(home) + r.compute( + inputs, outputs, _ctx(reset=True, state=ExecutionState.RUNNING) + ) # engage + np.testing.assert_allclose( + np.asarray(outputs["ee_pose"][0], dtype=np.float64)[:3], home, atol=1e-5 + ) + + inputs2, outputs2 = _build_io(r) + inputs2[JointStateRetargeter.JOINTS] = _joints_group( + r, ["j1", "j2", "gripper"], j1 + ) + r.compute(inputs2, outputs2, _ctx(state=ExecutionState.RUNNING)) + pos = np.asarray(outputs2["ee_pose"][0], dtype=np.float64)[:3] + np.testing.assert_allclose(pos, home + (fk1 - fk0), atol=1e-5) + + +# =========================================================================== +# End-to-end pipeline (retargeter -> TensorReorderer -> OutputCombiner) +# =========================================================================== + + +def _run_pipeline(combiner, leaf_name, joints_group, ctx): + result = combiner.execute_pipeline( + {leaf_name: {JointStateRetargeter.JOINTS: joints_group}}, ctx + ) + return np.asarray(result["action"][0], dtype=np.float64) + + +class TestPipeline: + def test_joint_pipeline_action_width_and_order(self): + r = JointStateRetargeter( + "leader", "joint", JointStateRetargeterConfig(device_joints=SO101_JOINTS) + ) + reorderer = TensorReorderer( + input_config={"joint_targets": SO101_JOINTS}, + output_order=SO101_JOINTS, + name="action_reorderer", + input_types={"joint_targets": "scalar"}, + ) + connected = reorderer.connect({"joint_targets": r.output("joint_targets")}) + combiner = OutputCombiner({"action": connected.output("output")}) + + positions = {n: 0.1 * (i + 1) for i, n in enumerate(SO101_JOINTS)} + jg = _joints_group(r, SO101_JOINTS, positions) + action = _run_pipeline(combiner, r.name, jg, _ctx(reset=True)) + assert action.shape == (6,) + np.testing.assert_allclose( + action, [positions[n] for n in SO101_JOINTS], atol=1e-6 + ) + + @pytest.mark.skipif(not _HAS_PINOCCHIO, reason="pinocchio not installed") + def test_ee_pipeline_action_width(self, minimal_urdf_path): + pose_labels = [ + "pos_x", + "pos_y", + "pos_z", + "quat_x", + "quat_y", + "quat_z", + "quat_w", + ] + r = JointStateRetargeter( + "leader", + "ee_pose", + JointStateRetargeterConfig( + device_joints=["j1", "j2", "gripper"], + urdf_path=minimal_urdf_path, + ee_link="tool", + ), + ) + reorderer = TensorReorderer( + input_config={"ee_pose": pose_labels, "gripper_command": ["gripper_value"]}, + output_order=pose_labels + ["gripper_value"], + name="action_reorderer", + input_types={"ee_pose": "array", "gripper_command": "scalar"}, + ) + connected = reorderer.connect( + { + "ee_pose": r.output("ee_pose"), + "gripper_command": r.output("gripper_command"), + } + ) + combiner = OutputCombiner({"action": connected.output("output")}) + + jg = _joints_group( + r, ["j1", "j2", "gripper"], {"j1": 0.2, "j2": -0.3, "gripper": 0.5} + ) + action = _run_pipeline(combiner, r.name, jg, _ctx(reset=True)) + assert action.shape == (8,) + assert action[7] == pytest.approx(0.5) # gripper passthrough diff --git a/src/core/retargeting_engine_tests/python/test_joint_state_source.py b/src/core/retargeting_engine_tests/python/test_joint_state_source.py new file mode 100644 index 000000000..69d15f819 --- /dev/null +++ b/src/core/retargeting_engine_tests/python/test_joint_state_source.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the JointStateSource DeviceIO converter. + +Exercises the stateless converter from a raw ``JointStateOutput`` FlatBuffer (constructed via the +real schema Python bindings) into the name-keyed joint-position tensor group consumed downstream, +with no OpenXR device involved. +""" + +import pytest + +from isaacteleop.retargeting_engine.deviceio_source_nodes import JointStateSource +from isaacteleop.retargeting_engine.interface.base_retargeter import _make_output_group +from isaacteleop.retargeting_engine.interface.tensor_group import TensorGroup +from isaacteleop.schema import JointState, JointStateOutput, JointStateOutputTrackedT + +SO101_JOINTS = [ + "shoulder_pan", + "shoulder_lift", + "elbow_flex", + "wrist_flex", + "wrist_roll", + "gripper", +] + + +def _make_inputs(source, raw: dict) -> dict: + spec = source.input_spec() + result = {} + for name, objects in raw.items(): + tg = TensorGroup(spec[name]) + for i, obj in enumerate(objects): + tg[i] = obj + result[name] = tg + return result + + +def _outputs(source): + return {name: _make_output_group(gt) for name, gt in source.output_spec().items()} + + +def _make_output(joint_values: dict) -> JointStateOutput: + out = JointStateOutput() + out.device_id = "so101_leader" + out.joints = [JointState(name, pos) for name, pos in joint_values.items()] + return out + + +class TestJointStateSource: + def test_creation_and_tracker(self): + src = JointStateSource( + name="leader", collection_id="so101_leader", joint_names=SO101_JOINTS + ) + assert src.name == "leader" + tracker = src.get_tracker() + assert tracker is not None + assert tracker.get_name() == "JointStateTracker" + + def test_input_output_spec(self): + src = JointStateSource( + name="leader", collection_id="so101_leader", joint_names=SO101_JOINTS + ) + assert list(src.input_spec()) == ["deviceio_joint_state"] + out_spec = src.output_spec() + assert list(out_spec) == [JointStateSource.JOINTS] + assert out_spec[JointStateSource.JOINTS].is_optional + names = [t.name for t in out_spec[JointStateSource.JOINTS].inner_type.types] + assert names == SO101_JOINTS + + def test_active_conversion(self): + src = JointStateSource( + name="leader", collection_id="so101_leader", joint_names=SO101_JOINTS + ) + values = {n: round(0.1 * (i + 1), 3) for i, n in enumerate(SO101_JOINTS)} + inputs = _make_inputs( + src, + {"deviceio_joint_state": [JointStateOutputTrackedT(_make_output(values))]}, + ) + outputs = _outputs(src) + src.compute(inputs, outputs) + + group = outputs[JointStateSource.JOINTS] + assert not group.is_none + for i, n in enumerate(SO101_JOINTS): + assert float(group[i]) == pytest.approx(values[n]) + + def test_name_order_independent(self): + """Joints arriving in a different order than joint_names are mapped by name.""" + src = JointStateSource( + name="leader", collection_id="so101_leader", joint_names=["a", "b", "c"] + ) + # Schema joints intentionally in reverse order. + out = _make_output({"c": 3.0, "a": 1.0, "b": 2.0}) + inputs = _make_inputs( + src, {"deviceio_joint_state": [JointStateOutputTrackedT(out)]} + ) + outputs = _outputs(src) + src.compute(inputs, outputs) + group = outputs[JointStateSource.JOINTS] + assert float(group[0]) == pytest.approx(1.0) # a + assert float(group[1]) == pytest.approx(2.0) # b + assert float(group[2]) == pytest.approx(3.0) # c + + def test_missing_joint_defaults_zero(self): + src = JointStateSource( + name="leader", collection_id="so101_leader", joint_names=["a", "missing"] + ) + out = _make_output({"a": 1.5}) + inputs = _make_inputs( + src, {"deviceio_joint_state": [JointStateOutputTrackedT(out)]} + ) + outputs = _outputs(src) + src.compute(inputs, outputs) + group = outputs[JointStateSource.JOINTS] + assert float(group[0]) == pytest.approx(1.5) + assert float(group[1]) == pytest.approx(0.0) + + def test_inactive_sets_none(self): + src = JointStateSource( + name="leader", collection_id="so101_leader", joint_names=SO101_JOINTS + ) + # TrackedT with no data -> device inactive. + inputs = _make_inputs( + src, {"deviceio_joint_state": [JointStateOutputTrackedT()]} + ) + outputs = _outputs(src) + src.compute(inputs, outputs) + assert outputs[JointStateSource.JOINTS].is_none diff --git a/src/core/schema/fbs/joint_state.fbs b/src/core/schema/fbs/joint_state.fbs new file mode 100644 index 000000000..54570ac66 --- /dev/null +++ b/src/core/schema/fbs/joint_state.fbs @@ -0,0 +1,73 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +include "pose.fbs"; +include "timestamp.fbs"; + +namespace core; + +// One actuated degree of freedom as a (name -> value) record. `position` is the primary +// value; `velocity` / `effort` are optional (populated per JointStateOutput.has_velocity / +// has_effort). `name` is the lookup key, so downstream consumers read joints BY NAME +// (order-independent). The gripper is modeled as just another named DOF (conventionally the +// last one, name == "gripper") -- there is no dedicated gripper field. +table JointState { + // Stable DOF name, e.g. "shoulder_pan", "gripper". Lookup key for the joints vector. + name: string (id: 0, key); + + // DOF position. [rad] for revolute joints, [m] for prismatic joints. + position: float (id: 1); + + // DOF velocity, parallel to position. [rad/s | m/s]. Meaningful only when has_velocity. + velocity: float (id: 2); + + // DOF effort (motor torque / current / force). [N*m | A | N depending on device]. + // Meaningful only when has_effort. Present for exoskeletons / bilateral force-feedback rigs. + effort: float (id: 3); + + // Per-DOF validity. Defaults to true; set false when a single joint read is stale/missing. + valid: bool = true (id: 4); +} + +// Per-frame state of a generic joint-space input device (leader arm, exoskeleton, glove, or +// any joint-encoder source), as name:value joint records. All fields are present when the +// parent Tracked/Record wrapper's data is non-null. +table JointStateOutput { + // One entry per actuated DOF, keyed by JointState.name. The reference JointStateSource maps + // these into the configured joint order BY NAME (so wire order does not matter). `name` is a + // FlatBuffers key, so a C++ consumer may LookupByKey if the vector is built sorted + // (flatbuffers::CreateVectorOfSortedTables); the reference plugin/source do not require it. + joints: [JointState] (id: 0); + + // Stable device identity for routing / calibration / multi-device setups + // (e.g. "so101_leader", or a serial number). Static per session. + device_id: string (id: 1); + + // Which optional per-joint channels are populated this session. + has_velocity: bool (id: 2); + has_effort: bool (id: 3); + + // OPTIONAL end-effector pose in the device base frame, for devices that run on-device forward + // kinematics. RESERVED: the reference JointStateRetargeter computes FK from the joint + // positions and does not consume this field yet; it is provided so device-side FK can be + // adopted without a schema change. Position [m], orientation quaternion [x, y, z, w]. + ee_pose: Pose (id: 4); + ee_pose_valid: bool (id: 5); +} + +// Tracked wrapper for the in-memory tracker API (data is null when the device is inactive). +table JointStateOutputTracked { + data: JointStateOutput (id: 0); +} + +// MCAP recording wrapper for JointStateOutput. +// +// Record types are the root types written to MCAP channels by the McapRecorder. +// Trackers serialize into Record types via their serialize() method, but the +// public query API returns the inner data type directly. +table JointStateOutputRecord { + data: JointStateOutput (id: 0); + timestamp: DeviceDataTimestamp (id: 1); +} + +root_type JointStateOutputRecord; diff --git a/src/core/schema/python/CMakeLists.txt b/src/core/schema/python/CMakeLists.txt index d948e1417..4d02a4e8d 100644 --- a/src/core/schema/python/CMakeLists.txt +++ b/src/core/schema/python/CMakeLists.txt @@ -7,6 +7,7 @@ pybind11_add_module(schema_py full_body_bindings.h hand_bindings.h head_bindings.h + joint_state_bindings.h message_channel_bindings.h pedals_bindings.h pose_bindings.h diff --git a/src/core/schema/python/joint_state_bindings.h b/src/core/schema/python/joint_state_bindings.h new file mode 100644 index 000000000..0370f9877 --- /dev/null +++ b/src/core/schema/python/joint_state_bindings.h @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Python bindings for the JointState FlatBuffer schema. +// Types: JointState (table), JointStateOutput (table), and the Tracked / Record wrappers. + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace py = pybind11; + +namespace core +{ + +inline void bind_joint_state(py::module& m) +{ + // One named DOF (name -> position [+ optional velocity/effort/valid]). + py::class_>(m, "JointState") + .def(py::init([]() { return std::make_shared(); })) + .def(py::init( + [](const std::string& name, float position, float velocity, float effort, bool valid) + { + auto obj = std::make_shared(); + obj->name = name; + obj->position = position; + obj->velocity = velocity; + obj->effort = effort; + obj->valid = valid; + return obj; + }), + py::arg("name"), py::arg("position") = 0.0f, py::arg("velocity") = 0.0f, py::arg("effort") = 0.0f, + py::arg("valid") = true) + .def_property( + "name", [](const JointStateT& self) { return self.name; }, + [](JointStateT& self, const std::string& val) { self.name = val; }) + .def_property( + "position", [](const JointStateT& self) { return self.position; }, + [](JointStateT& self, float val) { self.position = val; }) + .def_property( + "velocity", [](const JointStateT& self) { return self.velocity; }, + [](JointStateT& self, float val) { self.velocity = val; }) + .def_property( + "effort", [](const JointStateT& self) { return self.effort; }, + [](JointStateT& self, float val) { self.effort = val; }) + .def_property( + "valid", [](const JointStateT& self) { return self.valid; }, + [](JointStateT& self, bool val) { self.valid = val; }) + .def("__repr__", [](const JointStateT& self) + { return "JointState(name=" + self.name + ", position=" + std::to_string(self.position) + ")"; }); + + // Per-frame device state: a list of named joints plus identity / capability flags. + py::class_>(m, "JointStateOutput") + .def(py::init([]() { return std::make_shared(); })) + .def_property( + "joints", [](const JointStateOutputT& self) { return self.joints; }, + [](JointStateOutputT& self, std::vector> val) { self.joints = std::move(val); }) + .def_property( + "device_id", [](const JointStateOutputT& self) { return self.device_id; }, + [](JointStateOutputT& self, const std::string& val) { self.device_id = val; }) + .def_property( + "has_velocity", [](const JointStateOutputT& self) { return self.has_velocity; }, + [](JointStateOutputT& self, bool val) { self.has_velocity = val; }) + .def_property( + "has_effort", [](const JointStateOutputT& self) { return self.has_effort; }, + [](JointStateOutputT& self, bool val) { self.has_effort = val; }) + .def_property( + "ee_pose_valid", [](const JointStateOutputT& self) { return self.ee_pose_valid; }, + [](JointStateOutputT& self, bool val) { self.ee_pose_valid = val; }) + .def("__repr__", + [](const JointStateOutputT& self) { + return "JointStateOutput(device_id=" + self.device_id + + ", joints=" + std::to_string(self.joints.size()) + ")"; + }); + + py::class_>(m, "JointStateOutputRecord") + .def(py::init<>()) + .def(py::init( + [](const JointStateOutputT& data, const DeviceDataTimestamp& timestamp) + { + auto obj = std::make_shared(); + obj->data = std::make_shared(data); + obj->timestamp = std::make_shared(timestamp); + return obj; + }), + py::arg("data"), py::arg("timestamp")) + .def_property_readonly( + "data", [](const JointStateOutputRecordT& self) -> std::shared_ptr { return self.data; }) + .def_readonly("timestamp", &JointStateOutputRecordT::timestamp); + + py::class_>(m, "JointStateOutputTrackedT") + .def(py::init<>()) + .def(py::init( + [](const JointStateOutputT& data) + { + auto obj = std::make_shared(); + obj->data = std::make_shared(data); + return obj; + }), + py::arg("data")) + .def_property_readonly( + "data", [](const JointStateOutputTrackedT& self) -> std::shared_ptr { return self.data; }) + .def("__repr__", + [](const JointStateOutputTrackedT& self) { + return std::string("JointStateOutputTrackedT(data=") + (self.data ? "JointStateOutput(...)" : "None") + + ")"; + }); +} + +} // namespace core diff --git a/src/core/schema/python/schema_init.py b/src/core/schema/python/schema_init.py index 3f3aeb108..478847dad 100644 --- a/src/core/schema/python/schema_init.py +++ b/src/core/schema/python/schema_init.py @@ -35,6 +35,11 @@ Generic3AxisPedalOutput, Generic3AxisPedalOutputTrackedT, Generic3AxisPedalOutputRecord, + # Joint-state types (generic joint-space devices: leader arms, exoskeletons, ...). + JointState, + JointStateOutput, + JointStateOutputTrackedT, + JointStateOutputRecord, # Message channel types. MessageChannelMessages, MessageChannelMessagesTrackedT, @@ -82,6 +87,11 @@ "Generic3AxisPedalOutput", "Generic3AxisPedalOutputTrackedT", "Generic3AxisPedalOutputRecord", + # Joint-state types (generic joint-space devices). + "JointState", + "JointStateOutput", + "JointStateOutputTrackedT", + "JointStateOutputRecord", # Message channel types. "MessageChannelMessages", "MessageChannelMessagesTrackedT", diff --git a/src/core/schema/python/schema_module.cpp b/src/core/schema/python/schema_module.cpp index e20dae586..b08e3bfcd 100644 --- a/src/core/schema/python/schema_module.cpp +++ b/src/core/schema/python/schema_module.cpp @@ -10,6 +10,7 @@ #include "full_body_bindings.h" #include "hand_bindings.h" #include "head_bindings.h" +#include "joint_state_bindings.h" #include "message_channel_bindings.h" #include "oak_bindings.h" #include "pedals_bindings.h" @@ -40,6 +41,9 @@ PYBIND11_MODULE(_schema, m) // Bind pedals types (Generic3AxisPedalOutput table). core::bind_pedals(m); + // Bind joint-state types (JointState, JointStateOutput tables) for generic joint-space devices. + core::bind_joint_state(m); + // Bind message channel types (MessageChannelMessages table). core::bind_message_channel(m); diff --git a/src/plugins/so101_leader/CMakeLists.txt b/src/plugins/so101_leader/CMakeLists.txt new file mode 100644 index 000000000..01d4c228e --- /dev/null +++ b/src/plugins/so101_leader/CMakeLists.txt @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +add_executable(so101_leader_plugin + main.cpp + so101_leader_plugin.cpp +) + +target_link_libraries(so101_leader_plugin PRIVATE + pusherio::pusherio + oxr::oxr_core + isaacteleop_schema +) + +install(TARGETS so101_leader_plugin RUNTIME DESTINATION plugins/so101_leader) +install(FILES plugin.yaml README.md DESTINATION plugins/so101_leader) diff --git a/src/plugins/so101_leader/README.md b/src/plugins/so101_leader/README.md new file mode 100644 index 000000000..73fe08632 --- /dev/null +++ b/src/plugins/so101_leader/README.md @@ -0,0 +1,31 @@ + + +# SO-101 Leader Arm plugin + +Streams the SO-101 (5-DOF arm + gripper) leader joint angles as a `JointStateOutput` FlatBuffer +over the OpenXR tensor transport, using the generic **joint-space device** path +(`JointStateTracker` / `JointStateSource` / `JointStateRetargeter`). + +The SO-101 reads 6 Feetech STS3215 bus servos over a serial port. To keep the example +hardware-free and headless, the plugin ships a **synthetic backend** by default; the real +Feetech/serial read is the marked seam in `So101LeaderPlugin::read_hardware()`. + +## Run + +```bash +# Synthetic backend (no hardware): +./install/plugins/so101_leader/so101_leader_plugin + +# With a serial device path + custom collection id (real backend is a TODO seam): +./install/plugins/so101_leader/so101_leader_plugin /dev/ttyACM0 so101_leader +``` + +The consumer side creates a `JointStateTracker("so101_leader")` (via +`JointStateSource(name=..., collection_id="so101_leader", joint_names=[...])`) on the same +`collection_id`. See `examples/teleop/python/joint_space_device_example.py` for the retargeting +pipeline (joint-mirror and task-space EE modes). + +DOF order / names: `shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper`. diff --git a/src/plugins/so101_leader/main.cpp b/src/plugins/so101_leader/main.cpp new file mode 100644 index 000000000..9450a2a18 --- /dev/null +++ b/src/plugins/so101_leader/main.cpp @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "so101_leader_plugin.hpp" + +#include +#include +#include +#include +#include + +using namespace plugins::so101_leader; + +int main(int argc, char** argv) +try +{ + // Empty device_path selects the synthetic backend (no hardware required). + const std::string device_path = (argc > 1) ? argv[1] : ""; + const std::string collection_id = (argc > 2) ? argv[2] : "so101_leader"; + + std::cout << "SO-101 Leader Arm (device: " << (device_path.empty() ? "" : device_path) + << ", collection: " << collection_id << ")" << std::endl; + + So101LeaderPlugin plugin(device_path, collection_id); + + // Push joint state at 90 Hz. + const auto frame_duration = std::chrono::nanoseconds(1000000000 / 90); + const auto program_start = std::chrono::steady_clock::now(); + std::size_t frame_count = 0; + + while (true) + { + plugin.update(); + frame_count++; + std::this_thread::sleep_until(program_start + frame_duration * frame_count); + } + + return 0; +} +catch (const std::exception& e) +{ + std::cerr << argv[0] << ": " << e.what() << std::endl; + return 1; +} +catch (...) +{ + std::cerr << argv[0] << ": Unknown error" << std::endl; + return 1; +} diff --git a/src/plugins/so101_leader/plugin.yaml b/src/plugins/so101_leader/plugin.yaml new file mode 100644 index 000000000..4cf0c9eb9 --- /dev/null +++ b/src/plugins/so101_leader/plugin.yaml @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +name: so101_leader +description: "SO-101 (5-DOF + gripper) leader arm joint angles as JointStateOutput" +command: "./so101_leader_plugin" +version: "1.0.0" +devices: + - path: "/arm/so101_leader" + type: "joint_state" + description: "SO-101 leader arm (6 Feetech STS3215 servos); synthetic backend by default" diff --git a/src/plugins/so101_leader/so101_leader_plugin.cpp b/src/plugins/so101_leader/so101_leader_plugin.cpp new file mode 100644 index 000000000..abb94db30 --- /dev/null +++ b/src/plugins/so101_leader/so101_leader_plugin.cpp @@ -0,0 +1,116 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "so101_leader_plugin.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace plugins +{ +namespace so101_leader +{ + +namespace +{ + +// Must agree with JointStateTracker::DEFAULT_MAX_FLATBUFFER_SIZE on the consumer side; sizes the +// fixed tensor buffer (6 named joints + optional channels fit comfortably). +constexpr size_t kMaxFlatbufferSize = 4096; + +// SO-101 DOF order (matches Simulation/SO101/so101_new_calib.urdf and the schema name keys). +constexpr std::array kJointNames = { "shoulder_pan", "shoulder_lift", "elbow_flex", + "wrist_flex", "wrist_roll", "gripper" }; + +constexpr double kSynthAmplitude = 0.6; // [rad] arm-joint motion amplitude for the synthetic signal +constexpr double kSynthPeriodFrames = 90.0; // one cycle per ~1 s at 90 Hz + +} // namespace + +So101LeaderPlugin::So101LeaderPlugin(const std::string& device_path, const std::string& collection_id) + : device_path_(device_path), + collection_id_(collection_id), + session_(std::make_shared("So101LeaderPlugin", core::SchemaPusher::get_required_extensions())), + pusher_(session_->get_handles(), + core::SchemaPusherConfig{ .collection_id = collection_id, + .max_flatbuffer_size = kMaxFlatbufferSize, + .tensor_identifier = "joint_state", + .localized_name = "SO-101 Leader Arm", + .app_name = "So101LeaderPlugin" }) +{ + // This reference ships the synthetic backend only; the real Feetech read is a seam in + // read_hardware(). A device_path is accepted for that future backend but ignored for now. + if (!device_path_.empty()) + { + std::cout << "So101LeaderPlugin: device path " << device_path_ + << " given, but the Feetech serial backend is not yet implemented; " + "using synthetic data (see read_hardware())" + << std::endl; + } + else + { + std::cout << "So101LeaderPlugin: using synthetic joint backend" << std::endl; + } +} + +So101LeaderPlugin::~So101LeaderPlugin() = default; + +void So101LeaderPlugin::read_hardware() +{ + // SEAM: real hardware read goes here. + // + // For the SO-101 leader this reads the 6 Feetech STS3215 bus servos over `device_path_` + // (using LeRobot's calibration to convert ticks -> radians) into positions_, in kJointNames + // order. Until that is wired up, synthesize a smooth, phase-shifted trajectory so the full + // device -> tracker -> retargeter path can run with no hardware. + const double phase = 2.0 * std::numbers::pi * static_cast(frame_) / kSynthPeriodFrames; + for (size_t i = 0; i < kJointNames.size() - 1; ++i) + { + positions_[i] = kSynthAmplitude * std::sin(phase + 0.5 * static_cast(i)); + } + // Gripper: normalized open/close oscillation in [0, 1]. + positions_[kJointNames.size() - 1] = 0.5 * (1.0 + std::sin(phase)); +} + +void So101LeaderPlugin::push_current_state() +{ + core::JointStateOutputT out; + out.device_id = collection_id_; + out.has_velocity = false; + out.has_effort = false; + out.ee_pose_valid = false; + for (size_t i = 0; i < kJointNames.size(); ++i) + { + auto joint = std::make_shared(); + joint->name = kJointNames[i]; + joint->position = static_cast(positions_[i]); + joint->valid = true; + out.joints.push_back(std::move(joint)); + } + + const auto sample_time_ns = core::os_monotonic_now_ns(); + + flatbuffers::FlatBufferBuilder builder(kMaxFlatbufferSize); + auto offset = core::JointStateOutput::Pack(builder, &out); + builder.Finish(offset); + pusher_.push_buffer(builder.GetBufferPointer(), builder.GetSize(), sample_time_ns, sample_time_ns); +} + +void So101LeaderPlugin::update() +{ + read_hardware(); + push_current_state(); + ++frame_; +} + +} // namespace so101_leader +} // namespace plugins diff --git a/src/plugins/so101_leader/so101_leader_plugin.hpp b/src/plugins/so101_leader/so101_leader_plugin.hpp new file mode 100644 index 000000000..85df319f6 --- /dev/null +++ b/src/plugins/so101_leader/so101_leader_plugin.hpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include +#include + +namespace core +{ +class OpenXRSession; +} + +namespace plugins +{ +namespace so101_leader +{ + +/*! + * @brief Streams SO-101 (5-DOF + gripper) leader-arm joint angles as ``JointStateOutput`` via + * OpenXR ``SchemaPusher``, on the generic joint-space device path. + * + * The SO-101 reads 6 Feetech STS3215 bus servos over a serial port (LeRobot's ``FeetechMotorsBus`` + * + calibration). To keep the example hardware-free and headless, this plugin ships a + * **synthetic backend** that emits a smooth joint trajectory; ``read_hardware()`` is the marked + * seam where a real serial/Feetech read replaces the synthetic values. + */ +class So101LeaderPlugin +{ +public: + /*! + * @param device_path Serial device path (e.g. /dev/ttyACM0) for the real Feetech backend + * (see read_hardware()); the synthetic backend ignores it. Empty for synthetic-only. + * @param collection_id Tensor collection id; must match the consumer's JointStateTracker. + * Also used as the JointStateOutput.device_id. + */ + So101LeaderPlugin(const std::string& device_path, const std::string& collection_id); + ~So101LeaderPlugin(); + + void update(); + +private: + // Fill positions_ (in kJointNames order) with the latest joint angles. This reference ships a + // synthetic trajectory; SEAM: replace the body with a real Feetech/serial read for hardware. + void read_hardware(); + void push_current_state(); + + std::string device_path_; + std::string collection_id_; + int64_t frame_ = 0; + double positions_[6] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 }; + + std::shared_ptr session_; + core::SchemaPusher pusher_; +}; + +} // namespace so101_leader +} // namespace plugins diff --git a/src/retargeters/CMakeLists.txt b/src/retargeters/CMakeLists.txt index 381c23c9f..a802e2059 100644 --- a/src/retargeters/CMakeLists.txt +++ b/src/retargeters/CMakeLists.txt @@ -10,6 +10,8 @@ if(BUILD_PYTHON_BINDINGS) "${CMAKE_BINARY_DIR}/python_package/$/isaacteleop/retargeters/__pycache__" COMMAND ${CMAKE_COMMAND} -E rm -rf "${CMAKE_BINARY_DIR}/python_package/$/isaacteleop/retargeters/G1/__pycache__" + COMMAND ${CMAKE_COMMAND} -E rm -rf + "${CMAKE_BINARY_DIR}/python_package/$/isaacteleop/retargeters/joint_space/__pycache__" COMMAND ${CMAKE_COMMAND} -E rm -f "${CMAKE_BINARY_DIR}/python_package/$/isaacteleop/retargeters/CMakeLists.txt" COMMAND ${CMAKE_COMMAND} -E rm -f diff --git a/src/retargeters/__init__.py b/src/retargeters/__init__.py index afd470c5b..24a6d2569 100644 --- a/src/retargeters/__init__.py +++ b/src/retargeters/__init__.py @@ -16,6 +16,7 @@ - LocomotionRootCmdRetargeter: Locomotion from controller inputs - FootPedalRootCmdRetargeter: Root command from 3-axis foot pedal (horizontal/vertical + rudder) - GripperRetargeter: Pinch-based gripper control + - JointStateRetargeter: Generic joint-space device (leader arm, exoskeleton) -> joint or EE action - SharpaHandRetargeter: Pinocchio/Pink IK-based retargeting for Sharpa hand - SharpaBiManualRetargeter: Bimanual version of SharpaHandRetargeter - Se3AbsRetargeter: Absolute EE pose control @@ -98,6 +99,17 @@ # .gripper_retargeter "GripperRetargeter": (".gripper_retargeter", "GripperRetargeter", None), "GripperRetargeterConfig": (".gripper_retargeter", "GripperRetargeterConfig", None), + # .joint_space (generic joint-space devices: leader arms, exoskeletons, ...) + "JointStateRetargeter": ( + ".joint_space.joint_state_retargeter", + "JointStateRetargeter", + None, + ), + "JointStateRetargeterConfig": ( + ".joint_space.joint_state_retargeter", + "JointStateRetargeterConfig", + None, + ), # .se3_retargeter (requires retargeters-lite extra: scipy) "Se3AbsRetargeter": (".se3_retargeter", "Se3AbsRetargeter", "retargeters-lite"), "Se3RelRetargeter": (".se3_retargeter", "Se3RelRetargeter", "retargeters-lite"), @@ -189,6 +201,9 @@ def __getattr__(name: str): # Manipulator retargeters "GripperRetargeter", "GripperRetargeterConfig", + # Generic joint-space device retargeters (leader arms, exoskeletons, ...) + "JointStateRetargeter", + "JointStateRetargeterConfig", "Se3AbsRetargeter", "Se3RelRetargeter", "Se3RetargeterConfig", diff --git a/src/retargeters/joint_space/__init__.py b/src/retargeters/joint_space/__init__.py new file mode 100644 index 000000000..52a7a9daf --- /dev/null +++ b/src/retargeters/joint_space/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/retargeters/joint_space/joint_state_retargeter.py b/src/retargeters/joint_space/joint_state_retargeter.py new file mode 100644 index 000000000..28b2059e2 --- /dev/null +++ b/src/retargeters/joint_space/joint_state_retargeter.py @@ -0,0 +1,327 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Generic joint-space device retargeter (leader arms, exoskeletons, ...). + +A single retargeter that maps a name-keyed joint-state input (produced by +:class:`~isaacteleop.retargeting_engine.deviceio_source_nodes.JointStateSource`, which converts +the ``JointStateOutput`` FlatBuffer schema) into an Isaac Lab action, in one of two modes: + +* ``mode="joint"`` -- pass the device joints straight through to robot joint targets, remapped + by name with an optional per-joint affine (``offset + sign * scale * value``). This is the + lossless leader -> follower mirror used for same-kinematics teleoperation. No extra + dependencies. +* ``mode="ee_pose"`` -- forward-kinematics the device joints (via a URDF, using ``pinocchio``) + into a 7-D end-effector pose ``[x, y, z, qx, qy, qz, qw]`` plus a scalar gripper command, + for task-space / cross-embodiment teleoperation. ``pinocchio`` is imported lazily so that + ``mode="joint"`` never requires it. + +The output element names are chosen so a downstream +:class:`~isaacteleop.retargeters.TensorReorderer` can flatten them into the exact action layout +an Isaac Lab environment expects. See ``examples/teleop/python/joint_space_device_example.py`` +for an end-to-end reference. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import numpy as np + +from isaacteleop.retargeting_engine.interface import ( + BaseRetargeter, + RetargeterIOType, +) +from isaacteleop.retargeting_engine.interface.execution_events import ExecutionState +from isaacteleop.retargeting_engine.interface.retargeter_core_types import RetargeterIO +from isaacteleop.retargeting_engine.interface.tensor_group_type import ( + OptionalType, + TensorGroupType, +) +from isaacteleop.retargeting_engine.tensor_types import ( + DLDataType, + FloatType, + NDArrayType, + TransformMatrix, +) + +# Output group / element keys (single source of truth for the pipeline wiring). +JOINT_TARGETS_KEY = "joint_targets" +EE_POSE_KEY = "ee_pose" +GRIPPER_COMMAND_KEY = "gripper_command" +GRIPPER_ELEMENT_LABEL = "gripper_value" + +_IDENTITY_QUAT = np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float32) +_MATRIX_INDEX = 0 + + +@dataclass +class JointStateRetargeterConfig: + """Configuration for :class:`JointStateRetargeter`. + + Args: + device_joints: Ordered device DOF names, matching ``JointStateOutput.joints[*].name`` + and the names declared by the upstream ``JointStateSource``. + target_joints: ``joint`` mode only -- ordered robot joint names to emit, one output + element each. Defaults to ``device_joints`` (identity mirror). + joint_map: ``joint`` mode only -- ``{device_name: target_name}`` overrides; any target + not covered maps from the device joint of the same name. + scale: ``joint`` mode only -- per-target multiplicative gain (e.g. gear ratio / unit + conversion). Defaults to 1.0. + offset: ``joint`` mode only -- per-target additive offset [rad or m]. Defaults to 0.0. + sign: ``joint`` mode only -- per-target sign (+1 / -1). Defaults to +1.0. + urdf_path: ``ee_pose`` mode only -- path to the device URDF used for forward kinematics. + ee_link: ``ee_pose`` mode only -- URDF frame/link name of the end-effector (tool) frame. + gripper_joint: ``ee_pose`` mode only -- device DOF name treated as the gripper. + gripper_open: ``ee_pose`` mode only -- gripper DOF value at fully open. When both + ``gripper_open`` and ``gripper_close`` are set, the emitted gripper command is the + normalized closedness in ``[0, 1]``; otherwise the raw value is passed through. + gripper_close: ``ee_pose`` mode only -- gripper DOF value at fully closed. + clutch: ``ee_pose`` mode only -- when true, rebase the EE position around an origin + captured on the first ``RUNNING`` frame so engaging teleop does not jump the robot; + the home is the live ``robot_ee_pos`` input when connected, else the FK pose at + engage. When false (default) the absolute FK pose is emitted. + """ + + device_joints: list[str] + target_joints: list[str] = field(default_factory=list) + joint_map: dict[str, str] = field(default_factory=dict) + scale: dict[str, float] = field(default_factory=dict) + offset: dict[str, float] = field(default_factory=dict) + sign: dict[str, float] = field(default_factory=dict) + urdf_path: str | None = None + ee_link: str | None = None + gripper_joint: str = "gripper" + gripper_open: float | None = None + gripper_close: float | None = None + clutch: bool = False + + +class JointStateRetargeter(BaseRetargeter): + """Maps a name-keyed joint-state input to a robot action in ``joint`` or ``ee_pose`` mode. + + Input (both modes): + - :data:`JOINTS` -- Optional name-keyed group with one ``FloatType`` per + ``config.device_joints`` entry (the joint positions). + + Input (``ee_pose`` mode, only when ``config.clutch``): + - :data:`ROBOT_EE_POS_INPUT` -- Optional ``world_T_ee`` 4x4 transform of the robot's + current end-effector, used to latch the clutch home on engage. + + Output (``joint`` mode): + - :data:`JOINT_TARGETS_KEY` -- one ``FloatType`` per ``config.target_joints`` entry. + + Output (``ee_pose`` mode): + - :data:`EE_POSE_KEY` -- a single 7-D ``NDArray`` ``[x, y, z, qx, qy, qz, qw]``. + - :data:`GRIPPER_COMMAND_KEY` -- a single ``FloatType`` gripper command. + + Note: + The :data:`JOINTS` input is read positionally in ``config.device_joints`` order, so the + upstream producer (e.g. ``JointStateSource``) must declare the same names in the same + order; a name mismatch is rejected by the graph's type check at ``connect`` time. + ``ee_pose`` mode ignores the schema's ``ee_pose`` field and always computes forward + kinematics from the joint positions. + """ + + JOINTS = "joints" + ROBOT_EE_POS_INPUT = "robot_ee_pos" + + def __init__( + self, name: str, mode: str, config: JointStateRetargeterConfig + ) -> None: + """Initialize the joint-space retargeter. + + Args: + name: Name identifier for this retargeter node. + mode: ``"joint"`` or ``"ee_pose"``. + config: Device / mode configuration. + """ + if mode not in ("joint", "ee_pose"): + raise ValueError(f"mode must be 'joint' or 'ee_pose', got: {mode!r}") + self._mode = mode + self._cfg = config + + if mode == "joint": + self._target_joints = list(config.target_joints) or list( + config.device_joints + ) + # Per target joint, the device joint that feeds it (inverse of joint_map, which is + # device -> target). Targets not covered map from the device joint of the same name. + self._device_for_target: dict[str, str] = { + tgt: next((d for d, t in config.joint_map.items() if t == tgt), tgt) + for tgt in self._target_joints + } + self._last_targets = np.zeros(len(self._target_joints), dtype=np.float32) + else: + if not config.urdf_path or not config.ee_link: + raise ValueError( + "ee_pose mode requires config.urdf_path and config.ee_link" + ) + self._fk = _UrdfForwardKinematics(config.urdf_path, config.ee_link) + self._origin: np.ndarray | None = None + self._home = np.zeros(3, dtype=np.float64) + self._last_pose = np.concatenate([np.zeros(3), _IDENTITY_QUAT]).astype( + np.float32 + ) + self._last_gripper = 0.0 + + super().__init__(name=name) + + # ------------------------------------------------------------------ specs + + def input_spec(self) -> RetargeterIOType: + joints_type = TensorGroupType( + self.JOINTS, [FloatType(n) for n in self._cfg.device_joints] + ) + spec: RetargeterIOType = {self.JOINTS: OptionalType(joints_type)} + if self._mode == "ee_pose" and self._cfg.clutch: + spec[self.ROBOT_EE_POS_INPUT] = OptionalType(TransformMatrix()) + return spec + + def output_spec(self) -> RetargeterIOType: + if self._mode == "joint": + return { + JOINT_TARGETS_KEY: TensorGroupType( + JOINT_TARGETS_KEY, [FloatType(n) for n in self._target_joints] + ) + } + return { + EE_POSE_KEY: TensorGroupType( + EE_POSE_KEY, + [ + NDArrayType( + "pose", shape=(7,), dtype=DLDataType.FLOAT, dtype_bits=32 + ) + ], + ), + GRIPPER_COMMAND_KEY: TensorGroupType( + GRIPPER_COMMAND_KEY, [FloatType(GRIPPER_ELEMENT_LABEL)] + ), + } + + # ---------------------------------------------------------------- compute + + def _compute_fn(self, inputs: RetargeterIO, outputs: RetargeterIO, context) -> None: + if self._mode == "joint": + self._compute_joint(inputs, outputs, context) + else: + self._compute_ee(inputs, outputs, context) + + def _read_positions(self, joints_group) -> dict[str, float]: + """Read the name-keyed joint group into a ``{device_name: position}`` dict.""" + return { + name: float(joints_group[i]) + for i, name in enumerate(self._cfg.device_joints) + } + + def _compute_joint(self, inputs, outputs, context) -> None: + if context.execution_events.reset: + self._last_targets = np.zeros(len(self._target_joints), dtype=np.float32) + + out = outputs[JOINT_TARGETS_KEY] + jin = inputs[self.JOINTS] + if jin.is_none: + for i in range(len(self._target_joints)): + out[i] = float(self._last_targets[i]) + return + + positions = self._read_positions(jin) + for i, tgt in enumerate(self._target_joints): + raw = positions.get(self._device_for_target[tgt], 0.0) + value = ( + self._cfg.offset.get(tgt, 0.0) + + self._cfg.sign.get(tgt, 1.0) * self._cfg.scale.get(tgt, 1.0) * raw + ) + self._last_targets[i] = value + out[i] = float(value) + + def _compute_ee(self, inputs, outputs, context) -> None: + running = context.execution_events.execution_state == ExecutionState.RUNNING + if context.execution_events.reset or (self._cfg.clutch and not running): + self._origin = None + self._last_pose = np.concatenate([np.zeros(3), _IDENTITY_QUAT]).astype( + np.float32 + ) + self._last_gripper = 0.0 + + ee_out = outputs[EE_POSE_KEY] + grip_out = outputs[GRIPPER_COMMAND_KEY] + jin = inputs[self.JOINTS] + if jin.is_none: + ee_out[0] = self._last_pose + grip_out[0] = self._last_gripper + return + + positions = self._read_positions(jin) + fk_pose = self._fk.solve(positions) # [x, y, z, qx, qy, qz, qw] + + if self._cfg.clutch: + if self._origin is None: + if not running: + ee_out[0] = self._last_pose + grip_out[0] = self._compute_gripper(positions) + return + self._origin = fk_pose[:3].copy() + self._home = self._latch_home(inputs, fk_pose) + position = self._home + (fk_pose[:3] - self._origin) + else: + position = fk_pose[:3] + + self._last_pose = np.concatenate([position, fk_pose[3:7]]).astype(np.float32) + self._last_gripper = self._compute_gripper(positions) + ee_out[0] = self._last_pose + grip_out[0] = self._last_gripper + + def _latch_home(self, inputs: RetargeterIO, fk_pose: np.ndarray) -> np.ndarray: + """Home for clutch rebasing: live robot EE position if connected, else the FK pose.""" + ee_inp = inputs.get(self.ROBOT_EE_POS_INPUT) + if ee_inp is not None and not ee_inp.is_none: + world_T_ee = np.from_dlpack(ee_inp[_MATRIX_INDEX]).astype(np.float64) + return world_T_ee[:3, 3].copy() + return fk_pose[:3].copy() + + def _compute_gripper(self, positions: dict[str, float]) -> float: + raw = positions.get(self._cfg.gripper_joint, 0.0) + lo, hi = self._cfg.gripper_open, self._cfg.gripper_close + if lo is None or hi is None or hi == lo: + return float(raw) + c = (raw - lo) / (hi - lo) + return float(min(1.0, max(0.0, c))) + + +class _UrdfForwardKinematics: + """Lazy ``pinocchio`` forward-kinematics helper for a URDF end-effector frame.""" + + def __init__(self, urdf_path: str, ee_link: str) -> None: + try: + import pinocchio as pin # noqa: F401 + except ImportError as exc: + raise ModuleNotFoundError( + "JointStateRetargeter(mode='ee_pose') requires pinocchio.\n" + "Install it with: pip install 'isaacteleop[retargeters]' (or: pip install pin)" + ) from exc + self._pin = pin + self._model = pin.buildModelFromUrdf(urdf_path) + self._data = self._model.createData() + if not self._model.existFrame(ee_link): + raise ValueError(f"ee_link {ee_link!r} not found in URDF {urdf_path!r}") + self._frame_id = self._model.getFrameId(ee_link) + + def solve(self, positions: dict[str, float]) -> np.ndarray: + """Forward-kinematics the named joint positions to a 7-D ``[x,y,z,qx,qy,qz,qw]`` pose. + + Assumes a fixed-base model of single-DOF joints (revolute/prismatic) -- the common case + for leader arms and exoskeletons -- writing one configuration value per named joint. + Names not present in the URDF (e.g. the gripper) are ignored for the EE pose. + """ + pin = self._pin + q = pin.neutral(self._model) + for name, value in positions.items(): + if self._model.existJointName(name): + jid = self._model.getJointId(name) + q[self._model.joints[jid].idx_q] = value + pin.forwardKinematics(self._model, self._data, q) + pin.updateFramePlacements(self._model, self._data) + return np.asarray( + pin.SE3ToXYZQUAT(self._data.oMf[self._frame_id]), dtype=np.float32 + ) From 5f9b6aba10ca9e8a1f16d10a23601190ea560470 Mon Sep 17 00:00:00 2001 From: Rafael Wiltz Date: Thu, 11 Jun 2026 10:59:24 -0400 Subject: [PATCH 2/5] feat: add FEETECH serial backend to SO-101 leader plugin Wire the SO-101 leader plugin to real hardware so the joint-space pipeline can run end-to-end with a physical arm. Add FeetechBus, a minimal half-duplex client for the FEETECH SMS/STS bus servos (STS3215) that speaks the same wire protocol as the FEETECH SCServo SDK / LeRobot FeetechMotorsBus without an SDK dependency (POSIX termios only): disable torque so the leader can be back-driven, then read Present_Position each frame and convert ticks to radians with per-joint calibration. When no serial device path is given the plugin keeps its synthetic trajectory, so CI and the headless example still run hardware-free. A device path activates the live backend; calibration (servo id, sign, home tick) comes from an optional file and defaults to ids 1..6 / +1 / center 2048. The serial backend is POSIX-only; Windows compiles to a throwing stub and uses the synthetic fallback. --- docs/source/device/joint_space.rst | 19 +- src/plugins/so101_leader/CMakeLists.txt | 1 + src/plugins/so101_leader/README.md | 54 ++- src/plugins/so101_leader/feetech_bus.cpp | 312 ++++++++++++++++++ src/plugins/so101_leader/feetech_bus.hpp | 61 ++++ src/plugins/so101_leader/main.cpp | 7 +- src/plugins/so101_leader/plugin.yaml | 2 +- .../so101_leader/so101_leader_plugin.cpp | 135 ++++++-- .../so101_leader/so101_leader_plugin.hpp | 47 ++- 9 files changed, 594 insertions(+), 44 deletions(-) create mode 100644 src/plugins/so101_leader/feetech_bus.cpp create mode 100644 src/plugins/so101_leader/feetech_bus.hpp diff --git a/docs/source/device/joint_space.rst b/docs/source/device/joint_space.rst index 1edc34269..d0ea3d99d 100644 --- a/docs/source/device/joint_space.rst +++ b/docs/source/device/joint_space.rst @@ -27,7 +27,8 @@ At a glance optional velocity/effort) and ``JointStateOutput`` (a vector of joints + ``device_id``). * - Plugin - :code-dir:`src/plugins/so101_leader` -- pushes ``JointStateOutput`` via ``SchemaPusher``. - Ships a synthetic backend; the real Feetech/serial read is a marked seam. + Reads the FEETECH STS3215 servos over serial (``FeetechBus``); synthetic fallback when no + device path is given. * - Tracker - ``JointStateTracker`` (facade) with live (``LiveJointStateTrackerImpl``) and MCAP-replay (``ReplayJointStateTrackerImpl``) backends, registered in the live/replay factories. @@ -72,17 +73,23 @@ The SO-101 leader plugin ------------------------ ``so101_leader`` reads the six SO-101 servos (``shoulder_pan, shoulder_lift, elbow_flex, -wrist_flex, wrist_roll, gripper``) and pushes them to a tensor collection. To keep the example -hardware-free and headless it ships a **synthetic backend**; the real Feetech read (via LeRobot's -``FeetechMotorsBus`` + calibration) is the marked seam in ``So101LeaderPlugin::read_hardware()``. +wrist_flex, wrist_roll, gripper``) and pushes them to a tensor collection. With a serial device +path it talks to the FEETECH STS3215 bus servos directly via ``FeetechBus`` -- the same SMS/STS +wire protocol the FEETECH SCServo SDK / LeRobot's ``FeetechMotorsBus`` use, with no SDK dependency: +it disables torque (so the leader can be back-driven) and reads ``Present_Position`` each frame, +converting ticks to radians with per-joint calibration. With no device path it falls back to a +**synthetic** trajectory so the pipeline runs hardware-free (CI and the headless example). .. code-block:: bash # Synthetic backend (no hardware), default collection id "so101_leader": ./install/plugins/so101_leader/so101_leader_plugin - # Reserved for the real serial backend + a custom collection id: - ./install/plugins/so101_leader/so101_leader_plugin /dev/ttyACM0 so101_leader + # Real SO-101 leader on a serial port (Linux), optional calibration file: + ./install/plugins/so101_leader/so101_leader_plugin /dev/ttyACM0 so101_leader so101_leader.calib + +See the :code-file:`plugin README ` for hardware setup +(unique servo ids, gear removal, back-driving) and the calibration file format. The consumer side creates a ``JointStateSource(name=..., collection_id="so101_leader", joint_names=[...])`` on the same ``collection_id``; ``TeleopSession`` discovers and polls the diff --git a/src/plugins/so101_leader/CMakeLists.txt b/src/plugins/so101_leader/CMakeLists.txt index 01d4c228e..7dfa8cdd3 100644 --- a/src/plugins/so101_leader/CMakeLists.txt +++ b/src/plugins/so101_leader/CMakeLists.txt @@ -4,6 +4,7 @@ add_executable(so101_leader_plugin main.cpp so101_leader_plugin.cpp + feetech_bus.cpp ) target_link_libraries(so101_leader_plugin PRIVATE diff --git a/src/plugins/so101_leader/README.md b/src/plugins/so101_leader/README.md index 73fe08632..f4ec5088e 100644 --- a/src/plugins/so101_leader/README.md +++ b/src/plugins/so101_leader/README.md @@ -9,9 +9,15 @@ Streams the SO-101 (5-DOF arm + gripper) leader joint angles as a `JointStateOut over the OpenXR tensor transport, using the generic **joint-space device** path (`JointStateTracker` / `JointStateSource` / `JointStateRetargeter`). -The SO-101 reads 6 Feetech STS3215 bus servos over a serial port. To keep the example -hardware-free and headless, the plugin ships a **synthetic backend** by default; the real -Feetech/serial read is the marked seam in `So101LeaderPlugin::read_hardware()`. +The SO-101 leader is 6 FEETECH STS3215 bus servos on a half-duplex TTL serial bus (the same +hardware [TheRobotStudio/SO-ARM100](https://github.com/TheRobotStudio/SO-ARM100) and HuggingFace +LeRobot drive via the FEETECH SCServo SDK). `FeetechBus` (`feetech_bus.{hpp,cpp}`) speaks that +SMS/STS wire protocol directly — no SDK dependency — and implements just what a *leader* needs: +disable torque so the arm can be back-driven by hand, then read `Present_Position` (register 56, +4096 ticks / 360°) each frame. Ticks are converted to radians with per-joint calibration. + +When no serial device is given, the plugin falls back to a **synthetic** trajectory so the +device → tracker → retargeter pipeline runs with no hardware (CI and the headless example). ## Run @@ -19,10 +25,48 @@ Feetech/serial read is the marked seam in `So101LeaderPlugin::read_hardware()`. # Synthetic backend (no hardware): ./install/plugins/so101_leader/so101_leader_plugin -# With a serial device path + custom collection id (real backend is a TODO seam): -./install/plugins/so101_leader/so101_leader_plugin /dev/ttyACM0 so101_leader +# Real SO-101 leader on a serial port (Linux), default collection id "so101_leader": +./install/plugins/so101_leader/so101_leader_plugin /dev/ttyACM0 + +# ... with a custom collection id and a calibration file: +./install/plugins/so101_leader/so101_leader_plugin /dev/ttyACM0 so101_leader so101_leader.calib ``` +Args are positional: `[device_path] [collection_id] [calibration_file]`. The serial backend is +Linux/macOS only (POSIX `termios`); the STS bus runs at 1,000,000 bps by default. + +### Hardware setup (per SO-ARM100 / LeRobot) + +- Assemble the leader arm and **remove the gearbox gears** so the joints move freely and only the + position encoders are used (the plugin disables torque on connect, but the leader is meant to be + back-driven). +- Give each servo a unique id `1..6` on the bus and set them all to the same baud rate. Use the + FEETECH tool (`FT_SCServo_Debug_Qt` on Ubuntu) or LeRobot's `lerobot-setup-motors` to do this. +- Make sure your user can access the serial device (e.g. add it to the `dialout` group). + +### Calibration file (optional) + +Whitespace-separated, one joint per line; `#` starts a comment. Columns: +`name servo_id sign(+1/-1) home_ticks(0..4095)`. The conversion is +`angle [rad] = sign * (ticks - home_ticks) * 2π / 4096`. + +``` +# joint id sign home_ticks +shoulder_pan 1 1 2048 +shoulder_lift 2 1 2048 +elbow_flex 3 1 2048 +wrist_flex 4 1 2048 +wrist_roll 5 1 2048 +gripper 6 1 2048 +``` + +Defaults (no file): ids `1..6` in DOF order, `sign +1`, `home_ticks 2048` (servo center). Set +`home_ticks` to each servo's raw `Present_Position` at the joint's URDF-zero pose, and `sign` to +`-1` for any joint whose servo turns opposite the URDF convention (LeRobot's `drive_mode`). For +**joint-mirror** mode the retargeter's per-joint `offset`/`sign`/`scale` can also absorb +calibration; for **EE (URDF FK)** mode the joint angles must already match the URDF, so set +`home_ticks`/`sign` here. + The consumer side creates a `JointStateTracker("so101_leader")` (via `JointStateSource(name=..., collection_id="so101_leader", joint_names=[...])`) on the same `collection_id`. See `examples/teleop/python/joint_space_device_example.py` for the retargeting diff --git a/src/plugins/so101_leader/feetech_bus.cpp b/src/plugins/so101_leader/feetech_bus.cpp new file mode 100644 index 000000000..e3b7599c1 --- /dev/null +++ b/src/plugins/so101_leader/feetech_bus.cpp @@ -0,0 +1,312 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "feetech_bus.hpp" + +#include +#include + +#ifndef _WIN32 + +# include + +# include +# include +# include +# include +# include +# include + +namespace plugins +{ +namespace so101_leader +{ + +namespace +{ + +// FEETECH SMS/STS protocol constants. +constexpr uint8_t kHeader = 0xFF; +constexpr uint8_t kInstRead = 0x02; +constexpr uint8_t kInstWrite = 0x03; +constexpr uint8_t kRegTorqueEnable = 40; // 1 byte +constexpr uint8_t kRegPresentPosition = 56; // 2 bytes, little-endian (SMS/STS) +constexpr int kReadTimeoutMs = 20; + +// Map a numeric baud rate to the matching termios speed constant. Only the rates a FEETECH bus +// realistically uses are supported; anything else throws (rather than silently mis-configuring). +speed_t to_speed(int baud) +{ + switch (baud) + { +# ifdef B1000000 + case 1000000: + return B1000000; +# endif +# ifdef B500000 + case 500000: + return B500000; +# endif + case 115200: + return B115200; + case 57600: + return B57600; + case 38400: + return B38400; + default: + throw std::runtime_error("FeetechBus: unsupported baud rate " + std::to_string(baud) + + " (STS servos default to 1000000)"); + } +} + +} // namespace + +FeetechBus::FeetechBus(const std::string& port, int baud) +{ + fd_ = ::open(port.c_str(), O_RDWR | O_NOCTTY | O_NONBLOCK); + if (fd_ < 0) + { + throw std::runtime_error("FeetechBus: cannot open '" + port + "': " + std::strerror(errno)); + } + + termios tty{}; + if (::tcgetattr(fd_, &tty) != 0) + { + const std::string msg = std::strerror(errno); + ::close(fd_); + fd_ = -1; + throw std::runtime_error("FeetechBus: tcgetattr failed on '" + port + "': " + msg); + } + + ::cfmakeraw(&tty); + const speed_t spd = to_speed(baud); + ::cfsetispeed(&tty, spd); + ::cfsetospeed(&tty, spd); + + // 8N1, local, receiver enabled, no flow control. select()-driven reads (VMIN/VTIME = 0). + tty.c_cflag |= (CLOCAL | CREAD); + tty.c_cflag &= ~CSTOPB; + tty.c_cflag &= ~PARENB; + tty.c_cflag &= ~CSIZE; + tty.c_cflag |= CS8; +# ifdef CRTSCTS + tty.c_cflag &= ~CRTSCTS; +# endif + tty.c_cc[VMIN] = 0; + tty.c_cc[VTIME] = 0; + + if (::tcsetattr(fd_, TCSANOW, &tty) != 0) + { + const std::string msg = std::strerror(errno); + ::close(fd_); + fd_ = -1; + throw std::runtime_error("FeetechBus: tcsetattr failed on '" + port + "': " + msg); + } + + ::tcflush(fd_, TCIOFLUSH); +} + +FeetechBus::~FeetechBus() +{ + if (fd_ >= 0) + { + ::close(fd_); + } +} + +bool FeetechBus::write_packet(uint8_t id, uint8_t instruction, const uint8_t* params, uint8_t param_count) +{ + const uint8_t length = static_cast(param_count + 2); + std::vector pkt; + pkt.reserve(static_cast(param_count) + 6); + pkt.push_back(kHeader); + pkt.push_back(kHeader); + pkt.push_back(id); + pkt.push_back(length); + pkt.push_back(instruction); + + uint32_t checksum = id + length + instruction; + for (uint8_t i = 0; i < param_count; ++i) + { + pkt.push_back(params[i]); + checksum += params[i]; + } + pkt.push_back(static_cast(~checksum & 0xFF)); + + // Drop any stale/echoed bytes from a previous transaction before issuing this one. + ::tcflush(fd_, TCIFLUSH); + + size_t written = 0; + while (written < pkt.size()) + { + const ssize_t n = ::write(fd_, pkt.data() + written, pkt.size() - written); + if (n < 0) + { + if (errno == EAGAIN || errno == EINTR) + { + continue; + } + return false; + } + written += static_cast(n); + } + ::tcdrain(fd_); + return true; +} + +bool FeetechBus::read_byte(uint8_t& out, int timeout_ms) +{ + fd_set rfds; + FD_ZERO(&rfds); + FD_SET(fd_, &rfds); + timeval tv{}; + tv.tv_sec = timeout_ms / 1000; + tv.tv_usec = (timeout_ms % 1000) * 1000; + + const int ready = ::select(fd_ + 1, &rfds, nullptr, nullptr, &tv); + if (ready <= 0) + { + return false; // timeout or error + } + const ssize_t n = ::read(fd_, &out, 1); + return n == 1; +} + +bool FeetechBus::read_status(uint8_t expected_id, uint8_t* data_out, uint8_t expected_data_len) +{ + // Sync to the 0xFF 0xFF header (tolerates leading noise from bus turnaround). + int prev = -1; + bool synced = false; + for (int i = 0; i < 64 && !synced; ++i) + { + uint8_t b = 0; + if (!read_byte(b, kReadTimeoutMs)) + { + return false; + } + if (prev == kHeader && b == kHeader) + { + synced = true; + } + prev = b; + } + if (!synced) + { + return false; + } + + uint8_t id = 0; + uint8_t length = 0; + if (!read_byte(id, kReadTimeoutMs) || !read_byte(length, kReadTimeoutMs)) + { + return false; + } + // length = error(1) + data(length-2) + checksum(1); guard against malformed lengths. + if (length < 2 || length > 16) + { + return false; + } + + std::vector rest(length); // error + data... + checksum + for (uint8_t i = 0; i < length; ++i) + { + if (!read_byte(rest[i], kReadTimeoutMs)) + { + return false; + } + } + + uint32_t checksum = id + length; + for (uint8_t i = 0; i + 1 < length; ++i) + { + checksum += rest[i]; + } + const uint8_t expected_checksum = static_cast(~checksum & 0xFF); + if (expected_checksum != rest[length - 1] || id != expected_id) + { + return false; + } + + const uint8_t data_len = static_cast(length - 2); + if (data_len != expected_data_len) + { + return false; + } + for (uint8_t i = 0; i < expected_data_len; ++i) + { + data_out[i] = rest[i + 1]; // skip the error byte at rest[0] + } + return true; +} + +bool FeetechBus::read_position(uint8_t id, uint16_t& ticks_out) +{ + const uint8_t params[2] = { kRegPresentPosition, 0x02 }; + if (!write_packet(id, kInstRead, params, 2)) + { + return false; + } + uint8_t data[2] = { 0, 0 }; + if (!read_status(id, data, 2)) + { + return false; + } + ticks_out = static_cast(data[0]) | static_cast(data[1] << 8); + return true; +} + +bool FeetechBus::disable_torque(uint8_t id) +{ + const uint8_t params[2] = { kRegTorqueEnable, 0x00 }; + if (!write_packet(id, kInstWrite, params, 2)) + { + return false; + } + return read_status(id, nullptr, 0); +} + +} // namespace so101_leader +} // namespace plugins + +#else // _WIN32 + +namespace plugins +{ +namespace so101_leader +{ + +FeetechBus::FeetechBus(const std::string& /*port*/, int /*baud*/) +{ + throw std::runtime_error( + "FeetechBus: the serial backend is only implemented on POSIX " + "(Linux/macOS); run the SO-101 leader on Linux, or omit the device " + "path to use the synthetic backend"); +} + +FeetechBus::~FeetechBus() = default; + +bool FeetechBus::write_packet(uint8_t, uint8_t, const uint8_t*, uint8_t) +{ + return false; +} +bool FeetechBus::read_status(uint8_t, uint8_t*, uint8_t) +{ + return false; +} +bool FeetechBus::read_byte(uint8_t&, int) +{ + return false; +} +bool FeetechBus::read_position(uint8_t, uint16_t&) +{ + return false; +} +bool FeetechBus::disable_torque(uint8_t) +{ + return false; +} + +} // namespace so101_leader +} // namespace plugins + +#endif // _WIN32 diff --git a/src/plugins/so101_leader/feetech_bus.hpp b/src/plugins/so101_leader/feetech_bus.hpp new file mode 100644 index 000000000..290751cce --- /dev/null +++ b/src/plugins/so101_leader/feetech_bus.hpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace plugins +{ +namespace so101_leader +{ + +/*! + * @brief Minimal half-duplex serial client for FEETECH SMS/STS bus servos (e.g. the STS3215 used + * by the SO-101 / SO-ARM100). + * + * Implements the same wire protocol as the FEETECH SCServo SDK / LeRobot's ``FeetechMotorsBus``, + * but only the subset a *leader* arm needs: + * - read ``Present_Position`` (register 56, 2 bytes, little-endian for the SMS/STS series), and + * - disable torque (register 40) so the arm can be back-driven by hand. + * + * Wire format (Dynamixel-like): ``FF FF ID LEN INST PARAM... CHK``, with + * ``LEN = param_count + 2`` and ``CHK = ~(ID + LEN + INST + PARAMS) & 0xFF``. Default bus speed is + * 1,000,000 bps, 8N1 (the STS factory default). Assumes an auto-direction USB-TTL adapter + * (e.g. FE-URT-1 / Waveshare bus-servo adapter) that does not echo transmitted bytes. + * + * POSIX only (Linux/macOS); constructing on Windows throws. + */ +class FeetechBus +{ +public: + //! Open and configure @p port (e.g. ``/dev/ttyACM0``) at @p baud. Throws ``std::runtime_error`` + //! on failure (or always, on Windows). + explicit FeetechBus(const std::string& port, int baud = 1000000); + ~FeetechBus(); + + FeetechBus(const FeetechBus&) = delete; + FeetechBus& operator=(const FeetechBus&) = delete; + FeetechBus(FeetechBus&&) = delete; + FeetechBus& operator=(FeetechBus&&) = delete; + + //! Read ``Present_Position`` [ticks, 0..4095 over 360 deg] for servo @p id. Returns false on + //! timeout / malformed response so the caller can hold the last value instead of faulting. + bool read_position(uint8_t id, uint16_t& ticks_out); + + //! Write ``Torque_Enable = 0`` (register 40) so servo @p id goes limp and can be moved by hand. + bool disable_torque(uint8_t id); + +private: + bool write_packet(uint8_t id, uint8_t instruction, const uint8_t* params, uint8_t param_count); + //! Read a status packet; copies @p expected_data_len data bytes (after the error byte) into + //! @p data_out (may be null when @p expected_data_len is 0). Validates header, id, and checksum. + bool read_status(uint8_t expected_id, uint8_t* data_out, uint8_t expected_data_len); + bool read_byte(uint8_t& out, int timeout_ms); + + int fd_ = -1; +}; + +} // namespace so101_leader +} // namespace plugins diff --git a/src/plugins/so101_leader/main.cpp b/src/plugins/so101_leader/main.cpp index 9450a2a18..e5393e9a3 100644 --- a/src/plugins/so101_leader/main.cpp +++ b/src/plugins/so101_leader/main.cpp @@ -14,14 +14,17 @@ using namespace plugins::so101_leader; int main(int argc, char** argv) try { + // Usage: so101_leader_plugin [device_path] [collection_id] [calibration_file] // Empty device_path selects the synthetic backend (no hardware required). const std::string device_path = (argc > 1) ? argv[1] : ""; const std::string collection_id = (argc > 2) ? argv[2] : "so101_leader"; + const std::string calibration_path = (argc > 3) ? argv[3] : ""; std::cout << "SO-101 Leader Arm (device: " << (device_path.empty() ? "" : device_path) - << ", collection: " << collection_id << ")" << std::endl; + << ", collection: " << collection_id + << (calibration_path.empty() ? "" : ", calibration: " + calibration_path) << ")" << std::endl; - So101LeaderPlugin plugin(device_path, collection_id); + So101LeaderPlugin plugin(device_path, collection_id, calibration_path); // Push joint state at 90 Hz. const auto frame_duration = std::chrono::nanoseconds(1000000000 / 90); diff --git a/src/plugins/so101_leader/plugin.yaml b/src/plugins/so101_leader/plugin.yaml index 4cf0c9eb9..c0dd6db30 100644 --- a/src/plugins/so101_leader/plugin.yaml +++ b/src/plugins/so101_leader/plugin.yaml @@ -8,4 +8,4 @@ version: "1.0.0" devices: - path: "/arm/so101_leader" type: "joint_state" - description: "SO-101 leader arm (6 Feetech STS3215 servos); synthetic backend by default" + description: "SO-101 leader arm (6 FEETECH STS3215 servos over serial; synthetic fallback when no device path)" diff --git a/src/plugins/so101_leader/so101_leader_plugin.cpp b/src/plugins/so101_leader/so101_leader_plugin.cpp index abb94db30..8a77b323d 100644 --- a/src/plugins/so101_leader/so101_leader_plugin.cpp +++ b/src/plugins/so101_leader/so101_leader_plugin.cpp @@ -3,6 +3,8 @@ #include "so101_leader_plugin.hpp" +#include "feetech_bus.hpp" + #include #include #include @@ -11,9 +13,12 @@ #include #include #include +#include #include #include #include +#include +#include namespace plugins { @@ -28,15 +33,21 @@ namespace constexpr size_t kMaxFlatbufferSize = 4096; // SO-101 DOF order (matches Simulation/SO101/so101_new_calib.urdf and the schema name keys). -constexpr std::array kJointNames = { "shoulder_pan", "shoulder_lift", "elbow_flex", - "wrist_flex", "wrist_roll", "gripper" }; +constexpr std::array kJointNames = { "shoulder_pan", "shoulder_lift", "elbow_flex", + "wrist_flex", "wrist_roll", "gripper" }; + +// FEETECH STS3215: 12-bit magnetic encoder, 4096 ticks per 360 deg. +constexpr double kTicksToRadians = 2.0 * std::numbers::pi / 4096.0; +constexpr int kFeetechBaud = 1000000; // STS factory default constexpr double kSynthAmplitude = 0.6; // [rad] arm-joint motion amplitude for the synthetic signal constexpr double kSynthPeriodFrames = 90.0; // one cycle per ~1 s at 90 Hz } // namespace -So101LeaderPlugin::So101LeaderPlugin(const std::string& device_path, const std::string& collection_id) +So101LeaderPlugin::So101LeaderPlugin(const std::string& device_path, + const std::string& collection_id, + const std::string& calibration_path) : device_path_(device_path), collection_id_(collection_id), session_(std::make_shared("So101LeaderPlugin", core::SchemaPusher::get_required_extensions())), @@ -47,38 +58,115 @@ So101LeaderPlugin::So101LeaderPlugin(const std::string& device_path, const std:: .localized_name = "SO-101 Leader Arm", .app_name = "So101LeaderPlugin" }) { - // This reference ships the synthetic backend only; the real Feetech read is a seam in - // read_hardware(). A device_path is accepted for that future backend but ignored for now. + // Defaults: servo ids 1..6 in DOF order, no sign flip, centered at the servo midpoint (2048). + for (int i = 0; i < kNumJoints; ++i) + { + calibration_[i] = JointCalibration{ static_cast(i + 1), 1.0, 2048 }; + } + if (!calibration_path.empty()) + { + load_calibration(calibration_path); + } + if (!device_path_.empty()) { - std::cout << "So101LeaderPlugin: device path " << device_path_ - << " given, but the Feetech serial backend is not yet implemented; " - "using synthetic data (see read_hardware())" - << std::endl; + // Throws on POSIX if the port can't be opened; throws unconditionally on Windows. + bus_ = std::make_unique(device_path_, kFeetechBaud); + std::cout << "So101LeaderPlugin: FEETECH serial backend on " << device_path_ << std::endl; + + // Leader arm: disable torque so the operator can back-drive it by hand. + for (int i = 0; i < kNumJoints; ++i) + { + if (!bus_->disable_torque(calibration_[i].servo_id)) + { + std::cerr << "So101LeaderPlugin: warning: failed to disable torque on servo " + << static_cast(calibration_[i].servo_id) << " (is it powered / on the bus?)" << std::endl; + } + } } else { - std::cout << "So101LeaderPlugin: using synthetic joint backend" << std::endl; + std::cout << "So101LeaderPlugin: using synthetic joint backend (no device path)" << std::endl; } } So101LeaderPlugin::~So101LeaderPlugin() = default; -void So101LeaderPlugin::read_hardware() +void So101LeaderPlugin::load_calibration(const std::string& path) { - // SEAM: real hardware read goes here. - // - // For the SO-101 leader this reads the 6 Feetech STS3215 bus servos over `device_path_` - // (using LeRobot's calibration to convert ticks -> radians) into positions_, in kJointNames - // order. Until that is wired up, synthesize a smooth, phase-shifted trajectory so the full - // device -> tracker -> retargeter path can run with no hardware. + std::ifstream file(path); + if (!file) + { + std::cerr << "So101LeaderPlugin: warning: cannot open calibration file '" << path << "'; using defaults" + << std::endl; + return; + } + + std::string line; + int line_no = 0; + while (std::getline(file, line)) + { + ++line_no; + if (const auto hash = line.find('#'); hash != std::string::npos) + { + line.erase(hash); + } + + std::istringstream iss(line); + std::string name; + int servo_id = 0; + double sign = 1.0; + int home_ticks = 2048; + if (!(iss >> name >> servo_id >> sign >> home_ticks)) + { + continue; // blank / comment-only / malformed line + } + + int idx = -1; + for (int i = 0; i < kNumJoints; ++i) + { + if (name == kJointNames[i]) + { + idx = i; + break; + } + } + if (idx < 0) + { + std::cerr << "So101LeaderPlugin: warning: unknown joint '" << name << "' at " << path << ":" << line_no + << std::endl; + continue; + } + calibration_[idx] = JointCalibration{ static_cast(servo_id), (sign < 0.0 ? -1.0 : 1.0), home_ticks }; + } +} + +void So101LeaderPlugin::read_synthetic() +{ + // Smooth, phase-shifted trajectory so the full device -> tracker -> retargeter path can run + // with no hardware. const double phase = 2.0 * std::numbers::pi * static_cast(frame_) / kSynthPeriodFrames; - for (size_t i = 0; i < kJointNames.size() - 1; ++i) + for (int i = 0; i < kNumJoints - 1; ++i) { positions_[i] = kSynthAmplitude * std::sin(phase + 0.5 * static_cast(i)); } // Gripper: normalized open/close oscillation in [0, 1]. - positions_[kJointNames.size() - 1] = 0.5 * (1.0 + std::sin(phase)); + positions_[kNumJoints - 1] = 0.5 * (1.0 + std::sin(phase)); +} + +void So101LeaderPlugin::read_hardware() +{ + // Read the 6 FEETECH STS3215 bus servos and convert ticks -> radians with per-joint + // calibration. A failed read holds the last value so a transient bus hiccup never faults. + for (int i = 0; i < kNumJoints; ++i) + { + uint16_t ticks = 0; + if (bus_->read_position(calibration_[i].servo_id, ticks)) + { + positions_[i] = + calibration_[i].sign * (static_cast(ticks) - calibration_[i].home_ticks) * kTicksToRadians; + } + } } void So101LeaderPlugin::push_current_state() @@ -107,7 +195,14 @@ void So101LeaderPlugin::push_current_state() void So101LeaderPlugin::update() { - read_hardware(); + if (bus_) + { + read_hardware(); + } + else + { + read_synthetic(); + } push_current_state(); ++frame_; } diff --git a/src/plugins/so101_leader/so101_leader_plugin.hpp b/src/plugins/so101_leader/so101_leader_plugin.hpp index 85df319f6..bb36078e3 100644 --- a/src/plugins/so101_leader/so101_leader_plugin.hpp +++ b/src/plugins/so101_leader/so101_leader_plugin.hpp @@ -19,39 +19,66 @@ namespace plugins namespace so101_leader { +class FeetechBus; + +//! Number of SO-101 DOFs: 5-DOF arm + gripper. +inline constexpr int kNumJoints = 6; + /*! * @brief Streams SO-101 (5-DOF + gripper) leader-arm joint angles as ``JointStateOutput`` via * OpenXR ``SchemaPusher``, on the generic joint-space device path. * - * The SO-101 reads 6 Feetech STS3215 bus servos over a serial port (LeRobot's ``FeetechMotorsBus`` - * + calibration). To keep the example hardware-free and headless, this plugin ships a - * **synthetic backend** that emits a smooth joint trajectory; ``read_hardware()`` is the marked - * seam where a real serial/Feetech read replaces the synthetic values. + * The SO-101 reads 6 FEETECH STS3215 bus servos over a serial port. When a serial @p device_path + * is given, the plugin talks to the servos directly via :class:`FeetechBus` (the same SMS/STS wire + * protocol LeRobot's ``FeetechMotorsBus`` uses): it disables torque so the arm can be back-driven + * and reads ``Present_Position`` each frame, converting ticks to radians with per-joint calibration. + * With no device path it falls back to a **synthetic** trajectory so the device -> tracker -> + * retargeter pipeline can run with no hardware (used by CI and the headless example). */ class So101LeaderPlugin { public: /*! - * @param device_path Serial device path (e.g. /dev/ttyACM0) for the real Feetech backend - * (see read_hardware()); the synthetic backend ignores it. Empty for synthetic-only. + * @param device_path Serial device path (e.g. /dev/ttyACM0) for the real FEETECH backend. + * Empty selects the synthetic backend. * @param collection_id Tensor collection id; must match the consumer's JointStateTracker. * Also used as the JointStateOutput.device_id. + * @param calibration_path Optional calibration file (see load_calibration()); empty uses + * defaults (servo ids 1..6 in DOF order, sign +1, home tick 2048). */ - So101LeaderPlugin(const std::string& device_path, const std::string& collection_id); + So101LeaderPlugin(const std::string& device_path, + const std::string& collection_id, + const std::string& calibration_path = ""); ~So101LeaderPlugin(); void update(); private: - // Fill positions_ (in kJointNames order) with the latest joint angles. This reference ships a - // synthetic trajectory; SEAM: replace the body with a real Feetech/serial read for hardware. + //! Per-joint mapping from a FEETECH servo to a joint angle, mirroring LeRobot's calibration: + //! ``angle [rad] = sign * (ticks - home_ticks) * 2*pi / 4096``. + struct JointCalibration + { + uint8_t servo_id; + double sign; // +1 / -1 (LeRobot drive_mode) + int home_ticks; // raw tick at the joint's zero pose (LeRobot homing reference); 2048 = servo center + }; + + //! Fill positions_ from the live servos (held last on a failed read). SEAM for other backends. void read_hardware(); + //! Synthetic smooth trajectory used when no serial device is attached. + void read_synthetic(); void push_current_state(); + //! Parse a whitespace-separated calibration file: ``name servo_id sign home_ticks`` per line + //! (``#`` comments allowed). Unknown joint names are ignored; missing joints keep defaults. + void load_calibration(const std::string& path); std::string device_path_; std::string collection_id_; int64_t frame_ = 0; - double positions_[6] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 }; + double positions_[kNumJoints] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 }; + JointCalibration calibration_[kNumJoints]; + + std::unique_ptr bus_; // null => synthetic backend std::shared_ptr session_; core::SchemaPusher pusher_; From cba31977b62231fdddfce6f5b4b54dc1ba4a6572 Mon Sep 17 00:00:00 2001 From: Rafael Wiltz Date: Thu, 11 Jun 2026 11:31:16 -0400 Subject: [PATCH 3/5] feat: SO-101 leader sync-read + calibrate mode Read all six FEETECH servos in a single SYNC READ (instruction 0x82, one bus round-trip) instead of six sequential request/response pairs, matching LeRobot's sync_read and cutting per-frame latency. Add a `calibrate` subcommand that mirrors LeRobot's homing step: disable torque, prompt the operator to hold the zero pose, average a few sync reads, and print/write a calibration file (servo id, sign, home tick) in the format the plugin consumes. Runs off the serial bus only -- no OpenXR runtime required. --- src/plugins/so101_leader/README.md | 20 +++- src/plugins/so101_leader/feetech_bus.cpp | 67 ++++++++++- src/plugins/so101_leader/feetech_bus.hpp | 15 ++- src/plugins/so101_leader/main.cpp | 10 ++ .../so101_leader/so101_leader_plugin.cpp | 111 +++++++++++++++++- .../so101_leader/so101_leader_plugin.hpp | 10 ++ 6 files changed, 223 insertions(+), 10 deletions(-) diff --git a/src/plugins/so101_leader/README.md b/src/plugins/so101_leader/README.md index f4ec5088e..fefccd437 100644 --- a/src/plugins/so101_leader/README.md +++ b/src/plugins/so101_leader/README.md @@ -14,7 +14,8 @@ hardware [TheRobotStudio/SO-ARM100](https://github.com/TheRobotStudio/SO-ARM100) LeRobot drive via the FEETECH SCServo SDK). `FeetechBus` (`feetech_bus.{hpp,cpp}`) speaks that SMS/STS wire protocol directly — no SDK dependency — and implements just what a *leader* needs: disable torque so the arm can be back-driven by hand, then read `Present_Position` (register 56, -4096 ticks / 360°) each frame. Ticks are converted to radians with per-joint calibration. +4096 ticks / 360°) for all six servos in a **single SYNC READ** per frame (one bus round-trip, not +six — matching LeRobot's `sync_read`). Ticks are converted to radians with per-joint calibration. When no serial device is given, the plugin falls back to a **synthetic** trajectory so the device → tracker → retargeter pipeline runs with no hardware (CI and the headless example). @@ -35,6 +36,23 @@ device → tracker → retargeter pipeline runs with no hardware (CI and the hea Args are positional: `[device_path] [collection_id] [calibration_file]`. The serial backend is Linux/macOS only (POSIX `termios`); the STS bus runs at 1,000,000 bps by default. +### Generate a calibration file + +The `calibrate` subcommand reads the live servo positions and writes a calibration file (it does +**not** need the OpenXR runtime): + +```bash +# Hold the arm at its zero/home pose; this prompts, then averages a few sync reads: +./install/plugins/so101_leader/so101_leader_plugin calibrate /dev/ttyACM0 so101_leader.calib + +# Omit the output path to just print the current ticks (a "dump" for inspection): +./install/plugins/so101_leader/so101_leader_plugin calibrate /dev/ttyACM0 +``` + +It disables torque, prompts you to hold the zero pose, captures each servo's `home_ticks`, and +writes the file below (all `sign` default to `+1` — flip any inverted joint by hand afterward). +This mirrors LeRobot's `lerobot-calibrate` homing step. + ### Hardware setup (per SO-ARM100 / LeRobot) - Assemble the leader arm and **remove the gearbox gears** so the joints move freely and only the diff --git a/src/plugins/so101_leader/feetech_bus.cpp b/src/plugins/so101_leader/feetech_bus.cpp index e3b7599c1..c359c5917 100644 --- a/src/plugins/so101_leader/feetech_bus.cpp +++ b/src/plugins/so101_leader/feetech_bus.cpp @@ -27,8 +27,10 @@ namespace // FEETECH SMS/STS protocol constants. constexpr uint8_t kHeader = 0xFF; +constexpr uint8_t kBroadcastId = 0xFE; constexpr uint8_t kInstRead = 0x02; constexpr uint8_t kInstWrite = 0x03; +constexpr uint8_t kInstSyncRead = 0x82; constexpr uint8_t kRegTorqueEnable = 40; // 1 byte constexpr uint8_t kRegPresentPosition = 56; // 2 bytes, little-endian (SMS/STS) constexpr int kReadTimeoutMs = 20; @@ -172,7 +174,7 @@ bool FeetechBus::read_byte(uint8_t& out, int timeout_ms) return n == 1; } -bool FeetechBus::read_status(uint8_t expected_id, uint8_t* data_out, uint8_t expected_data_len) +bool FeetechBus::read_status_any(uint8_t* data_out, uint8_t expected_data_len, uint8_t& id_out) { // Sync to the 0xFF 0xFF header (tolerates leading noise from bus turnaround). int prev = -1; @@ -222,7 +224,7 @@ bool FeetechBus::read_status(uint8_t expected_id, uint8_t* data_out, uint8_t exp checksum += rest[i]; } const uint8_t expected_checksum = static_cast(~checksum & 0xFF); - if (expected_checksum != rest[length - 1] || id != expected_id) + if (expected_checksum != rest[length - 1]) { return false; } @@ -236,9 +238,16 @@ bool FeetechBus::read_status(uint8_t expected_id, uint8_t* data_out, uint8_t exp { data_out[i] = rest[i + 1]; // skip the error byte at rest[0] } + id_out = id; return true; } +bool FeetechBus::read_status(uint8_t expected_id, uint8_t* data_out, uint8_t expected_data_len) +{ + uint8_t id = 0; + return read_status_any(data_out, expected_data_len, id) && id == expected_id; +} + bool FeetechBus::read_position(uint8_t id, uint16_t& ticks_out) { const uint8_t params[2] = { kRegPresentPosition, 0x02 }; @@ -255,6 +264,52 @@ bool FeetechBus::read_position(uint8_t id, uint16_t& ticks_out) return true; } +bool FeetechBus::sync_read_positions(const std::vector& ids, + std::vector& positions, + std::vector& ok) +{ + positions.assign(ids.size(), 0); + ok.assign(ids.size(), 0); + if (ids.empty()) + { + return true; + } + + // SYNC READ (0x82) to the broadcast id: params are [reg, read_len, id0, id1, ...]. Each + // addressed servo then replies with its own status packet in list order. + std::vector params; + params.reserve(ids.size() + 2); + params.push_back(kRegPresentPosition); + params.push_back(0x02); + params.insert(params.end(), ids.begin(), ids.end()); + if (!write_packet(kBroadcastId, kInstSyncRead, params.data(), static_cast(params.size()))) + { + return false; + } + + // Read up to one reply per requested servo, matching by id so a non-responding servo doesn't + // misalign the rest (the first missing reply just ends the burst). + for (size_t k = 0; k < ids.size(); ++k) + { + uint8_t data[2] = { 0, 0 }; + uint8_t resp_id = 0; + if (!read_status_any(data, 2, resp_id)) + { + break; + } + for (size_t i = 0; i < ids.size(); ++i) + { + if (ids[i] == resp_id && !ok[i]) + { + positions[i] = static_cast(data[0]) | static_cast(data[1] << 8); + ok[i] = 1; + break; + } + } + } + return true; +} + bool FeetechBus::disable_torque(uint8_t id) { const uint8_t params[2] = { kRegTorqueEnable, 0x00 }; @@ -289,6 +344,10 @@ bool FeetechBus::write_packet(uint8_t, uint8_t, const uint8_t*, uint8_t) { return false; } +bool FeetechBus::read_status_any(uint8_t*, uint8_t, uint8_t&) +{ + return false; +} bool FeetechBus::read_status(uint8_t, uint8_t*, uint8_t) { return false; @@ -301,6 +360,10 @@ bool FeetechBus::read_position(uint8_t, uint16_t&) { return false; } +bool FeetechBus::sync_read_positions(const std::vector&, std::vector&, std::vector&) +{ + return false; +} bool FeetechBus::disable_torque(uint8_t) { return false; diff --git a/src/plugins/so101_leader/feetech_bus.hpp b/src/plugins/so101_leader/feetech_bus.hpp index 290751cce..4decdf418 100644 --- a/src/plugins/so101_leader/feetech_bus.hpp +++ b/src/plugins/so101_leader/feetech_bus.hpp @@ -5,6 +5,7 @@ #include #include +#include namespace plugins { @@ -44,13 +45,23 @@ class FeetechBus //! timeout / malformed response so the caller can hold the last value instead of faulting. bool read_position(uint8_t id, uint16_t& ticks_out); + //! Read ``Present_Position`` for all @p ids in a single SYNC READ (one bus round-trip instead + //! of one request/response per servo -- the low-latency path). @p positions and @p ok are + //! resized to ``ids.size()`` and filled in parallel; ``ok[i] == 0`` means servo ``ids[i]`` did + //! not reply (its position is left 0 and the caller should hold its last value). Returns false + //! only if the request itself could not be sent. + bool sync_read_positions(const std::vector& ids, std::vector& positions, std::vector& ok); + //! Write ``Torque_Enable = 0`` (register 40) so servo @p id goes limp and can be moved by hand. bool disable_torque(uint8_t id); private: bool write_packet(uint8_t id, uint8_t instruction, const uint8_t* params, uint8_t param_count); - //! Read a status packet; copies @p expected_data_len data bytes (after the error byte) into - //! @p data_out (may be null when @p expected_data_len is 0). Validates header, id, and checksum. + //! Read one status packet; copies @p expected_data_len data bytes (after the error byte) into + //! @p data_out (may be null when @p expected_data_len is 0), validates header/length/checksum, + //! and reports the responder id in @p id_out. + bool read_status_any(uint8_t* data_out, uint8_t expected_data_len, uint8_t& id_out); + //! As read_status_any(), but also requires the responder id to equal @p expected_id. bool read_status(uint8_t expected_id, uint8_t* data_out, uint8_t expected_data_len); bool read_byte(uint8_t& out, int timeout_ms); diff --git a/src/plugins/so101_leader/main.cpp b/src/plugins/so101_leader/main.cpp index e5393e9a3..edaaa244f 100644 --- a/src/plugins/so101_leader/main.cpp +++ b/src/plugins/so101_leader/main.cpp @@ -14,6 +14,16 @@ using namespace plugins::so101_leader; int main(int argc, char** argv) try { + // Calibration/dump mode: so101_leader_plugin calibrate [output_file] + // Reads the current servo positions (hold the arm at its zero pose) and optionally writes a + // calibration file. No OpenXR runtime required. + if (argc > 1 && std::string(argv[1]) == "calibrate") + { + const std::string device_path = (argc > 2) ? argv[2] : ""; + const std::string output_path = (argc > 3) ? argv[3] : ""; + return run_calibration(device_path, output_path); + } + // Usage: so101_leader_plugin [device_path] [collection_id] [calibration_file] // Empty device_path selects the synthetic backend (no hardware required). const std::string device_path = (argc > 1) ? argv[1] : ""; diff --git a/src/plugins/so101_leader/so101_leader_plugin.cpp b/src/plugins/so101_leader/so101_leader_plugin.cpp index 8a77b323d..d7e603ab3 100644 --- a/src/plugins/so101_leader/so101_leader_plugin.cpp +++ b/src/plugins/so101_leader/so101_leader_plugin.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -19,6 +20,8 @@ #include #include #include +#include +#include namespace plugins { @@ -68,6 +71,14 @@ So101LeaderPlugin::So101LeaderPlugin(const std::string& device_path, load_calibration(calibration_path); } + // Servo id list (DOF order) and reusable scratch for the per-frame sync read. + for (int i = 0; i < kNumJoints; ++i) + { + servo_ids_.push_back(calibration_[i].servo_id); + } + read_ticks_.assign(kNumJoints, 0); + read_ok_.assign(kNumJoints, 0); + if (!device_path_.empty()) { // Throws on POSIX if the port can't be opened; throws unconditionally on Windows. @@ -156,15 +167,19 @@ void So101LeaderPlugin::read_synthetic() void So101LeaderPlugin::read_hardware() { - // Read the 6 FEETECH STS3215 bus servos and convert ticks -> radians with per-joint - // calibration. A failed read holds the last value so a transient bus hiccup never faults. + // One SYNC READ for all six servos (a single bus round-trip) instead of six request/response + // pairs -- lower latency. Convert ticks -> radians with per-joint calibration; a servo that + // doesn't reply holds its last value so a transient bus hiccup never faults. + if (!bus_->sync_read_positions(servo_ids_, read_ticks_, read_ok_)) + { + return; // request could not be sent; hold all + } for (int i = 0; i < kNumJoints; ++i) { - uint16_t ticks = 0; - if (bus_->read_position(calibration_[i].servo_id, ticks)) + if (read_ok_[i]) { positions_[i] = - calibration_[i].sign * (static_cast(ticks) - calibration_[i].home_ticks) * kTicksToRadians; + calibration_[i].sign * (static_cast(read_ticks_[i]) - calibration_[i].home_ticks) * kTicksToRadians; } } } @@ -207,5 +222,91 @@ void So101LeaderPlugin::update() ++frame_; } +int run_calibration(const std::string& device_path, const std::string& output_path) +{ + if (device_path.empty()) + { + std::cerr << "calibrate: a serial device path is required (e.g. /dev/ttyACM0)" << std::endl; + return 2; + } + + FeetechBus bus(device_path, kFeetechBaud); + + // Default bus ids 1..6 in DOF order; back-drive the arm by disabling torque. + std::vector ids; + for (int i = 0; i < kNumJoints; ++i) + { + ids.push_back(static_cast(i + 1)); + bus.disable_torque(ids.back()); + } + + std::cout << "Hold the SO-101 leader at its zero/home pose, then press ENTER..." << std::flush; + std::string line; + std::getline(std::cin, line); + + // Average a few sync reads to smooth encoder jitter. + constexpr int kSamples = 8; + std::vector sums(kNumJoints, 0); + std::vector counts(kNumJoints, 0); + for (int s = 0; s < kSamples; ++s) + { + std::vector ticks; + std::vector ok; + if (bus.sync_read_positions(ids, ticks, ok)) + { + for (int i = 0; i < kNumJoints; ++i) + { + if (ok[i]) + { + sums[i] += ticks[i]; + ++counts[i]; + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + + std::ofstream out; + const bool write_file = !output_path.empty(); + if (write_file) + { + out.open(output_path); + if (!out) + { + std::cerr << "calibrate: cannot write '" << output_path << "'" << std::endl; + return 2; + } + out << "# SO-101 leader calibration (generated by `so101_leader_plugin calibrate`)\n"; + out << "# name id sign home_ticks\n"; + } + + bool all_ok = true; + std::cout << "\nMeasured home positions:" << std::endl; + for (int i = 0; i < kNumJoints; ++i) + { + const int home = counts[i] > 0 ? static_cast((sums[i] + counts[i] / 2) / counts[i]) : 2048; + if (counts[i] == 0) + { + all_ok = false; + std::cerr << " warning: no reply from servo " << static_cast(ids[i]) << " (" << kJointNames[i] + << "); writing default 2048" << std::endl; + } + std::cout << " " << kJointNames[i] << " id=" << static_cast(ids[i]) << " home_ticks=" << home + << std::endl; + if (write_file) + { + out << kJointNames[i] << " " << static_cast(ids[i]) << " 1 " << home << "\n"; + } + } + + if (write_file) + { + out.close(); + std::cout << "Wrote calibration to " << output_path << std::endl; + } + std::cout << "Set 'sign' to -1 for any joint that moves opposite the URDF convention." << std::endl; + return all_ok ? 0 : 1; +} + } // namespace so101_leader } // namespace plugins diff --git a/src/plugins/so101_leader/so101_leader_plugin.hpp b/src/plugins/so101_leader/so101_leader_plugin.hpp index bb36078e3..faac534e4 100644 --- a/src/plugins/so101_leader/so101_leader_plugin.hpp +++ b/src/plugins/so101_leader/so101_leader_plugin.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace core { @@ -79,10 +80,19 @@ class So101LeaderPlugin JointCalibration calibration_[kNumJoints]; std::unique_ptr bus_; // null => synthetic backend + std::vector servo_ids_; // calibration_[*].servo_id in DOF order (sync-read request) + std::vector read_ticks_; // sync-read scratch (reused each frame) + std::vector read_ok_; // sync-read scratch: per-servo reply flag std::shared_ptr session_; core::SchemaPusher pusher_; }; +//! Calibration/dump helper: open @p device_path, back-drive-enable the servos, read the current +//! joint positions (hold the arm at its zero pose first), print them, and -- if @p output_path is +//! non-empty -- write a calibration file in the format ``load_calibration()`` consumes. Does not +//! create an OpenXR session. Returns a process exit code (0 = all servos read). +int run_calibration(const std::string& device_path, const std::string& output_path); + } // namespace so101_leader } // namespace plugins From 0a0172be07e4bd5f514036dc3f56438e33f74d2e Mon Sep 17 00:00:00 2001 From: Rafael Wiltz Date: Thu, 11 Jun 2026 11:52:34 -0400 Subject: [PATCH 4/5] feat: range-of-motion sweep in SO-101 leader calibrate Extend the `calibrate` subcommand with a range-of-motion sweep that mirrors LeRobot's lerobot-calibrate: after the homing step, the operator moves every joint through its range while per-joint min/max ticks are recorded (ENTER to finish). The calibration file gains optional range_min/range_max columns; reads are clamped to that range, which guards against encoder-wrap spikes and out-of-range jitter. The command also prints the gripper's range endpoints in radians for the retargeter's gripper_open/gripper_close. Range columns are optional and default to the full 0..4095 (clamp no-op), so existing four-column calibration files and the no-file defaults are unchanged. --- src/plugins/so101_leader/README.md | 54 +++++--- .../so101_leader/so101_leader_plugin.cpp | 126 ++++++++++++++---- .../so101_leader/so101_leader_plugin.hpp | 20 ++- 3 files changed, 148 insertions(+), 52 deletions(-) diff --git a/src/plugins/so101_leader/README.md b/src/plugins/so101_leader/README.md index fefccd437..a57d3a6b3 100644 --- a/src/plugins/so101_leader/README.md +++ b/src/plugins/so101_leader/README.md @@ -38,20 +38,27 @@ Linux/macOS only (POSIX `termios`); the STS bus runs at 1,000,000 bps by default ### Generate a calibration file -The `calibrate` subcommand reads the live servo positions and writes a calibration file (it does -**not** need the OpenXR runtime): +The `calibrate` subcommand reads the live servos and writes a calibration file (it does **not** need +the OpenXR runtime). It mirrors LeRobot's `lerobot-calibrate`: a homing step then a range-of-motion +sweep. ```bash -# Hold the arm at its zero/home pose; this prompts, then averages a few sync reads: +# Two interactive steps: (1) hold the arm at mid-range, ENTER; (2) sweep every joint, ENTER: ./install/plugins/so101_leader/so101_leader_plugin calibrate /dev/ttyACM0 so101_leader.calib -# Omit the output path to just print the current ticks (a "dump" for inspection): +# Omit the output path to just print the measurements (a "dump" for inspection): ./install/plugins/so101_leader/so101_leader_plugin calibrate /dev/ttyACM0 ``` -It disables torque, prompts you to hold the zero pose, captures each servo's `home_ticks`, and -writes the file below (all `sign` default to `+1` — flip any inverted joint by hand afterward). -This mirrors LeRobot's `lerobot-calibrate` homing step. +It disables torque, then: +1. **Home** — prompts you to hold all joints at the middle of their range and averages a few sync + reads into each servo's `home_ticks` (the middle pose is also the SO-101 URDF/operating zero). +2. **Range sweep** — while you move every joint through its full range, it tracks per-joint + `range_min`/`range_max` until you press ENTER. + +It writes the file below (all `sign` default to `+1` — flip any inverted joint by hand) and prints +the gripper's range endpoints in radians, which you drop into the retargeter's +`gripper_open`/`gripper_close`. ### Hardware setup (per SO-ARM100 / LeRobot) @@ -65,25 +72,28 @@ This mirrors LeRobot's `lerobot-calibrate` homing step. ### Calibration file (optional) Whitespace-separated, one joint per line; `#` starts a comment. Columns: -`name servo_id sign(+1/-1) home_ticks(0..4095)`. The conversion is -`angle [rad] = sign * (ticks - home_ticks) * 2π / 4096`. +`name servo_id sign(+1/-1) home_ticks(0..4095) [range_min range_max]` (the two range columns are +optional). The conversion is +`angle [rad] = sign * (clamp(ticks, range_min, range_max) - home_ticks) * 2π / 4096`. ``` -# joint id sign home_ticks -shoulder_pan 1 1 2048 -shoulder_lift 2 1 2048 -elbow_flex 3 1 2048 -wrist_flex 4 1 2048 -wrist_roll 5 1 2048 -gripper 6 1 2048 +# name id sign home_ticks range_min range_max +shoulder_pan 1 1 2048 800 3300 +shoulder_lift 2 1 2048 900 3100 +elbow_flex 3 1 2048 700 3400 +wrist_flex 4 1 2048 800 3300 +wrist_roll 5 1 2048 0 4095 +gripper 6 1 2048 2000 3000 ``` -Defaults (no file): ids `1..6` in DOF order, `sign +1`, `home_ticks 2048` (servo center). Set -`home_ticks` to each servo's raw `Present_Position` at the joint's URDF-zero pose, and `sign` to -`-1` for any joint whose servo turns opposite the URDF convention (LeRobot's `drive_mode`). For -**joint-mirror** mode the retargeter's per-joint `offset`/`sign`/`scale` can also absorb -calibration; for **EE (URDF FK)** mode the joint angles must already match the URDF, so set -`home_ticks`/`sign` here. +Defaults (no file, or only the first four columns): ids `1..6` in DOF order, `sign +1`, +`home_ticks 2048` (servo center), full range `0..4095` (clamp is a no-op). Set `home_ticks` to each +servo's raw `Present_Position` at the joint's URDF-zero pose, `sign` to `-1` for any joint whose +servo turns opposite the URDF convention (LeRobot's `drive_mode`), and the optional `range_min/max` +to the swept extremes (reads are clamped to them; `range_min < range_max` required or they're +ignored). For **joint-mirror** mode the retargeter's per-joint `offset`/`sign`/`scale` can also +absorb calibration; for **EE (URDF FK)** mode the joint angles must already match the URDF, so set +`home_ticks`/`sign` here. The `calibrate` subcommand fills all of this in for you. The consumer side creates a `JointStateTracker("so101_leader")` (via `JointStateSource(name=..., collection_id="so101_leader", joint_names=[...])`) on the same diff --git a/src/plugins/so101_leader/so101_leader_plugin.cpp b/src/plugins/so101_leader/so101_leader_plugin.cpp index d7e603ab3..c20cd5470 100644 --- a/src/plugins/so101_leader/so101_leader_plugin.cpp +++ b/src/plugins/so101_leader/so101_leader_plugin.cpp @@ -10,7 +10,9 @@ #include #include +#include #include +#include #include #include #include @@ -61,10 +63,11 @@ So101LeaderPlugin::So101LeaderPlugin(const std::string& device_path, .localized_name = "SO-101 Leader Arm", .app_name = "So101LeaderPlugin" }) { - // Defaults: servo ids 1..6 in DOF order, no sign flip, centered at the servo midpoint (2048). + // Defaults: servo ids 1..6 in DOF order, no sign flip, centered at the servo midpoint (2048), + // full tick range (so the clamp is a no-op until a calibration file narrows it). for (int i = 0; i < kNumJoints; ++i) { - calibration_[i] = JointCalibration{ static_cast(i + 1), 1.0, 2048 }; + calibration_[i] = JointCalibration{ static_cast(i + 1), 1.0, 2048, 0, 4095 }; } if (!calibration_path.empty()) { @@ -133,6 +136,15 @@ void So101LeaderPlugin::load_calibration(const std::string& path) continue; // blank / comment-only / malformed line } + // Optional range_min range_max columns (from the calibrate sweep); else full range. + int range_min = 0; + int range_max = 4095; + if (int a = 0, b = 0; (iss >> a >> b) && a >= 0 && b <= 4095 && a < b) + { + range_min = a; + range_max = b; + } + int idx = -1; for (int i = 0; i < kNumJoints; ++i) { @@ -148,7 +160,8 @@ void So101LeaderPlugin::load_calibration(const std::string& path) << std::endl; continue; } - calibration_[idx] = JointCalibration{ static_cast(servo_id), (sign < 0.0 ? -1.0 : 1.0), home_ticks }; + calibration_[idx] = JointCalibration{ static_cast(servo_id), (sign < 0.0 ? -1.0 : 1.0), home_ticks, + range_min, range_max }; } } @@ -178,8 +191,9 @@ void So101LeaderPlugin::read_hardware() { if (read_ok_[i]) { - positions_[i] = - calibration_[i].sign * (static_cast(read_ticks_[i]) - calibration_[i].home_ticks) * kTicksToRadians; + const int ticks = + std::clamp(static_cast(read_ticks_[i]), calibration_[i].range_min, calibration_[i].range_max); + positions_[i] = calibration_[i].sign * (ticks - calibration_[i].home_ticks) * kTicksToRadians; } } } @@ -222,6 +236,48 @@ void So101LeaderPlugin::update() ++frame_; } +namespace +{ + +//! Read the servos a few times and return the per-joint averaged tick (or 2048 if a servo never +//! replied). @p ok_out[i] reflects whether servo @p ids[i] replied at least once. +std::vector averaged_positions(FeetechBus& bus, const std::vector& ids, int samples, std::vector& ok_out) +{ + std::vector sums(ids.size(), 0); + std::vector counts(ids.size(), 0); + for (int s = 0; s < samples; ++s) + { + std::vector ticks; + std::vector ok; + if (bus.sync_read_positions(ids, ticks, ok)) + { + for (size_t i = 0; i < ids.size(); ++i) + { + if (ok[i]) + { + sums[i] += ticks[i]; + ++counts[i]; + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } + + std::vector out(ids.size(), 2048); + ok_out.assign(ids.size(), false); + for (size_t i = 0; i < ids.size(); ++i) + { + if (counts[i] > 0) + { + out[i] = static_cast((sums[i] + counts[i] / 2) / counts[i]); + ok_out[i] = true; + } + } + return out; +} + +} // namespace + int run_calibration(const std::string& device_path, const std::string& output_path) { if (device_path.empty()) @@ -240,15 +296,30 @@ int run_calibration(const std::string& device_path, const std::string& output_pa bus.disable_torque(ids.back()); } - std::cout << "Hold the SO-101 leader at its zero/home pose, then press ENTER..." << std::flush; + // Step 1: home (zero) capture. Holding the middle of the range matches LeRobot's homing step + // and, for the SO-101, the URDF/operating zero convention used by EE-mode forward kinematics. + std::cout << "Step 1/2: move all joints to the MIDDLE of their range of motion, then press ENTER..." << std::flush; std::string line; std::getline(std::cin, line); - - // Average a few sync reads to smooth encoder jitter. - constexpr int kSamples = 8; - std::vector sums(kNumJoints, 0); - std::vector counts(kNumJoints, 0); - for (int s = 0; s < kSamples; ++s) + std::vector home_ok; + const std::vector home = averaged_positions(bus, ids, 8, home_ok); + + // Step 2: range-of-motion sweep -- track per-joint min/max while the operator moves the arm, + // until ENTER is pressed (mirrors LeRobot's record_ranges_of_motion). Seed with home so the + // range always contains the zero pose. + std::vector range_min = home; + std::vector range_max = home; + std::cout << "Step 2/2: move EVERY joint through its full range of motion, then press ENTER to finish..." + << std::endl; + std::atomic stop{ false }; + std::thread waiter( + [&stop]() + { + std::string l; + std::getline(std::cin, l); + stop.store(true); + }); + while (!stop.load()) { std::vector ticks; std::vector ok; @@ -258,13 +329,14 @@ int run_calibration(const std::string& device_path, const std::string& output_pa { if (ok[i]) { - sums[i] += ticks[i]; - ++counts[i]; + range_min[i] = std::min(range_min[i], static_cast(ticks[i])); + range_max[i] = std::max(range_max[i], static_cast(ticks[i])); } } } - std::this_thread::sleep_for(std::chrono::milliseconds(20)); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } + waiter.join(); std::ofstream out; const bool write_file = !output_path.empty(); @@ -277,28 +349,36 @@ int run_calibration(const std::string& device_path, const std::string& output_pa return 2; } out << "# SO-101 leader calibration (generated by `so101_leader_plugin calibrate`)\n"; - out << "# name id sign home_ticks\n"; + out << "# name id sign home_ticks range_min range_max\n"; } bool all_ok = true; - std::cout << "\nMeasured home positions:" << std::endl; + std::cout << "\nMeasured calibration (ticks; angle = sign * (ticks - home) * 2pi/4096):" << std::endl; for (int i = 0; i < kNumJoints; ++i) { - const int home = counts[i] > 0 ? static_cast((sums[i] + counts[i] / 2) / counts[i]) : 2048; - if (counts[i] == 0) + if (!home_ok[i]) { all_ok = false; std::cerr << " warning: no reply from servo " << static_cast(ids[i]) << " (" << kJointNames[i] - << "); writing default 2048" << std::endl; + << "); writing defaults" << std::endl; } - std::cout << " " << kJointNames[i] << " id=" << static_cast(ids[i]) << " home_ticks=" << home - << std::endl; + std::cout << " " << kJointNames[i] << " id=" << static_cast(ids[i]) << " home=" << home[i] + << " range=[" << range_min[i] << ", " << range_max[i] << "]" << std::endl; if (write_file) { - out << kJointNames[i] << " " << static_cast(ids[i]) << " 1 " << home << "\n"; + out << kJointNames[i] << " " << static_cast(ids[i]) << " 1 " << home[i] << " " << range_min[i] << " " + << range_max[i] << "\n"; } } + // Gripper endpoints in radians (relative to home) for the retargeter's gripper_open/gripper_close. + const int g = kNumJoints - 1; + const double grip_lo = (range_min[g] - home[g]) * kTicksToRadians; + const double grip_hi = (range_max[g] - home[g]) * kTicksToRadians; + std::cout << "\nGripper '" << kJointNames[g] << "' range endpoints (radians, relative to home): " << grip_lo + << " .. " << grip_hi << "\n -> set JointStateRetargeterConfig.gripper_open / gripper_close to these " + << "(whichever matches your open/closed convention)." << std::endl; + if (write_file) { out.close(); diff --git a/src/plugins/so101_leader/so101_leader_plugin.hpp b/src/plugins/so101_leader/so101_leader_plugin.hpp index faac534e4..03f9f371f 100644 --- a/src/plugins/so101_leader/so101_leader_plugin.hpp +++ b/src/plugins/so101_leader/so101_leader_plugin.hpp @@ -56,12 +56,14 @@ class So101LeaderPlugin private: //! Per-joint mapping from a FEETECH servo to a joint angle, mirroring LeRobot's calibration: - //! ``angle [rad] = sign * (ticks - home_ticks) * 2*pi / 4096``. + //! ``angle [rad] = sign * (clamp(ticks, range_min, range_max) - home_ticks) * 2*pi / 4096``. struct JointCalibration { uint8_t servo_id; double sign; // +1 / -1 (LeRobot drive_mode) int home_ticks; // raw tick at the joint's zero pose (LeRobot homing reference); 2048 = servo center + int range_min; // sweep min tick; reads are clamped to [range_min, range_max] + int range_max; // sweep max tick; default full range 0..4095 => clamp is a no-op }; //! Fill positions_ from the live servos (held last on a failed read). SEAM for other backends. @@ -69,8 +71,9 @@ class So101LeaderPlugin //! Synthetic smooth trajectory used when no serial device is attached. void read_synthetic(); void push_current_state(); - //! Parse a whitespace-separated calibration file: ``name servo_id sign home_ticks`` per line - //! (``#`` comments allowed). Unknown joint names are ignored; missing joints keep defaults. + //! Parse a whitespace-separated calibration file: ``name servo_id sign home_ticks [range_min + //! range_max]`` per line (``#`` comments allowed; range columns optional). Unknown joint names + //! are ignored; missing joints keep defaults. void load_calibration(const std::string& path); std::string device_path_; @@ -88,10 +91,13 @@ class So101LeaderPlugin core::SchemaPusher pusher_; }; -//! Calibration/dump helper: open @p device_path, back-drive-enable the servos, read the current -//! joint positions (hold the arm at its zero pose first), print them, and -- if @p output_path is -//! non-empty -- write a calibration file in the format ``load_calibration()`` consumes. Does not -//! create an OpenXR session. Returns a process exit code (0 = all servos read). +//! Calibration/dump helper: open @p device_path, back-drive-enable the servos, then (1) capture the +//! home tick with the arm held at the middle of its range, and (2) record each joint's min/max over +//! a range-of-motion sweep (move the arm, press ENTER to finish). Prints the result and -- if +//! @p output_path is non-empty -- writes a calibration file in the format ``load_calibration()`` +//! consumes (``name id sign home_ticks range_min range_max``). Also prints the gripper open/close +//! endpoints in radians for the retargeter. Does not create an OpenXR session. Returns a process +//! exit code (0 = all servos read). int run_calibration(const std::string& device_path, const std::string& output_path); } // namespace so101_leader From daa450f7b7f55439f1acab593073a37f001eac8f Mon Sep 17 00:00:00 2001 From: Rafael Wiltz Date: Thu, 11 Jun 2026 12:16:04 -0400 Subject: [PATCH 5/5] feat: LeRobot calibration interop for SO-101 leader Make the SO-101 leader calibration interchangeable with LeRobot's. A .json path is read/written in LeRobot's format ({id, drive_mode, homing_offset, range_min, range_max} per joint); any other path stays the plain-text format. Mapping: range_min/range_max -> our range, the range midpoint -> home_ticks (LeRobot's zero), drive_mode -> sign. LeRobot keeps each joint's homing_offset in the servo EEPROM (its runtime normalization ignores it), so on connect we read the live Homing_Offset (register 31, sign-magnitude) and shift home/range by file-minus-servo to reconcile the frames without writing to the servo. The calibrate subcommand emits LeRobot JSON when the output path ends in .json. Includes a small dependency-free JSON reader for the flat LeRobot schema. --- src/plugins/so101_leader/README.md | 32 +- src/plugins/so101_leader/feetech_bus.cpp | 26 ++ src/plugins/so101_leader/feetech_bus.hpp | 5 + .../so101_leader/so101_leader_plugin.cpp | 302 ++++++++++++++++-- .../so101_leader/so101_leader_plugin.hpp | 23 +- 5 files changed, 359 insertions(+), 29 deletions(-) diff --git a/src/plugins/so101_leader/README.md b/src/plugins/so101_leader/README.md index a57d3a6b3..20c9e4c09 100644 --- a/src/plugins/so101_leader/README.md +++ b/src/plugins/so101_leader/README.md @@ -46,6 +46,9 @@ sweep. # Two interactive steps: (1) hold the arm at mid-range, ENTER; (2) sweep every joint, ENTER: ./install/plugins/so101_leader/so101_leader_plugin calibrate /dev/ttyACM0 so101_leader.calib +# Write a LeRobot-format calibration instead (chosen by the .json extension): +./install/plugins/so101_leader/so101_leader_plugin calibrate /dev/ttyACM0 so101_leader.json + # Omit the output path to just print the measurements (a "dump" for inspection): ./install/plugins/so101_leader/so101_leader_plugin calibrate /dev/ttyACM0 ``` @@ -56,9 +59,29 @@ It disables torque, then: 2. **Range sweep** — while you move every joint through its full range, it tracks per-joint `range_min`/`range_max` until you press ENTER. -It writes the file below (all `sign` default to `+1` — flip any inverted joint by hand) and prints -the gripper's range endpoints in radians, which you drop into the retargeter's -`gripper_open`/`gripper_close`. +It writes the file (plain text, or LeRobot JSON if the path ends in `.json`; all `sign` default to +`+1` — flip any inverted joint by hand) and prints the gripper's range endpoints in radians, which +you drop into the retargeter's `gripper_open`/`gripper_close`. + +### LeRobot calibration interoperability + +The calibration files are interchangeable with [LeRobot](https://github.com/huggingface/lerobot) +(`SO101Leader`). A `.json` path is read/written in LeRobot's format +(`{ "joint": {"id", "drive_mode", "homing_offset", "range_min", "range_max"}, ... }`); anything else +is the plain-text format above. + +```bash +# Use an existing LeRobot calibration directly: +./install/plugins/so101_leader/so101_leader_plugin /dev/ttyACM0 so101_leader \ + ~/.cache/huggingface/lerobot/calibration/teleoperators/so101_leader/.json +``` + +Mapping: LeRobot's `range_min`/`range_max` → our range (reads are clamped to it), the range midpoint +→ `home_ticks` (LeRobot's zero), `drive_mode` → `sign`. LeRobot stores each joint's `homing_offset` +in the **servo EEPROM** (its runtime normalization doesn't use it), so on connect we read the +servo's live `Homing_Offset` (register 31) and shift home/range by +`homing_offset_file − homing_offset_servo` — reconciling the frames without writing to the servo. +The result reproduces LeRobot's joint angles (in radians rather than degrees). ### Hardware setup (per SO-ARM100 / LeRobot) @@ -71,7 +94,8 @@ the gripper's range endpoints in radians, which you drop into the retargeter's ### Calibration file (optional) -Whitespace-separated, one joint per line; `#` starts a comment. Columns: +The plain-text format (used unless the path ends in `.json`) is whitespace-separated, one joint per +line; `#` starts a comment. Columns: `name servo_id sign(+1/-1) home_ticks(0..4095) [range_min range_max]` (the two range columns are optional). The conversion is `angle [rad] = sign * (clamp(ticks, range_min, range_max) - home_ticks) * 2π / 4096`. diff --git a/src/plugins/so101_leader/feetech_bus.cpp b/src/plugins/so101_leader/feetech_bus.cpp index c359c5917..0a2698f7c 100644 --- a/src/plugins/so101_leader/feetech_bus.cpp +++ b/src/plugins/so101_leader/feetech_bus.cpp @@ -33,6 +33,8 @@ constexpr uint8_t kInstWrite = 0x03; constexpr uint8_t kInstSyncRead = 0x82; constexpr uint8_t kRegTorqueEnable = 40; // 1 byte constexpr uint8_t kRegPresentPosition = 56; // 2 bytes, little-endian (SMS/STS) +constexpr uint8_t kRegHomingOffset = 31; // 2 bytes, sign-magnitude (sign bit 11) on SMS/STS +constexpr int kHomingOffsetSignBit = 11; constexpr int kReadTimeoutMs = 20; // Map a numeric baud rate to the matching termios speed constant. Only the rates a FEETECH bus @@ -310,6 +312,26 @@ bool FeetechBus::sync_read_positions(const std::vector& ids, return true; } +bool FeetechBus::read_homing_offset(uint8_t id, int& offset_out) +{ + const uint8_t params[2] = { kRegHomingOffset, 0x02 }; + if (!write_packet(id, kInstRead, params, 2)) + { + return false; + } + uint8_t data[2] = { 0, 0 }; + if (!read_status(id, data, 2)) + { + return false; + } + const uint16_t raw = static_cast(data[0]) | static_cast(data[1] << 8); + // Sign-magnitude: bits [0, sign_bit) are the magnitude, bit sign_bit is the sign. + const uint16_t magnitude = raw & ((1u << kHomingOffsetSignBit) - 1u); + const bool negative = (raw >> kHomingOffsetSignBit) & 1u; + offset_out = negative ? -static_cast(magnitude) : static_cast(magnitude); + return true; +} + bool FeetechBus::disable_torque(uint8_t id) { const uint8_t params[2] = { kRegTorqueEnable, 0x00 }; @@ -364,6 +386,10 @@ bool FeetechBus::sync_read_positions(const std::vector&, std::vector& ids, std::vector& positions, std::vector& ok); + //! Read the signed ``Homing_Offset`` (register 31, sign-magnitude with sign bit 11) the servo + //! subtracts from its actual position (``Present_Position = Actual_Position - Homing_Offset``). + //! Used to reconcile our reads with a LeRobot calibration, whose offsets are stored in the servo. + bool read_homing_offset(uint8_t id, int& offset_out); + //! Write ``Torque_Enable = 0`` (register 40) so servo @p id goes limp and can be moved by hand. bool disable_torque(uint8_t id); diff --git a/src/plugins/so101_leader/so101_leader_plugin.cpp b/src/plugins/so101_leader/so101_leader_plugin.cpp index c20cd5470..39eb3a46d 100644 --- a/src/plugins/so101_leader/so101_leader_plugin.cpp +++ b/src/plugins/so101_leader/so101_leader_plugin.cpp @@ -13,11 +13,14 @@ #include #include #include +#include #include #include #include +#include #include #include +#include #include #include #include @@ -97,6 +100,10 @@ So101LeaderPlugin::So101LeaderPlugin(const std::string& device_path, << static_cast(calibration_[i].servo_id) << " (is it powered / on the bus?)" << std::endl; } } + + // A LeRobot calibration's homing offsets live in the servo EEPROM; reconcile now that the + // bus is up so our reads land in the same frame LeRobot uses. + compensate_homing(); } else { @@ -108,6 +115,12 @@ So101LeaderPlugin::~So101LeaderPlugin() = default; void So101LeaderPlugin::load_calibration(const std::string& path) { + if (path.ends_with(".json")) + { + load_lerobot_calibration(path); + return; + } + std::ifstream file(path); if (!file) { @@ -165,6 +178,85 @@ void So101LeaderPlugin::load_calibration(const std::string& path) } } +void So101LeaderPlugin::load_lerobot_calibration(const std::string& path) +{ + std::ifstream file(path); + if (!file) + { + std::cerr << "So101LeaderPlugin: warning: cannot open LeRobot calibration '" << path << "'; using defaults" + << std::endl; + return; + } + std::stringstream buffer; + buffer << file.rdbuf(); + const auto motors = parse_lerobot_calibration(buffer.str()); + if (motors.empty()) + { + std::cerr << "So101LeaderPlugin: warning: could not parse LeRobot calibration '" << path << "'; using defaults" + << std::endl; + return; + } + + lerobot_homing_.assign(kNumJoints, 0); + for (const auto& [name, fields] : motors) + { + int idx = -1; + for (int i = 0; i < kNumJoints; ++i) + { + if (name == kJointNames[i]) + { + idx = i; + break; + } + } + if (idx < 0) + { + std::cerr << "So101LeaderPlugin: warning: unknown joint '" << name << "' in " << path << std::endl; + continue; + } + + const auto get = [&fields](const char* key, long fallback) -> long + { + const auto it = fields.find(key); + return it != fields.end() ? it->second : fallback; + }; + const int servo_id = static_cast(get("id", idx + 1)); + const bool inverted = get("drive_mode", 0) != 0; + const int range_min = static_cast(get("range_min", 0)); + const int range_max = static_cast(get("range_max", 4095)); + // LeRobot's zero is the range midpoint (its DEGREES normalization centers on it); homing_offset + // lives in the servo and is reconciled in compensate_homing(). + const int home = (range_min + range_max) / 2; + calibration_[idx] = + JointCalibration{ static_cast(servo_id), inverted ? -1.0 : 1.0, home, range_min, range_max }; + lerobot_homing_[idx] = static_cast(get("homing_offset", 0)); + } +} + +void So101LeaderPlugin::compensate_homing() +{ + if (lerobot_homing_.empty() || !bus_) + { + return; // no LeRobot calibration loaded, or no live bus to read the servo offset + } + for (int i = 0; i < kNumJoints; ++i) + { + int servo_offset = 0; + if (!bus_->read_homing_offset(calibration_[i].servo_id, servo_offset)) + { + std::cerr << "So101LeaderPlugin: warning: could not read Homing_Offset of servo " + << static_cast(calibration_[i].servo_id) << "; assuming it matches the calibration file" + << std::endl; + continue; + } + // File offsets are in the servo's homed frame; shift into this servo's current frame. + const int delta = lerobot_homing_[i] - servo_offset; + calibration_[i].home_ticks += delta; + calibration_[i].range_min += delta; + calibration_[i].range_max += delta; + } +} + void So101LeaderPlugin::read_synthetic() { // Smooth, phase-shifted trajectory so the full device -> tracker -> retargeter path can run @@ -338,20 +430,6 @@ int run_calibration(const std::string& device_path, const std::string& output_pa } waiter.join(); - std::ofstream out; - const bool write_file = !output_path.empty(); - if (write_file) - { - out.open(output_path); - if (!out) - { - std::cerr << "calibrate: cannot write '" << output_path << "'" << std::endl; - return 2; - } - out << "# SO-101 leader calibration (generated by `so101_leader_plugin calibrate`)\n"; - out << "# name id sign home_ticks range_min range_max\n"; - } - bool all_ok = true; std::cout << "\nMeasured calibration (ticks; angle = sign * (ticks - home) * 2pi/4096):" << std::endl; for (int i = 0; i < kNumJoints; ++i) @@ -364,11 +442,6 @@ int run_calibration(const std::string& device_path, const std::string& output_pa } std::cout << " " << kJointNames[i] << " id=" << static_cast(ids[i]) << " home=" << home[i] << " range=[" << range_min[i] << ", " << range_max[i] << "]" << std::endl; - if (write_file) - { - out << kJointNames[i] << " " << static_cast(ids[i]) << " 1 " << home[i] << " " << range_min[i] << " " - << range_max[i] << "\n"; - } } // Gripper endpoints in radians (relative to home) for the retargeter's gripper_open/gripper_close. @@ -379,14 +452,199 @@ int run_calibration(const std::string& device_path, const std::string& output_pa << " .. " << grip_hi << "\n -> set JointStateRetargeterConfig.gripper_open / gripper_close to these " << "(whichever matches your open/closed convention)." << std::endl; - if (write_file) + if (!output_path.empty()) { - out.close(); - std::cout << "Wrote calibration to " << output_path << std::endl; + std::ofstream out(output_path); + if (!out) + { + std::cerr << "calibrate: cannot write '" << output_path << "'" << std::endl; + return 2; + } + if (output_path.ends_with(".json")) + { + // LeRobot calibration JSON. drive_mode 0 (the leader sweep keeps sign +1); homing_offset + // is read back from the servo so re-applying this file is a no-op and the recorded ranges + // stay in the servo's current frame. + out << "{\n"; + for (int i = 0; i < kNumJoints; ++i) + { + int homing = 0; + bus.read_homing_offset(ids[i], homing); + out << " \"" << kJointNames[i] << "\": {\n"; + out << " \"id\": " << static_cast(ids[i]) << ",\n"; + out << " \"drive_mode\": 0,\n"; + out << " \"homing_offset\": " << homing << ",\n"; + out << " \"range_min\": " << range_min[i] << ",\n"; + out << " \"range_max\": " << range_max[i] << "\n"; + out << " }" << (i + 1 < kNumJoints ? "," : "") << "\n"; + } + out << "}\n"; + } + else + { + out << "# SO-101 leader calibration (generated by `so101_leader_plugin calibrate`)\n"; + out << "# name id sign home_ticks range_min range_max\n"; + for (int i = 0; i < kNumJoints; ++i) + { + out << kJointNames[i] << " " << static_cast(ids[i]) << " 1 " << home[i] << " " << range_min[i] + << " " << range_max[i] << "\n"; + } + } + std::cout << "Wrote " << (output_path.ends_with(".json") ? "LeRobot " : "") << "calibration to " << output_path + << std::endl; } std::cout << "Set 'sign' to -1 for any joint that moves opposite the URDF convention." << std::endl; return all_ok ? 0 : 1; } +std::map> parse_lerobot_calibration(const std::string& json) +{ + std::map> result; + size_t i = 0; + const size_t n = json.size(); + + const auto skip_ws = [&]() + { + while (i < n && std::isspace(static_cast(json[i]))) + { + ++i; + } + }; + const auto parse_string = [&](std::string& out) -> bool + { + skip_ws(); + if (i >= n || json[i] != '"') + { + return false; + } + ++i; + out.clear(); + while (i < n && json[i] != '"') + { + if (json[i] == '\\' && i + 1 < n) + { + ++i; // take the escaped character literally (keys/fields here have no escapes) + } + out.push_back(json[i++]); + } + if (i >= n) + { + return false; + } + ++i; // closing quote + return true; + }; + const auto consume = [&](char c) -> bool + { + skip_ws(); + if (i < n && json[i] == c) + { + ++i; + return true; + } + return false; + }; + + if (!consume('{')) + { + return result; + } + skip_ws(); + if (i < n && json[i] == '}') + { + return result; // empty object + } + + while (i < n) + { + std::string motor; + if (!parse_string(motor) || !consume(':') || !consume('{')) + { + result.clear(); + return result; + } + + std::map fields; + skip_ws(); + if (i < n && json[i] == '}') + { + ++i; // empty motor object + } + else + { + while (i < n) + { + std::string key; + if (!parse_string(key) || !consume(':')) + { + result.clear(); + return result; + } + skip_ws(); + const size_t start = i; + if (i < n && (json[i] == '-' || json[i] == '+')) + { + ++i; + } + bool is_int = false; + while (i < n && std::isdigit(static_cast(json[i]))) + { + ++i; + is_int = true; + } + if (is_int && (i >= n || (json[i] != '.' && json[i] != 'e' && json[i] != 'E'))) + { + fields[key] = std::strtol(json.c_str() + start, nullptr, 10); + } + else + { + // Non-integer value (float / string / bool / null): skip it shallowly. + i = start; + if (i < n && json[i] == '"') + { + std::string tmp; + parse_string(tmp); + } + else + { + while (i < n && json[i] != ',' && json[i] != '}') + { + ++i; + } + } + } + skip_ws(); + if (i < n && json[i] == ',') + { + ++i; + continue; + } + if (i < n && json[i] == '}') + { + ++i; + break; + } + result.clear(); + return result; + } + } + + result[motor] = std::move(fields); + skip_ws(); + if (i < n && json[i] == ',') + { + ++i; + continue; + } + if (i < n && json[i] == '}') + { + ++i; + break; + } + break; + } + return result; +} + } // namespace so101_leader } // namespace plugins diff --git a/src/plugins/so101_leader/so101_leader_plugin.hpp b/src/plugins/so101_leader/so101_leader_plugin.hpp index 03f9f371f..3b9aa07af 100644 --- a/src/plugins/so101_leader/so101_leader_plugin.hpp +++ b/src/plugins/so101_leader/so101_leader_plugin.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -71,10 +72,19 @@ class So101LeaderPlugin //! Synthetic smooth trajectory used when no serial device is attached. void read_synthetic(); void push_current_state(); - //! Parse a whitespace-separated calibration file: ``name servo_id sign home_ticks [range_min - //! range_max]`` per line (``#`` comments allowed; range columns optional). Unknown joint names - //! are ignored; missing joints keep defaults. + //! Load calibration from @p path. A ``.json`` file is read as a LeRobot calibration (see + //! load_lerobot_calibration()); anything else is the plain-text format: ``name servo_id sign + //! home_ticks [range_min range_max]`` per line (``#`` comments allowed; range columns optional). + //! Unknown joint names are ignored; missing joints keep defaults. void load_calibration(const std::string& path); + //! Load a LeRobot calibration JSON. Maps ``range_min/range_max`` -> range, the range midpoint -> + //! ``home_ticks`` (LeRobot's zero), and ``drive_mode`` -> ``sign``. The per-joint + //! ``homing_offset`` is reconciled against the servo by compensate_homing(). + void load_lerobot_calibration(const std::string& path); + //! Reconcile a loaded LeRobot calibration with the live servos: LeRobot's offsets live in the + //! servo EEPROM, so shift home/range by ``homing_offset_file - homing_offset_servo`` (read live). + //! No-op without a LeRobot calibration or a connected bus. + void compensate_homing(); std::string device_path_; std::string collection_id_; @@ -86,6 +96,7 @@ class So101LeaderPlugin std::vector servo_ids_; // calibration_[*].servo_id in DOF order (sync-read request) std::vector read_ticks_; // sync-read scratch (reused each frame) std::vector read_ok_; // sync-read scratch: per-servo reply flag + std::vector lerobot_homing_; // per-DOF homing_offset from a loaded LeRobot JSON (else empty) std::shared_ptr session_; core::SchemaPusher pusher_; @@ -100,5 +111,11 @@ class So101LeaderPlugin //! exit code (0 = all servos read). int run_calibration(const std::string& device_path, const std::string& output_path); +//! Minimal reader for a LeRobot calibration JSON of the shape ``{ "joint": {"id": int, +//! "drive_mode": int, "homing_offset": int, "range_min": int, "range_max": int}, ... }``. Returns +//! ``joint -> {field -> integer}`` (non-integer values skipped). Not a general JSON parser; returns +//! an empty map on malformed input. +std::map> parse_lerobot_calibration(const std::string& json); + } // namespace so101_leader } // namespace plugins