Skip to content

[bugfix] pin MeanMetric all_reduce tensor to float64 to avoid cross-rank dtype mismatch#9551

Open
HaozheZhang6 wants to merge 1 commit into
modelscope:mainfrom
HaozheZhang6:fix/meanmetric-allreduce-dtype
Open

[bugfix] pin MeanMetric all_reduce tensor to float64 to avoid cross-rank dtype mismatch#9551
HaozheZhang6 wants to merge 1 commit into
modelscope:mainfrom
HaozheZhang6:fix/meanmetric-allreduce-dtype

Conversation

@HaozheZhang6

Copy link
Copy Markdown

PR type

  • Bug Fix

PR information

Fixes #9550.

MeanMetric.compute() (swift/metrics/utils.py) builds the all-reduce tensor with an inferred dtype:

tensor = torch.tensor([self.state, self.count], device=self.device)

In sequence-level accuracy, compute_acc() appends np.all(...) (a numpy.bool_). Accumulating that into self.state promotes it from a python float to numpy.float64 on the ranks that take that path. torch.tensor([...]) then infers float64 on those ranks but float32 on ranks where self.state stayed a python float. The ranks enter dist.all_reduce() with mismatched dtypes and the collective hangs (matching the debug screenshots in #9550).

Pinning the tensor to float64 makes every rank reduce with the same dtype regardless of self.state's runtime type. float64 also keeps count exact for large totals (float32 loses integer precision above 2^24).

Experiment results

Reproducing the hang needs a multi-rank run; minimal demo of the dtype inference (one rank's state promoted to numpy.float64 via np.bool_, the other left a python float):

rank-A state type: float | rank-B state type: float64
CURRENT  -> rankA=torch.float32  rankB=torch.float64  all_reduce MATCH=False
FIX      -> rankA=torch.float64  rankB=torch.float64  all_reduce MATCH=True

…ank dtype mismatch

MeanMetric.compute() built the reduction tensor with an inferred dtype:
`torch.tensor([self.state, self.count], device=...)`. When sequence-level
`compute_acc()` accumulates `np.all(...)` (a numpy.bool_), it promotes
self.state from python float to numpy.float64 on those ranks, so torch infers
float64 there but float32 on ranks that stayed python float. The ranks then
enter `dist.all_reduce()` with mismatched dtypes and the collective hangs
(modelscope#9550).

Pin the tensor to float64 so every rank reduces with the same dtype regardless
of self.state's runtime type; float64 also keeps count exact for large totals.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the tensor creation in swift/metrics/utils.py to use torch.float64 during distributed all-reduce operations. Feedback suggests changing this to torch.float32 to prevent runtime crashes on hardware backends like NPUs (using HCCL) that do not support double precision for collective operations.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread swift/metrics/utils.py
def compute(self):
if dist.is_initialized():
tensor = torch.tensor([self.state, self.count], device=self.device)
tensor = torch.tensor([self.state, self.count], dtype=torch.float64, device=self.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.float64 for dist.all_reduce can cause runtime errors on certain accelerators and backends. Specifically, Huawei's HCCL (used for NPU execution, which this repository supports) does not support float64 (double) data types for collective operations.\n\nTo ensure compatibility across all supported hardware backends (including NPU/HCCL, GPU/NCCL, and CPU/Gloo), it is safer to use torch.float32. While float32 has a precision limit of 2^24 (approx. 16.7M) for exact integers, this is typically more than sufficient for evaluation sample/token counts, and any potential precision loss beyond that is negligible compared to a hard runtime crash on NPU.

Suggested change
tensor = torch.tensor([self.state, self.count], dtype=torch.float64, device=self.device)
tensor = torch.tensor([self.state, self.count], dtype=torch.float32, device=self.device)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Potential dtype inconsistency in distributed metric reduction

1 participant