Skip to content

mx.fast.scaled_dot_product_attention produces incorrect attention scores (MLX 0.31.2) #3585

@coldbugs

Description

@coldbugs

Environment

  • MLX 0.31.2
  • macOS, Apple M4
  • Python 3.12

Summary

mx.fast.scaled_dot_product_attention returns mathematically incorrect results compared to manual attention, causing complete model failure (0% accuracy). The manual fallback produces correct results (99.5% accuracy), matching a Candle/Rust implementation with identical weights.

Model config

  • seq_len=82 (81 Sudoku tokens + 1 puzzle embedding), hidden=512, heads=8, head_dim=64
  • 4 H-layers + 4 L-layers, H_cycles=2, L_cycles=2
  • RoPE (base=10000) with pre-computed cos/sin tables shape [82, 64]
  • No GQA (num_heads == num_kv_heads)
  • Weights: sapientinc/HRM-checkpoint-sudoku-extreme

Behavior

MPS fused kernel: model produces garbage output, 0% accuracy on Sudoku given cells. All 16 ACT steps run (q-values never converge).

Manual attention: model produces correct Sudoku solutions, 99.5% accuracy (214/215 given cells across 10 puzzles), halts in 2-7 steps.

Reproduction

The fix was replacing the MPS fused call with the manual fallback that already exists in the code:

# Broken:
return mx.fast.scaled_dot_product_attention(query, key, value, scale=scale)

# Fixed:
q = mx.transpose(query, (0, 2, 1, 3))
k = mx.transpose(key, (0, 2, 1, 3))
v = mx.transpose(value, (0, 2, 1, 3))
scores = q @ mx.transpose(k, (0, 1, 3, 2)) * scale
attn_weights = mx.softmax(scores, axis=-1)
out = attn_weights @ v
return mx.transpose(out, (0, 2, 1, 3))

Full investigation log with side-by-side comparison against Candle/Rust available on request.

Workaround

Use manual attention unconditionally — the MPS fused kernel produces wrong attention weights for this model configuration, cascading into zero accuracy.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions