From 5437b08753bff548a75be55e15204da62825928c Mon Sep 17 00:00:00 2001 From: Haozhe Zhang Date: Fri, 12 Jun 2026 09:52:24 -0700 Subject: [PATCH 1/2] [bugfix] pin MeanMetric all_reduce tensor to float64 to avoid cross-rank dtype mismatch MeanMetric.compute() built the reduction tensor with an inferred dtype: `torch.tensor([self.state, self.count], device=...)`. When sequence-level `compute_acc()` accumulates `np.all(...)` (a numpy.bool_), it promotes self.state from python float to numpy.float64 on those ranks, so torch infers float64 there but float32 on ranks that stayed python float. The ranks then enter `dist.all_reduce()` with mismatched dtypes and the collective hangs (#9550). Pin the tensor to float64 so every rank reduces with the same dtype regardless of self.state's runtime type; float64 also keeps count exact for large totals. --- swift/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/metrics/utils.py b/swift/metrics/utils.py index e2e0f53c32..f39b6d1380 100644 --- a/swift/metrics/utils.py +++ b/swift/metrics/utils.py @@ -101,7 +101,7 @@ def update(self, state: torch.Tensor): def compute(self): if dist.is_initialized(): - tensor = torch.tensor([self.state, self.count], device=self.device) + tensor = torch.tensor([self.state, self.count], dtype=torch.float64, device=self.device) dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=self.group) self.state, self.count = tensor[0].item(), int(tensor[1].item()) if self.count == 0: From 9823a8f8387a858ba7cc4d40187ed5fc8bad729d Mon Sep 17 00:00:00 2001 From: Hz_Zhang <47402297+HaozheZhang6@users.noreply.github.com> Date: Sat, 13 Jun 2026 01:24:02 -0700 Subject: [PATCH 2/2] use float32 for all_reduce tensor (some backends e.g. NPU/HCCL don't support float64) --- swift/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/metrics/utils.py b/swift/metrics/utils.py index f39b6d1380..4893f515b0 100644 --- a/swift/metrics/utils.py +++ b/swift/metrics/utils.py @@ -101,7 +101,7 @@ def update(self, state: torch.Tensor): def compute(self): if dist.is_initialized(): - tensor = torch.tensor([self.state, self.count], dtype=torch.float64, device=self.device) + tensor = torch.tensor([self.state, self.count], dtype=torch.float32, device=self.device) dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=self.group) self.state, self.count = tensor[0].item(), int(tensor[1].item()) if self.count == 0: