diff --git a/aqt/jax/v2/flax/aqt_flax.py b/aqt/jax/v2/flax/aqt_flax.py index 6a0989fd..fe58fe3b 100644 --- a/aqt/jax/v2/flax/aqt_flax.py +++ b/aqt/jax/v2/flax/aqt_flax.py @@ -540,8 +540,8 @@ def dummy_tiled_dg(lhs_in, rhs_in): dg_eqn = [eqn for eqn in tiled_dg_jaxpr.eqns if 'dot_general' in str(eqn)] assert len(dg_eqn) == 1, 'Multiple dg calls are found in tiled dg jaxpr.' lhs_invar, rhs_invar = dg_eqn[0].invars - tiled_lhs_shape = lhs_invar.aval.shape - tiled_rhs_shape = rhs_invar.aval.shape + tiled_lhs_shape = lhs_invar.aval.shape # type: ignore + tiled_rhs_shape = rhs_invar.aval.shape # type: ignore tiled_dimension_numbers = dg_eqn[0].params['dimension_numbers'] # Use tiled input shapes and dimension numbers to create aqt_dg that # will be injected to tiled_dot_general