diff --git a/aqt/jax/v2/aqt_tensor.py b/aqt/jax/v2/aqt_tensor.py index ba82194a..ad7de7f2 100644 --- a/aqt/jax/v2/aqt_tensor.py +++ b/aqt/jax/v2/aqt_tensor.py @@ -137,15 +137,17 @@ def quant(self, x): for b in self.bias: qvalue += b + new_scale = [] for s in self.scale: # TODO(lew): We could store s_inv for faster activation quantization. - s_inv = jax.lax.reciprocal(s) - s_inv = jnp.where(jnp.isinf(s_inv), jnp.ones_like(s_inv), s_inv) + orig_inv = jax.lax.reciprocal(s) + s_inv = jnp.where(jnp.isinf(orig_inv), jnp.ones_like(orig_inv), orig_inv) + new_scale.append(jnp.where(jnp.isinf(orig_inv), jnp.ones_like(s), s)) qvalue *= s_inv # TODO(lew): We should apply numerics here, so that 'quant' function # Can be considered a part of API. - return self.replace(qvalue=qvalue) # pytype: disable=attribute-error + return self.replace(qvalue=qvalue, scale=new_scale) # pytype: disable=attribute-error def dequant(self) -> jnp.ndarray: """Dequantizes the QTensor."""