[bugfix] pin MeanMetric all_reduce tensor to float64 to avoid cross-rank dtype mismatch#9551
Conversation
…ank dtype mismatch MeanMetric.compute() built the reduction tensor with an inferred dtype: `torch.tensor([self.state, self.count], device=...)`. When sequence-level `compute_acc()` accumulates `np.all(...)` (a numpy.bool_), it promotes self.state from python float to numpy.float64 on those ranks, so torch infers float64 there but float32 on ranks that stayed python float. The ranks then enter `dist.all_reduce()` with mismatched dtypes and the collective hangs (modelscope#9550). Pin the tensor to float64 so every rank reduces with the same dtype regardless of self.state's runtime type; float64 also keeps count exact for large totals.
There was a problem hiding this comment.
Code Review
This pull request updates the tensor creation in swift/metrics/utils.py to use torch.float64 during distributed all-reduce operations. Feedback suggests changing this to torch.float32 to prevent runtime crashes on hardware backends like NPUs (using HCCL) that do not support double precision for collective operations.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def compute(self): | ||
| if dist.is_initialized(): | ||
| tensor = torch.tensor([self.state, self.count], device=self.device) | ||
| tensor = torch.tensor([self.state, self.count], dtype=torch.float64, device=self.device) |
There was a problem hiding this comment.
Using torch.float64 for dist.all_reduce can cause runtime errors on certain accelerators and backends. Specifically, Huawei's HCCL (used for NPU execution, which this repository supports) does not support float64 (double) data types for collective operations.\n\nTo ensure compatibility across all supported hardware backends (including NPU/HCCL, GPU/NCCL, and CPU/Gloo), it is safer to use torch.float32. While float32 has a precision limit of 2^24 (approx. 16.7M) for exact integers, this is typically more than sufficient for evaluation sample/token counts, and any potential precision loss beyond that is negligible compared to a hard runtime crash on NPU.
| tensor = torch.tensor([self.state, self.count], dtype=torch.float64, device=self.device) | |
| tensor = torch.tensor([self.state, self.count], dtype=torch.float32, device=self.device) |
PR type
PR information
Fixes #9550.
MeanMetric.compute()(swift/metrics/utils.py) builds the all-reduce tensor with an inferred dtype:In sequence-level accuracy,
compute_acc()appendsnp.all(...)(anumpy.bool_). Accumulating that intoself.statepromotes it from a pythonfloattonumpy.float64on the ranks that take that path.torch.tensor([...])then infers float64 on those ranks but float32 on ranks whereself.statestayed a python float. The ranks enterdist.all_reduce()with mismatched dtypes and the collective hangs (matching the debug screenshots in #9550).Pinning the tensor to
float64makes every rank reduce with the same dtype regardless ofself.state's runtime type. float64 also keepscountexact for large totals (float32 loses integer precision above 2^24).Experiment results
Reproducing the hang needs a multi-rank run; minimal demo of the dtype inference (one rank's state promoted to
numpy.float64vianp.bool_, the other left a python float):