diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f3acec574b..69ad813a2c 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3776,6 +3776,18 @@ bool GatherQMM::is_equivalent(const Primitive& other) const { mode_ == qm_other.mode_ && transpose_ == qm_other.transpose_; } +std::vector GatherQMM::output_shapes(const std::vector& inputs) { + const auto& x = inputs[0]; + const auto& w = inputs[1]; + const auto& lhs_indices = + (mode_ == QuantizationMode::Affine) ? inputs[4] : inputs[3]; + int w_outer = transpose_ ? w.shape(-2) : w.shape(-1) * 32 / bits_; + auto out_shape = lhs_indices.shape(); + out_shape.push_back(x.shape(-2)); + out_shape.push_back(w_outer); + return {out_shape}; +} + std::pair, std::vector> RandomBits::vmap( const std::vector& inputs, const std::vector& axes) { @@ -5882,6 +5894,16 @@ bool GatherMM::is_equivalent(const Primitive& other) const { right_sorted_ == g_other.right_sorted_; } +std::vector GatherMM::output_shapes(const std::vector& inputs) { + const auto& a = inputs[0]; + const auto& b = inputs[1]; + const auto& lhs_indices = inputs[2]; + auto out_shape = lhs_indices.shape(); + out_shape.push_back(a.shape(-2)); + out_shape.push_back(b.shape(-1)); + return {out_shape}; +} + bool BlockMaskedMM::is_equivalent(const Primitive& other) const { const BlockMaskedMM& a_other = static_cast(other); return (block_size_ == a_other.block_size_); diff --git a/mlx/primitives.h b/mlx/primitives.h index 75fb978dce..5b8517c56d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -543,6 +543,7 @@ class GatherMM : public UnaryPrimitive { DEFINE_NAME(GatherMM) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; auto state() const { return std::make_pair(left_sorted_, right_sorted_); } @@ -1700,6 +1701,7 @@ class GatherQMM : public UnaryPrimitive { DEFINE_GRADS() DEFINE_NAME(GatherQMM) bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; auto state() const { return std::make_tuple( group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 20f1145223..d39ddf3f18 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -504,6 +504,45 @@ def ones_fun(x): self.assertEqual(compiled_zero_like(y).shape, y_shape) self.assertEqual(compiled_ones_like(y).shape, y_shape) + def test_shapeless_compile_gather_qmm(self): + K, N, num_experts = 64, 32, 4 + + w = mx.random.normal((num_experts, N, K)) + qw, s, b = mx.quantize(w) + mx.eval(qw, s, b) + + idx = mx.array([0, 1, 2, 3]) + x4 = mx.ones((num_experts, 4, K)) + x8 = mx.ones((num_experts, 8, K)) + + def fn(x): + return mx.gather_qmm( + x, qw, s, b, lhs_indices=idx, rhs_indices=idx, transpose=True + ) + + cfn = mx.compile(fn, shapeless=True) + + self.assertEqual(cfn(x4).shape, fn(x4).shape) + self.assertEqual(cfn(x8).shape, fn(x8).shape) + + def test_shapeless_compile_gather_mm(self): + K, N, num_experts = 64, 32, 4 + + idx = mx.array([0, 1, 2, 3]) + b = mx.random.normal((num_experts, K, N)) + mx.eval(b) + + x4 = mx.ones((num_experts, 4, K)) + x8 = mx.ones((num_experts, 8, K)) + + def fn(x): + return mx.gather_mm(x, b, lhs_indices=idx, rhs_indices=idx) + + cfn = mx.compile(fn, shapeless=True) + + self.assertEqual(cfn(x4).shape, fn(x4).shape) + self.assertEqual(cfn(x8).shape, fn(x8).shape) + def test_compile_with_constant(self): # Test float @partial(mx.compile)