From ba24c56563042c01f5d486bd2f47d47a81647ce2 Mon Sep 17 00:00:00 2001 From: Cerebra Catalyst Team Date: Wed, 25 Sep 2024 17:43:54 -0700 Subject: [PATCH] Fix dequant logic with zero scaling factors PiperOrigin-RevId: 678913500 --- aqt/jax/v2/aqt_tensor.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aqt/jax/v2/aqt_tensor.py b/aqt/jax/v2/aqt_tensor.py index ba82194a..ad7de7f2 100644 --- a/aqt/jax/v2/aqt_tensor.py +++ b/aqt/jax/v2/aqt_tensor.py @@ -137,15 +137,17 @@ def quant(self, x): for b in self.bias: qvalue += b + new_scale = [] for s in self.scale: # TODO(lew): We could store s_inv for faster activation quantization. - s_inv = jax.lax.reciprocal(s) - s_inv = jnp.where(jnp.isinf(s_inv), jnp.ones_like(s_inv), s_inv) + orig_inv = jax.lax.reciprocal(s) + s_inv = jnp.where(jnp.isinf(orig_inv), jnp.ones_like(orig_inv), orig_inv) + new_scale.append(jnp.where(jnp.isinf(orig_inv), jnp.ones_like(s), s)) qvalue *= s_inv # TODO(lew): We should apply numerics here, so that 'quant' function # Can be considered a part of API. - return self.replace(qvalue=qvalue) # pytype: disable=attribute-error + return self.replace(qvalue=qvalue, scale=new_scale) # pytype: disable=attribute-error def dequant(self) -> jnp.ndarray: """Dequantizes the QTensor."""