Skip to content

Potential dtype inconsistency in distributed metric reduction #9550

@xiaofu2730

Description

@xiaofu2730

Checklist / 检查清单

  • I have searched existing issues, and this is a new bug report. / 我已经搜索过现有的 issues,确认这是一个新的 bug report。

Bug Description / Bug 描述

Description

While training a multimodal (text + audio) model with distributed training, I encountered a hang during metric synchronization.

After adding debug logs, I found that different ranks were creating reduction tensors with different dtypes immediately before dist.all_reduce(), which eventually caused the collective communication to fail and block.

Debug screenshots are attached below.

Image

Investigation

After tracing the metric pipeline, the issue appears to involve the interaction between:

  • compute_acc() in metrics/acc.py
  • MeanMetric.compute() in metrics/utils.py

In sequence-level accuracy mode, compute_acc() appends the result of np.all(...) directly:

acc_list.append(np.all(preds[i, m] == labels[i, m]))

Since np.all() returns numpy.bool_, NumPy scalar types may propagate through metric accumulation and eventually affect the runtime type of self.state inside MeanMetric.

Later, MeanMetric.compute() constructs the reduction tensor via:

tensor = torch.tensor([self.state, self.count], device=self.device)

Because the dtype is inferred dynamically, the resulting tensor dtype depends on the actual runtime type of self.state.

From my debugging logs, some ranks entered the collective operation with:

torch.float32

while others entered with:

torch.float64

which is consistent with the observed distributed synchronization failure.

Expected Behavior

Metric reduction should use a consistent tensor dtype across all ranks, regardless of whether intermediate metric values originate from Python scalars or NumPy scalars.

Additional Information

Attached screenshots show:

  1. The dtype mismatch observed across ranks before dist.all_reduce().
  2. The intermediate metric values involved in the dtype divergence.

I have investigated the issue locally and prepared a potential fix together with a small regression test. If the analysis above looks reasonable, I'd be happy to open a PR for discussion.

How to Reproduce / 如何复现

The issue was observed during distributed training of a multimodal (text + audio) model using multiple GPUs.

  1. Launch distributed training/evaluation with sequence-level accuracy metrics enabled.
  2. Run evaluation on a multimodal dataset including text and audio.
  3. Add debug logs before MeanMetric.compute() performs dist.all_reduce().
  4. Observe that some ranks construct reduction tensors with torch.float32 while others construct torch.float64.
  5. The subsequent dist.all_reduce() may hang or fail due to the dtype mismatch across ranks.

Minimal Reproducer

import numpy as np
import torch

state = 0.

acc_list = [
    np.bool_(True),
    np.bool_(False),
]

state += sum(acc_list)

print(type(state))
# <class 'numpy.float64'>

tensor = torch.tensor([state, 1])
print(tensor.dtype)
# torch.float64

In contrast, using Python bool values produces a different inferred dtype:

state = 0.

acc_list = [
    True,
    False,
]

state += sum(acc_list)

print(type(state))
# <class 'float'>

tensor = torch.tensor([state, 1])
print(tensor.dtype)
# torch.float32

Additional Information / 补充信息

I have investigated the issue locally and prepared a potential fix. If the analysis above looks reasonable, I'd be happy to open a PR for discussion.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions