diff --git a/src/retargeters/dex_hand_retargeter.py b/src/retargeters/dex_hand_retargeter.py index b056de146..ad4d38db4 100644 --- a/src/retargeters/dex_hand_retargeter.py +++ b/src/retargeters/dex_hand_retargeter.py @@ -412,10 +412,16 @@ def _compute_hand(self, poses: Dict[str, np.ndarray]) -> np.ndarray: ref_value = target_pos[task_indices, :] - target_pos[origin_indices, :] # 4. Run optimizer + # ``dex_retargeting`` solves a QP-style optimization that does not + # require autograd. ``torch.no_grad()`` avoids the per-step + # grad-tracking overhead the previous ``enable_grad`` context paid for. + # ``torch.inference_mode(False)`` is preserved so callers running + # inside an outer ``torch.inference_mode()`` (where some in-place / + # view ops can error) still see the optimizer execute in normal mode. try: import torch # type: ignore - with torch.enable_grad(), torch.inference_mode(False): + with torch.inference_mode(False), torch.no_grad(): return self._dex_hand.retarget(ref_value) # type: ignore except Exception as e: logger.error(f"Error in retargeting: {e}")