Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aqt/jax/v2/flax/aqt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ def dummy_tiled_dg(lhs_in, rhs_in):
dg_eqn = [eqn for eqn in tiled_dg_jaxpr.eqns if 'dot_general' in str(eqn)]
assert len(dg_eqn) == 1, 'Multiple dg calls are found in tiled dg jaxpr.'
lhs_invar, rhs_invar = dg_eqn[0].invars
tiled_lhs_shape = lhs_invar.aval.shape
tiled_rhs_shape = rhs_invar.aval.shape
tiled_lhs_shape = lhs_invar.aval.shape # type: ignore
tiled_rhs_shape = rhs_invar.aval.shape # type: ignore
tiled_dimension_numbers = dg_eqn[0].params['dimension_numbers']
# Use tiled input shapes and dimension numbers to create aqt_dg that
# will be injected to tiled_dot_general
Expand Down