diff --git a/swift/metrics/utils.py b/swift/metrics/utils.py index e2e0f53c32..4893f515b0 100644 --- a/swift/metrics/utils.py +++ b/swift/metrics/utils.py @@ -101,7 +101,7 @@ def update(self, state: torch.Tensor): def compute(self): if dist.is_initialized(): - tensor = torch.tensor([self.state, self.count], device=self.device) + tensor = torch.tensor([self.state, self.count], dtype=torch.float32, device=self.device) dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=self.group) self.state, self.count = tensor[0].item(), int(tensor[1].item()) if self.count == 0: