diff --git a/aqt/jax/v2/config.py b/aqt/jax/v2/config.py index c590b987..25fe487c 100644 --- a/aqt/jax/v2/config.py +++ b/aqt/jax/v2/config.py @@ -668,9 +668,9 @@ def config_v3( rng_type: str = 'jax.uniform', # 'custom-1' dlhs_local_aqt: None | LocalAqt = None, drhs_local_aqt: None | LocalAqt = None, - fwd_accumulator_dtype: ... = jnp.int32, - dlhs_accumulator_dtype: ... = jnp.int32, - drhs_accumulator_dtype: ... = None, + fwd_accumulator_dtype=jnp.int32, + dlhs_accumulator_dtype=jnp.int32, + drhs_accumulator_dtype=None, ) -> DotGeneral: """Fully Quantized Training.""" fwd = dot_general_raw_make(