-
Notifications
You must be signed in to change notification settings - Fork 3.6k
feat(actuators): add GRU and GRU-residual network actuator models #6083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| Added | ||
| ^^^^^ | ||
|
|
||
| * Added :class:`~isaaclab.actuators.ActuatorNetGRU` and | ||
| :class:`~isaaclab.actuators.ActuatorNetGRUCfg`, an explicit actuator whose GRU | ||
| network predicts the total joint effort from the joint position, position error, and | ||
| velocity, with optional input and output normalization. | ||
| * Added :class:`~isaaclab.actuators.ActuatorNetGRUResidual` and | ||
| :class:`~isaaclab.actuators.ActuatorNetGRUResidualCfg`, an implicit-PD actuator that | ||
| adds a GRU-predicted residual feed-forward effort, with optional input and output | ||
| normalization. |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -9,11 +9,13 @@ | |||||||||
|
|
||||||||||
| * Multi-Layer Perceptron (MLP) | ||||||||||
| * Long Short-Term Memory (LSTM) | ||||||||||
| * Gated Recurrent Unit (GRU), both explicit full-torque and implicit-PD residual variants | ||||||||||
|
|
||||||||||
| """ | ||||||||||
|
|
||||||||||
| from __future__ import annotations | ||||||||||
|
|
||||||||||
| import logging | ||||||||||
| from collections.abc import Sequence | ||||||||||
| from typing import TYPE_CHECKING | ||||||||||
|
|
||||||||||
|
|
@@ -22,10 +24,17 @@ | |||||||||
| from isaaclab.utils.assets import read_file | ||||||||||
| from isaaclab.utils.types import ArticulationActions | ||||||||||
|
|
||||||||||
| from .actuator_pd import DCMotor | ||||||||||
| from .actuator_pd import DCMotor, IdealPDActuator, ImplicitActuator | ||||||||||
|
|
||||||||||
| if TYPE_CHECKING: | ||||||||||
| from .actuator_net_cfg import ActuatorNetLSTMCfg, ActuatorNetMLPCfg | ||||||||||
| from .actuator_net_cfg import ( | ||||||||||
| ActuatorNetGRUCfg, | ||||||||||
| ActuatorNetGRUResidualCfg, | ||||||||||
| ActuatorNetLSTMCfg, | ||||||||||
| ActuatorNetMLPCfg, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| logger = logging.getLogger(__name__) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ActuatorNetLSTM(DCMotor): | ||||||||||
|
|
@@ -98,6 +107,242 @@ def compute( | |||||||||
| return control_action | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class _GRUActuatorMixin: | ||||||||||
| """Shared machinery for the GRU-based actuator models. | ||||||||||
|
|
||||||||||
| Loads the TorchScript GRU network, allocates the recurrent input and hidden-state buffers, and | ||||||||||
| runs inference. The network consumes a fixed input of joint position, position error, and | ||||||||||
| velocity. An optional ``(mean, std)`` normalization may be applied to each input and to the | ||||||||||
| output (``None`` selects the identity transform). The concrete actuator classes combine this | ||||||||||
| mixin with an explicit (:class:`IdealPDActuator`) or implicit (:class:`ImplicitActuator`) base | ||||||||||
| to define their effort semantics. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| # number of fixed network inputs: [position, position_error, velocity] | ||||||||||
| _NUM_INPUTS = 3 | ||||||||||
| # standard-deviation floor used when normalizing to avoid division by tiny values | ||||||||||
| _GRU_STD_FLOOR = 1.0e-8 | ||||||||||
|
|
||||||||||
| def _init_gru_runtime(self) -> None: | ||||||||||
| """Load the network and allocate the GRU buffers and normalization statistics. | ||||||||||
|
|
||||||||||
| Raises: | ||||||||||
| ValueError: If the TorchScript module does not expose a ``.gru`` submodule, or if its | ||||||||||
| input dimension is not 3 (joint position, position error, and velocity). | ||||||||||
| """ | ||||||||||
| # load the TorchScript network | ||||||||||
| file_bytes = read_file(self.cfg.network_file) | ||||||||||
| self.network = torch.jit.load(file_bytes, map_location=self._device).eval() | ||||||||||
| if not hasattr(self.network, "gru"): | ||||||||||
| raise ValueError(f"The network file '{self.cfg.network_file}' must expose a TorchScript '.gru' submodule.") | ||||||||||
|
|
||||||||||
| # infer dimensions from the GRU weights (the input is [position, position_error, velocity]) | ||||||||||
| gru_state = self.network.gru.state_dict() | ||||||||||
| if any("reverse" in key for key in gru_state): | ||||||||||
| raise ValueError( | ||||||||||
| f"The network file '{self.cfg.network_file}' uses a bidirectional GRU, which is not supported." | ||||||||||
| ) | ||||||||||
| input_dim = int(gru_state["weight_ih_l0"].shape[1]) | ||||||||||
| hidden_dim = int(gru_state["weight_hh_l0"].shape[1]) | ||||||||||
| num_layers = sum(1 for key in gru_state if key.startswith("weight_ih_l") and "reverse" not in key) | ||||||||||
| if input_dim != self._NUM_INPUTS: | ||||||||||
| raise ValueError( | ||||||||||
| f"The network file '{self.cfg.network_file}' must take {self._NUM_INPUTS} inputs (joint position," | ||||||||||
| f" position error, and velocity), but its GRU expects {input_dim}." | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # resolve (mean, std) normalization for the inputs and output (identity when unset) | ||||||||||
| self._position_norm = self._resolve_normalization(self.cfg.position_normalization, "position_normalization") | ||||||||||
| self._pos_error_norm = self._resolve_normalization(self.cfg.pos_error_normalization, "pos_error_normalization") | ||||||||||
| self._vel_norm = self._resolve_normalization(self.cfg.vel_normalization, "vel_normalization") | ||||||||||
| self._output_norm = self._resolve_normalization(self.cfg.output_normalization, "output_normalization") | ||||||||||
|
|
||||||||||
| # recurrent input and hidden-state buffers | ||||||||||
| batch = self._num_envs * self.num_joints | ||||||||||
| self.sea_input = torch.zeros(batch, 1, self._NUM_INPUTS, device=self._device) | ||||||||||
| self.sea_hidden_state = torch.zeros(num_layers, batch, hidden_dim, device=self._device) | ||||||||||
| # per-env view for resets (shares storage) | ||||||||||
| self.sea_hidden_state_per_env = self.sea_hidden_state.view( | ||||||||||
| num_layers, self._num_envs, self.num_joints, hidden_dim | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def _resolve_normalization(self, stats: tuple[float, float] | None, name: str) -> tuple[float, float]: | ||||||||||
| """Return the ``(mean, std)`` to apply, defaulting to identity and flooring the std. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| stats: The ``(mean, std)`` pair, or None for the identity transform. | ||||||||||
| name: The configuration field name, used for the warning message. | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| The resolved ``(mean, std)`` with the std floored to avoid division by tiny values. | ||||||||||
| """ | ||||||||||
| if stats is None: | ||||||||||
| return 0.0, 1.0 | ||||||||||
| mean, std = float(stats[0]), float(stats[1]) | ||||||||||
| if std < 0.0: | ||||||||||
| raise ValueError( | ||||||||||
| f"Actuator '{self.cfg.network_file}' has {name} std={std}; the standard deviation must be" | ||||||||||
| " non-negative. Check the (mean, std) ordering." | ||||||||||
| ) | ||||||||||
| if std < self._GRU_STD_FLOOR: | ||||||||||
| logger.warning( | ||||||||||
| "Actuator '%s' has %s std=%s below the floor %s; flooring it, which can amplify the" | ||||||||||
| " normalized values. Set a larger std or leave the field unset for identity.", | ||||||||||
| self.cfg.network_file, | ||||||||||
| name, | ||||||||||
| std, | ||||||||||
| self._GRU_STD_FLOOR, | ||||||||||
| ) | ||||||||||
| return mean, max(std, self._GRU_STD_FLOOR) | ||||||||||
|
|
||||||||||
| def _reset_gru_state(self, env_ids: Sequence[int]): | ||||||||||
| """Zero the GRU hidden state for the specified environments. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| env_ids: The environment indices whose hidden state should be reset. | ||||||||||
| """ | ||||||||||
| with torch.no_grad(): | ||||||||||
| self.sea_hidden_state_per_env[:, env_ids] = 0.0 | ||||||||||
|
|
||||||||||
| def _predict_gru_effort( | ||||||||||
| self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor | ||||||||||
| ) -> torch.Tensor: | ||||||||||
| """Assemble the network input, run inference, and return the denormalized effort. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| control_action: The joint action instance holding the desired joint positions. | ||||||||||
| joint_pos: The current joint positions. Shape is (num_envs, num_joints). | ||||||||||
| joint_vel: The current joint velocities. Shape is (num_envs, num_joints). | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| The predicted effort [N·m or N, depending on joint type]. Shape is | ||||||||||
| (num_envs, num_joints). | ||||||||||
|
|
||||||||||
| Raises: | ||||||||||
| ValueError: If ``control_action.joint_positions`` is None. | ||||||||||
| """ | ||||||||||
| if control_action.joint_positions is None: | ||||||||||
| raise ValueError("GRU actuator input requires control_action.joint_positions to be set.") | ||||||||||
| # normalized [position, position_error, velocity] inputs | ||||||||||
| position = joint_pos.flatten() | ||||||||||
| pos_error = (control_action.joint_positions - joint_pos).flatten() | ||||||||||
| velocity = joint_vel.flatten() | ||||||||||
| self.sea_input[:, 0, 0] = (position - self._position_norm[0]) / self._position_norm[1] | ||||||||||
| self.sea_input[:, 0, 1] = (pos_error - self._pos_error_norm[0]) / self._pos_error_norm[1] | ||||||||||
| self.sea_input[:, 0, 2] = (velocity - self._vel_norm[0]) / self._vel_norm[1] | ||||||||||
|
|
||||||||||
| # run inference, then denormalize and guard against a non-finite output | ||||||||||
| with torch.inference_mode(): | ||||||||||
| output, self.sea_hidden_state[:] = self.network(self.sea_input, self.sea_hidden_state) | ||||||||||
|
Comment on lines
+211
to
+236
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in e1dea0f — There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Excellent — raising on negative std while flooring the near-zero case is the right split. Regression test seals it. 👍 |
||||||||||
| output = output * self._output_norm[1] + self._output_norm[0] | ||||||||||
| # a non-finite prediction carries no usable actuation, so command zero effort this step | ||||||||||
| output = torch.nan_to_num(output, nan=0.0, posinf=0.0, neginf=0.0) | ||||||||||
| return output.reshape(self._num_envs, self.num_joints) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ActuatorNetGRU(_GRUActuatorMixin, IdealPDActuator): | ||||||||||
| """Explicit actuator model based on a recurrent neural network (GRU). | ||||||||||
|
|
||||||||||
| The GRU network predicts the *total* joint effort [N·m or N, depending on joint type] from the | ||||||||||
| joint position, position error, and velocity. Unlike the analytical models, no PD gains are | ||||||||||
| applied; the hidden state of the recurrent network captures the actuator history. The predicted | ||||||||||
| effort is clipped to the actuator's effort limit via :meth:`~isaaclab.actuators.ActuatorBase._clip_effort`. | ||||||||||
|
|
||||||||||
| This model derives from :class:`IdealPDActuator`, whose simple symmetric ``±effort_limit`` | ||||||||||
| saturation matches a learned total-torque source without requiring the velocity-dependent | ||||||||||
| torque-speed parameters of a DC motor. | ||||||||||
|
|
||||||||||
| Note: | ||||||||||
| The recurrent hidden state encodes the actuator history and is only cleared by | ||||||||||
| :meth:`reset`. Callers must reset the relevant environments on episode boundaries | ||||||||||
| (and after any control gap, e.g. a hardware reconnect) so the first post-reset effort is | ||||||||||
| not computed against stale temporal context. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| cfg: ActuatorNetGRUCfg | ||||||||||
| """The configuration of the actuator model.""" | ||||||||||
|
|
||||||||||
| def __init__(self, cfg: ActuatorNetGRUCfg, *args, **kwargs): | ||||||||||
| super().__init__(cfg, *args, **kwargs) | ||||||||||
| self._init_gru_runtime() | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| Operations. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def reset(self, env_ids: Sequence[int]): | ||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔵 Suggestion: Consider calling
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 63903b9 — There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the fix! 👍 |
||||||||||
| super().reset(env_ids) | ||||||||||
| self._reset_gru_state(env_ids) | ||||||||||
|
|
||||||||||
|
Comment on lines
+271
to
+276
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in e1dea0f — the denormalization and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the fix! Confirmed — having everything under the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the quick fix! Verified that the denormalization and |
||||||||||
| def compute( | ||||||||||
| self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor | ||||||||||
| ) -> ArticulationActions: | ||||||||||
| self.computed_effort = self._predict_gru_effort(control_action, joint_pos, joint_vel) | ||||||||||
| # clip the computed effort based on the motor limits | ||||||||||
| self.applied_effort = self._clip_effort(self.computed_effort) | ||||||||||
| control_action.joint_efforts = self.applied_effort | ||||||||||
| control_action.joint_positions = None | ||||||||||
| control_action.joint_velocities = None | ||||||||||
| return control_action | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ActuatorNetGRUResidual(_GRUActuatorMixin, ImplicitActuator): | ||||||||||
| """Implicit-PD actuator model with an added recurrent (GRU) residual effort. | ||||||||||
|
|
||||||||||
| This model behaves like an :class:`ImplicitActuator` -- the physics engine applies the PD | ||||||||||
| control using the configured stiffness and damping -- but augments the feed-forward effort | ||||||||||
| term with a *residual* effort [N·m or N, depending on joint type] predicted by a recurrent | ||||||||||
| (GRU) network. The residual is added to any existing feed-forward effort, and the approximate | ||||||||||
| total effort is stored for reward computation while the desired joint positions and velocities | ||||||||||
| are preserved so the engine can compute the PD term. | ||||||||||
|
|
||||||||||
| Note: | ||||||||||
| As with any :class:`ImplicitActuator`, the effort actually applied by the engine is the | ||||||||||
| feed-forward effort plus the engine-side PD term, and it is bounded by the simulation | ||||||||||
| effort limit (``effort_limit_sim``) rather than by :meth:`~isaaclab.actuators.ActuatorBase._clip_effort` | ||||||||||
| (which only populates the reported :attr:`applied_effort`). Set ``effort_limit_sim`` to a | ||||||||||
| finite value to bound the residual feed-forward. The hidden state is cleared only by | ||||||||||
| :meth:`reset`; reset the relevant environments on episode boundaries (and after any control | ||||||||||
| gap) to avoid stale recurrent context. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| cfg: ActuatorNetGRUResidualCfg | ||||||||||
| """The configuration of the actuator model.""" | ||||||||||
|
|
||||||||||
| def __init__(self, cfg: ActuatorNetGRUResidualCfg, *args, **kwargs): | ||||||||||
| super().__init__(cfg, *args, **kwargs) | ||||||||||
| self._init_gru_runtime() | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| Operations. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| def reset(self, env_ids: Sequence[int]): | ||||||||||
| super().reset(env_ids) | ||||||||||
| self._reset_gru_state(env_ids) | ||||||||||
|
|
||||||||||
| def compute( | ||||||||||
| self, control_action: ArticulationActions, joint_pos: torch.Tensor, joint_vel: torch.Tensor | ||||||||||
| ) -> ArticulationActions: | ||||||||||
| # add the GRU residual to the feed-forward effort | ||||||||||
| residual = self._predict_gru_effort(control_action, joint_pos, joint_vel) | ||||||||||
| if control_action.joint_efforts is None: | ||||||||||
| control_action.joint_efforts = residual | ||||||||||
| else: | ||||||||||
| control_action.joint_efforts = control_action.joint_efforts + residual | ||||||||||
|
|
||||||||||
| # approximate total effort for reward telemetry (engine applies the PD term) | ||||||||||
| error_pos = control_action.joint_positions - joint_pos | ||||||||||
| if control_action.joint_velocities is not None: | ||||||||||
| error_vel = control_action.joint_velocities - joint_vel | ||||||||||
| else: | ||||||||||
| error_vel = -joint_vel | ||||||||||
| self.computed_effort = self.stiffness * error_pos + self.damping * error_vel + control_action.joint_efforts | ||||||||||
| self.applied_effort = self._clip_effort(self.computed_effort) | ||||||||||
| # positions/velocities are preserved so the engine computes the PD term | ||||||||||
| return control_action | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class ActuatorNetMLP(DCMotor): | ||||||||||
| """Actuator model based on multi-layer perceptron and joint history. | ||||||||||
|
|
||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Warning: If
stdis passed as0.0(e.g., from a misconfigured normalization config), this silently floors to1e-8, which can make the normalized inputs explode to ~O(1e8). Consider logging a warning when the inputstdis below the floor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 63903b9 —
_resolve_normalizationnow logs a warning when a configuredstdis below the floor (and identifies which normalization field), since a near-zero std amplifies the normalized values.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the warning with field identification is exactly what I had in mind. LGTM ✅