Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ Guidelines for modifications:
* Tsz Ki GAO
* Tyler Lum
* Victor Khaustov
* Vidur Vij
* Virgilio Gómez Lambo
* Vladimir Fokow
* Wei Yang
Expand Down
32 changes: 32 additions & 0 deletions docs/source/api/lab/isaaclab.actuators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
ActuatorNetMLPCfg
ActuatorNetLSTM
ActuatorNetLSTMCfg
ActuatorNetGRU
ActuatorNetGRUCfg
ActuatorNetGRUResidual
ActuatorNetGRUResidualCfg

Actuator Base
-------------
Expand Down Expand Up @@ -133,3 +137,31 @@ LSTM Network Actuator
:inherited-members:
:show-inheritance:
:exclude-members: __init__, class_type

GRU Network Actuator
--------------------

.. autoclass:: ActuatorNetGRU
:members:
:inherited-members:
:show-inheritance:

.. autoclass:: ActuatorNetGRUCfg
:members:
:inherited-members:
:show-inheritance:
:exclude-members: __init__, class_type

GRU Residual Network Actuator
-----------------------------

.. autoclass:: ActuatorNetGRUResidual
:members:
:inherited-members:
:show-inheritance:

.. autoclass:: ActuatorNetGRUResidualCfg
:members:
:inherited-members:
:show-inheritance:
:exclude-members: __init__, class_type
11 changes: 11 additions & 0 deletions source/isaaclab/changelog.d/vidurv-gru-actuators.minor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Added
^^^^^

* Added :class:`~isaaclab.actuators.ActuatorNetGRU` and
:class:`~isaaclab.actuators.ActuatorNetGRUCfg`, an explicit actuator whose GRU
network predicts the total joint effort from the joint position, position error, and
velocity, with optional input and output normalization.
* Added :class:`~isaaclab.actuators.ActuatorNetGRUResidual` and
:class:`~isaaclab.actuators.ActuatorNetGRUResidualCfg`, an implicit-PD actuator that
adds a GRU-predicted residual feed-forward effort, with optional input and output
normalization.
8 changes: 6 additions & 2 deletions source/isaaclab/isaaclab/actuators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
__all__ = [
"ActuatorBase",
"ActuatorBaseCfg",
"ActuatorNetGRU",
"ActuatorNetGRUResidual",
"ActuatorNetLSTM",
"ActuatorNetMLP",
"ActuatorNetGRUCfg",
"ActuatorNetGRUResidualCfg",
"ActuatorNetLSTMCfg",
"ActuatorNetMLPCfg",
"DCMotor",
Expand All @@ -24,8 +28,8 @@ __all__ = [

from .actuator_base import ActuatorBase
from .actuator_base_cfg import ActuatorBaseCfg
from .actuator_net import ActuatorNetLSTM, ActuatorNetMLP
from .actuator_net_cfg import ActuatorNetLSTMCfg, ActuatorNetMLPCfg
from .actuator_net import ActuatorNetGRU, ActuatorNetGRUResidual, ActuatorNetLSTM, ActuatorNetMLP
from .actuator_net_cfg import ActuatorNetGRUCfg, ActuatorNetGRUResidualCfg, ActuatorNetLSTMCfg, ActuatorNetMLPCfg
from .actuator_pd import (
DCMotor,
DelayedPDActuator,
Expand Down
249 changes: 247 additions & 2 deletions source/isaaclab/isaaclab/actuators/actuator_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

* Multi-Layer Perceptron (MLP)
* Long Short-Term Memory (LSTM)
* Gated Recurrent Unit (GRU), both explicit full-torque and implicit-PD residual variants

"""

from __future__ import annotations

import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING

Expand All @@ -22,10 +24,17 @@
from isaaclab.utils.assets import read_file
from isaaclab.utils.types import ArticulationActions

from .actuator_pd import DCMotor
from .actuator_pd import DCMotor, IdealPDActuator, ImplicitActuator

if TYPE_CHECKING:
from .actuator_net_cfg import ActuatorNetLSTMCfg, ActuatorNetMLPCfg
from .actuator_net_cfg import (
ActuatorNetGRUCfg,
ActuatorNetGRUResidualCfg,
ActuatorNetLSTMCfg,
ActuatorNetMLPCfg,
)

logger = logging.getLogger(__name__)


class ActuatorNetLSTM(DCMotor):
Expand Down Expand Up @@ -98,6 +107,242 @@ def compute(
return control_action


class _GRUActuatorMixin:
"""Shared machinery for the GRU-based actuator models.

Loads the TorchScript GRU network, allocates the recurrent input and hidden-state buffers, and
runs inference. The network consumes a fixed input of joint position, position error, and
velocity. An optional ``(mean, std)`` normalization may be applied to each input and to the
output (``None`` selects the identity transform). The concrete actuator classes combine this
mixin with an explicit (:class:`IdealPDActuator`) or implicit (:class:`ImplicitActuator`) base
to define their effort semantics.
"""

# number of fixed network inputs: [position, position_error, velocity]
_NUM_INPUTS = 3
# standard-deviation floor used when normalizing to avoid division by tiny values
_GRU_STD_FLOOR = 1.0e-8

def _init_gru_runtime(self) -> None:
"""Load the network and allocate the GRU buffers and normalization statistics.

Raises:
ValueError: If the TorchScript module does not expose a ``.gru`` submodule, or if its
input dimension is not 3 (joint position, position error, and velocity).
"""
# load the TorchScript network
file_bytes = read_file(self.cfg.network_file)
self.network = torch.jit.load(file_bytes, map_location=self._device).eval()
if not hasattr(self.network, "gru"):
raise ValueError(f"The network file '{self.cfg.network_file}' must expose a TorchScript '.gru' submodule.")

# infer dimensions from the GRU weights (the input is [position, position_error, velocity])
gru_state = self.network.gru.state_dict()
if any("reverse" in key for key in gru_state):
raise ValueError(
f"The network file '{self.cfg.network_file}' uses a bidirectional GRU, which is not supported."
)
input_dim = int(gru_state["weight_ih_l0"].shape[1])
hidden_dim = int(gru_state["weight_hh_l0"].shape[1])
num_layers = sum(1 for key in gru_state if key.startswith("weight_ih_l") and "reverse" not in key)
if input_dim != self._NUM_INPUTS:
raise ValueError(
f"The network file '{self.cfg.network_file}' must take {self._NUM_INPUTS} inputs (joint position,"
f" position error, and velocity), but its GRU expects {input_dim}."
)

# resolve (mean, std) normalization for the inputs and output (identity when unset)
self._position_norm = self._resolve_normalization(self.cfg.position_normalization, "position_normalization")
self._pos_error_norm = self._resolve_normalization(self.cfg.pos_error_normalization, "pos_error_normalization")
self._vel_norm = self._resolve_normalization(self.cfg.vel_normalization, "vel_normalization")
self._output_norm = self._resolve_normalization(self.cfg.output_normalization, "output_normalization")

# recurrent input and hidden-state buffers
batch = self._num_envs * self.num_joints
self.sea_input = torch.zeros(batch, 1, self._NUM_INPUTS, device=self._device)
self.sea_hidden_state = torch.zeros(num_layers, batch, hidden_dim, device=self._device)
# per-env view for resets (shares storage)
self.sea_hidden_state_per_env = self.sea_hidden_state.view(
num_layers, self._num_envs, self.num_joints, hidden_dim
)

def _resolve_normalization(self, stats: tuple[float, float] | None, name: str) -> tuple[float, float]:
"""Return the ``(mean, std)`` to apply, defaulting to identity and flooring the std.

Args:
stats: The ``(mean, std)`` pair, or None for the identity transform.
name: The configuration field name, used for the warning message.

Returns:
The resolved ``(mean, std)`` with the std floored to avoid division by tiny values.
"""
if stats is None:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Warning: If std is passed as 0.0 (e.g., from a misconfigured normalization config), this silently floors to 1e-8, which can make the normalized inputs explode to ~O(1e8). Consider logging a warning when the input std is below the floor.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 63903b9_resolve_normalization now logs a warning when a configured std is below the floor (and identifies which normalization field), since a near-zero std amplifies the normalized values.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the warning with field identification is exactly what I had in mind. LGTM ✅

return 0.0, 1.0
mean, std = float(stats[0]), float(stats[1])
if std < 0.0:
raise ValueError(
f"Actuator '{self.cfg.network_file}' has {name} std={std}; the standard deviation must be"
" non-negative. Check the (mean, std) ordering."
)
if std < self._GRU_STD_FLOOR:
logger.warning(
"Actuator '%s' has %s std=%s below the floor %s; flooring it, which can amplify the"
" normalized values. Set a larger std or leave the field unset for identity.",
self.cfg.network_file,
name,
std,
self._GRU_STD_FLOOR,
)
return mean, max(std, self._GRU_STD_FLOOR)

def _reset_gru_state(self, env_ids: Sequence[int]):
"""Zero the GRU hidden state for the specified environments.

Args:
env_ids: The environment indices whose hidden state should be reset.
"""
with torch.no_grad():
self.sea_hidden_state_per_env[:, env_ids] = 0.0

def _predict_gru_effort(
self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor
) -> torch.Tensor:
"""Assemble the network input, run inference, and return the denormalized effort.

Args:
control_action: The joint action instance holding the desired joint positions.
joint_pos: The current joint positions. Shape is (num_envs, num_joints).
joint_vel: The current joint velocities. Shape is (num_envs, num_joints).

Returns:
The predicted effort [N·m or N, depending on joint type]. Shape is
(num_envs, num_joints).

Raises:
ValueError: If ``control_action.joint_positions`` is None.
"""
if control_action.joint_positions is None:
raise ValueError("GRU actuator input requires control_action.joint_positions to be set.")
# normalized [position, position_error, velocity] inputs
position = joint_pos.flatten()
pos_error = (control_action.joint_positions - joint_pos).flatten()
velocity = joint_vel.flatten()
self.sea_input[:, 0, 0] = (position - self._position_norm[0]) / self._position_norm[1]
self.sea_input[:, 0, 1] = (pos_error - self._pos_error_norm[0]) / self._pos_error_norm[1]
self.sea_input[:, 0, 2] = (velocity - self._vel_norm[0]) / self._vel_norm[1]

# run inference, then denormalize and guard against a non-finite output
with torch.inference_mode():
output, self.sea_hidden_state[:] = self.network(self.sea_input, self.sea_hidden_state)
Comment on lines +211 to +236

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Negative std silently corrected to 1e-8 instead of raising an error

max(float(std), self._GRU_STD_FLOOR) will replace any negative std (e.g., -2.0) with 1e-8 — effectively dividing by nearly zero — while the warning message says "flooring it, which can amplify the normalized values". A user who accidentally swaps the (mean, std) tuple order (passing (std, mean)) or passes a negative value gets input amplification of up to 1 / 1e-8 = 1e8 with a warning that does not clearly signal the sign correction. Only positive-but-small std values belong in the "floor" category; a negative std should raise ValueError to force the misconfiguration to be fixed.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e1dea0f_resolve_normalization now raises ValueError for a negative std (only 0 ≤ std < 1e-8 is floored, with the warning). Added a regression test (test_actuator_net_gru_negative_std_raises).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent — raising on negative std while flooring the near-zero case is the right split. Regression test seals it. 👍

output = output * self._output_norm[1] + self._output_norm[0]
# a non-finite prediction carries no usable actuation, so command zero effort this step
output = torch.nan_to_num(output, nan=0.0, posinf=0.0, neginf=0.0)
return output.reshape(self._num_envs, self.num_joints)


class ActuatorNetGRU(_GRUActuatorMixin, IdealPDActuator):
"""Explicit actuator model based on a recurrent neural network (GRU).

The GRU network predicts the *total* joint effort [N·m or N, depending on joint type] from the
joint position, position error, and velocity. Unlike the analytical models, no PD gains are
applied; the hidden state of the recurrent network captures the actuator history. The predicted
effort is clipped to the actuator's effort limit via :meth:`~isaaclab.actuators.ActuatorBase._clip_effort`.

This model derives from :class:`IdealPDActuator`, whose simple symmetric ``±effort_limit``
saturation matches a learned total-torque source without requiring the velocity-dependent
torque-speed parameters of a DC motor.

Note:
The recurrent hidden state encodes the actuator history and is only cleared by
:meth:`reset`. Callers must reset the relevant environments on episode boundaries
(and after any control gap, e.g. a hardware reconnect) so the first post-reset effort is
not computed against stale temporal context.
"""

cfg: ActuatorNetGRUCfg
"""The configuration of the actuator model."""

def __init__(self, cfg: ActuatorNetGRUCfg, *args, **kwargs):
super().__init__(cfg, *args, **kwargs)
self._init_gru_runtime()

"""
Operations.
"""

def reset(self, env_ids: Sequence[int]):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔵 Suggestion: Consider calling super().reset(env_ids) here for consistency with ActuatorNetGRUResidual.reset() and to future-proof against base class changes.

Suggested change
def reset(self, env_ids: Sequence[int]):
def reset(self, env_ids: Sequence[int]):
super().reset(env_ids)
self._reset_gru_state(env_ids)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 63903b9ActuatorNetGRU.reset() now calls super().reset(env_ids) before resetting the hidden state, matching ActuatorNetGRUResidual.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! 👍

super().reset(env_ids)
self._reset_gru_state(env_ids)

Comment on lines +271 to +276

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Denormalization applied to an inference-mode tensor outside the context block

output returned by the network inside torch.inference_mode() is an inference tensor. The denormalization output * self._output_norm[1] + self._output_norm[0] and the subsequent nan_to_num are executed outside the with block, so they run with gradient tracking enabled (if a grad context is active). In most deployment scenarios this is harmless, but it creates unnecessary ops when called from within a torch.no_grad() or outer inference_mode() scope. Consider moving the denormalization and sanitization inside the with torch.inference_mode(): block for consistency and to avoid any surprise interactions with autograd.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e1dea0f — the denormalization and nan_to_num now run inside the with torch.inference_mode() block.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! Confirmed — having everything under the inference_mode() context is cleaner and correct. 👍

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick fix! Verified that the denormalization and nan_to_num are now properly scoped inside the inference_mode() block. 👍

def compute(
self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor
) -> ArticulationActions:
self.computed_effort = self._predict_gru_effort(control_action, joint_pos, joint_vel)
# clip the computed effort based on the motor limits
self.applied_effort = self._clip_effort(self.computed_effort)
control_action.joint_efforts = self.applied_effort
control_action.joint_positions = None
control_action.joint_velocities = None
return control_action


class ActuatorNetGRUResidual(_GRUActuatorMixin, ImplicitActuator):
"""Implicit-PD actuator model with an added recurrent (GRU) residual effort.

This model behaves like an :class:`ImplicitActuator` -- the physics engine applies the PD
control using the configured stiffness and damping -- but augments the feed-forward effort
term with a *residual* effort [N·m or N, depending on joint type] predicted by a recurrent
(GRU) network. The residual is added to any existing feed-forward effort, and the approximate
total effort is stored for reward computation while the desired joint positions and velocities
are preserved so the engine can compute the PD term.

Note:
As with any :class:`ImplicitActuator`, the effort actually applied by the engine is the
feed-forward effort plus the engine-side PD term, and it is bounded by the simulation
effort limit (``effort_limit_sim``) rather than by :meth:`~isaaclab.actuators.ActuatorBase._clip_effort`
(which only populates the reported :attr:`applied_effort`). Set ``effort_limit_sim`` to a
finite value to bound the residual feed-forward. The hidden state is cleared only by
:meth:`reset`; reset the relevant environments on episode boundaries (and after any control
gap) to avoid stale recurrent context.
"""

cfg: ActuatorNetGRUResidualCfg
"""The configuration of the actuator model."""

def __init__(self, cfg: ActuatorNetGRUResidualCfg, *args, **kwargs):
super().__init__(cfg, *args, **kwargs)
self._init_gru_runtime()

"""
Operations.
"""

def reset(self, env_ids: Sequence[int]):
super().reset(env_ids)
self._reset_gru_state(env_ids)

def compute(
self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor
) -> ArticulationActions:
# add the GRU residual to the feed-forward effort
residual = self._predict_gru_effort(control_action, joint_pos, joint_vel)
if control_action.joint_efforts is None:
control_action.joint_efforts = residual
else:
control_action.joint_efforts = control_action.joint_efforts + residual

# approximate total effort for reward telemetry (engine applies the PD term)
error_pos = control_action.joint_positions - joint_pos
if control_action.joint_velocities is not None:
error_vel = control_action.joint_velocities - joint_vel
else:
error_vel = -joint_vel
self.computed_effort = self.stiffness * error_pos + self.damping * error_vel + control_action.joint_efforts
self.applied_effort = self._clip_effort(self.computed_effort)
# positions/velocities are preserved so the engine computes the PD term
return control_action


class ActuatorNetMLP(DCMotor):
"""Actuator model based on multi-layer perceptron and joint history.

Expand Down
Loading
Loading