perf(dex-hand): run retargeting optimizer under torch.no_grad#478
perf(dex-hand): run retargeting optimizer under torch.no_grad#478
Conversation
DexHandRetargeter._compute_hand wrapped the per-frame call to ``self._dex_hand.retarget(...)`` in ``torch.enable_grad()`` and ``torch.inference_mode(False)`` — every step paid for autograd bookkeeping the dex_retargeting QP solver does not consume. Switch to ``torch.no_grad()`` so forward ops in the optimizer skip grad tracking entirely. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThe PR modifies the Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Pull request overview
This PR optimizes the dex-hand retargeting hot path by disabling PyTorch autograd tracking around the dex_retargeting solver call, aiming to reduce per-frame overhead in DexHandRetargeter.
Changes:
- Replace
torch.enable_grad()+torch.inference_mode(False)withtorch.no_grad()aroundself._dex_hand.retarget(...). - Add an explanatory comment documenting the rationale for disabling grad tracking.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # ``dex_retargeting`` solves a QP-style optimization that does not | ||
| # require autograd; running under ``torch.no_grad()`` avoids the | ||
| # per-step grad-tracking overhead the previous ``enable_grad`` / | ||
| # ``inference_mode(False)`` context was incurring on every frame. | ||
| try: | ||
| import torch # type: ignore | ||
|
|
||
| with torch.enable_grad(), torch.inference_mode(False): | ||
| with torch.no_grad(): | ||
| return self._dex_hand.retarget(ref_value) # type: ignore |
Copilot review on #478 flagged that switching the per-frame context from ``torch.enable_grad(), torch.inference_mode(False)`` to ``torch.no_grad()`` alone drops the explicit opt-out from any outer ``torch.inference_mode()``. Some in-place / view ops in dex_retargeting can error under inference mode, so a caller wrapping the session in ``torch.inference_mode()`` would now break. Combine both: ``torch.inference_mode(False), torch.no_grad()`` keeps the escape from inference mode (preserving prior behaviour) while still skipping autograd bookkeeping (the actual perf win). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
DexHandRetargeter._compute_handwas wrapping every call toself._dex_hand.retarget(...)intorch.enable_grad()+torch.inference_mode(False). The dex_retargeting QP solver does not consume those gradients, so each frame paid for autograd bookkeeping that was immediately discarded.torch.no_grad()so forward tensor ops inside the optimizer skip grad tracking entirely. The optimizer is unchanged; only the surrounding context flips.Why this matters
At 90 Hz with two
DexHandRetargeterinstances (bimanual), the per-frame cost adds up. Disabling grad tracking on the hot path is a free win unless dex_retargeting internally requires autograd (which it does not — it's a NLOpt-style optimization).Risk
The previous explicit
enable_grad()+inference_mode(False)reads as defensive escape from an outer no-grad context, but dex_retargeting'sretarget()does not differentiate through any tensor operation, so running underno_gradis safe. If any downstream code actually relies on graph capture fromretarget(...)(none in this repo), CI will flag it.Test plan
ctestretargeting suite green🤖 Generated with Claude Code
Summary by CodeRabbit