diff --git a/aqt/jax/v2/aqt_quantizer.py b/aqt/jax/v2/aqt_quantizer.py index 63433a59..1651a394 100644 --- a/aqt/jax/v2/aqt_quantizer.py +++ b/aqt/jax/v2/aqt_quantizer.py @@ -54,16 +54,18 @@ def quant( ) -> tuple[aqt_tensor.QTensor, aqt_tensor.GradientFn]: """The core quantizing function.""" qt = self.calibrate(x, calibration_axes=calibration_axes) - qt, quant_grad = self.calculate_qvalue(x, qt) - return qt, quant_grad + return self.calculate_qvalue(x, qt) def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor: """Create incomplete QTensor with only quantization parameters.""" if isinstance(self.numerics, no_numerics.NoNumerics): - qt = aqt_tensor.QTensor( - qvalue=x, scale=[], scale_t=None, dequant_dtype=x.dtype + return aqt_tensor.QTensor( + qvalue=x, + scale=[], + scale_t=None, + dequant_dtype=x.dtype, + numerics=self.numerics, ) - return qt dequant_dtype = x.dtype # TODO(lew): We should cast earlier. xhs_q should be in cfg.xhs.dtype @@ -94,6 +96,7 @@ def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor: scale=[scale], scale_t=None, dequant_dtype=dequant_dtype, + numerics=self.numerics, ) return qt @@ -103,20 +106,7 @@ def calculate_qvalue( qt: aqt_tensor.QTensor ) -> tuple[aqt_tensor.QTensor, aqt_tensor.GradientFn]: """Uses the quantization parameters in qt to quantize x.""" - if isinstance(self.numerics, no_numerics.NoNumerics): - return qt, None - - # TODO: b/333984742 - make numeric as a member of QTensor, and put - # numerics-related logics into the QTensor. - qt = qt.quant(x) - - # TODO(lew): A logical thing would be if this call was part of - # QTensor.quant. - x_q, res = self.numerics.vjp_fwd(qt.qvalue, self.context) - quant_grad = jax.tree_util.Partial(self.numerics.vjp_bwd, res) - - qt = qt.replace(qvalue=x_q) - return qt, quant_grad + return qt.quant(x, self.context) def quantizer_make( diff --git a/aqt/jax/v2/aqt_tensor.py b/aqt/jax/v2/aqt_tensor.py index b3a30808..416fc2a2 100644 --- a/aqt/jax/v2/aqt_tensor.py +++ b/aqt/jax/v2/aqt_tensor.py @@ -25,6 +25,9 @@ import typing from typing import Any, Callable, Optional, Sequence, TypeAlias from aqt.jax.v2 import utils +from aqt.jax.v2.numerics import int_numerics +from aqt.jax.v2.numerics import no_numerics +from aqt.jax.v2.numerics import numerics import flax.cursor import flax.struct import jax @@ -32,6 +35,7 @@ import jax.typing as jax_typing from typing_extensions import Self # for python version < 3.11 +AbstractAqtNumerics = numerics.AqtNumerics GradientFn = Callable[..., Any] | None # None when there is no numerics _MSG_NO_QVALUE = ( 'QTensor does not have qvalue, but it is asked to access the qvalue.' @@ -75,6 +79,9 @@ class QTensor: pytree_node=False, default=None ) + # Numerics of the QTensor. + numerics: AbstractAqtNumerics = utils.static_field(default=None) + @property def dtype(self) -> jnp.dtype | None: return self.dequant_dtype @@ -89,8 +96,12 @@ def without_qvalue(self) -> Self: def astype(self, dtype: jnp.dtype) -> Self: return self.replace(dequant_dtype=dtype) # pytype: disable=attribute-error - def quant(self, x): - """Quantizes the QTensor.""" + def quant(self, x, context: utils.Context) -> tuple[Self, GradientFn]: + """Uses the quantization parameters in qt to quantize x.""" + assert self.numerics is not None, 'Missing numerics used for quantization.' + if isinstance(self.numerics, no_numerics.NoNumerics): + return self, None + assert not self.is_full(), 'Already quantized QTensor.' assert self.scale is not None, 'Missing scales to be used for quantization.' @@ -101,9 +112,10 @@ def quant(self, x): s_inv = jnp.where(jnp.isinf(s_inv), jnp.ones_like(s_inv), s_inv) qvalue = qvalue * s_inv - # TODO(lew): We should apply numerics here, so that 'quant' function - # Can be considered a part of API. - return self.replace(qvalue=qvalue) # pytype: disable=attribute-error + x_q, res = self.numerics.vjp_fwd(qvalue, context) + quant_grad = jax.tree_util.Partial(self.numerics.vjp_bwd, res) + + return self.replace(qvalue=x_q), quant_grad # pytype: disable=attribute-error def dequant(self) -> jnp.ndarray: """Dequantizes the QTensor.""" @@ -152,26 +164,14 @@ def __len__(self) -> int: def zeros( - shape: Sequence[int], - *, - container_dtype: jnp.dtype, - dequant_dtype: jnp.dtype = jnp.bfloat16, -) -> QTensor: - return QTensor( - qvalue=jnp.zeros(shape, dtype=container_dtype), - scale=[], - scale_t=None, - dequant_dtype=dequant_dtype, - ) - - -def zeros_with_scale( shape: Sequence[int], calibration_axis: Sequence[utils.AxisIdx], *, container_dtype: jnp.dtype, scale_dtype: jnp.dtype | None = None, dequant_dtype: jnp.dtype = jnp.bfloat16, + n_bits: int | None = None, + preserve_max_val: bool = False, ) -> QTensor: """Initializes a QTensor with empty qvalue along with empty scale value.""" scale_shape = list(shape) @@ -186,6 +186,7 @@ def zeros_with_scale( scale=[jnp.ones(scale_shape, dtype=scale_dtype)], scale_t=None, dequant_dtype=dequant_dtype, + numerics=_get_numerics(n_bits, preserve_max_val), ) @@ -193,6 +194,8 @@ def partition_spec( partitions: Sequence[Any], calibration_axis: Sequence[utils.AxisIdx], dtype: jnp.dtype, + n_bits: int | None, + preserve_max_val: bool = False, ) -> QTensor: """Returns a QTensor filled with partition specs.""" scale_partitions = list(partitions) @@ -203,6 +206,26 @@ def partition_spec( scale=[jax.sharding.PartitionSpec(*scale_partitions)], scale_t=None, dequant_dtype=dtype, + numerics=_get_numerics(n_bits, preserve_max_val), + ) + + +def _get_numerics( + n_bits: int | None, preserve_max_val: bool = False +) -> numerics.AqtNumerics: + if n_bits is None: + return no_numerics.NoNumerics() + pz = False if n_bits == 1 else True + dtype = utils.infer_dtype_from_bits(n_bits) if pz else None + return int_numerics.IntNumerics( + bits=n_bits, + preserve_zero=pz, + preserve_max_val=preserve_max_val, + clip=True, + round=True, + noise_fn=None, + clip_gradient=False, # This can be disabled when using abs-max scaling. + dtype=dtype, ) @@ -242,6 +265,7 @@ def get_sliced_scales(scale): scale=[get_sliced_scales(s) for s in operand.scale], scale_t=None, dequant_dtype=operand.dequant_dtype, + numerics=operand.numerics, ) @@ -290,6 +314,7 @@ def dynamic_update_slice( scale=scales, scale_t=None, dequant_dtype=operand.dequant_dtype, + numerics=operand.numerics, ) @@ -306,4 +331,5 @@ def update_frame(operand: QTensor, frame: int, update: QTensor) -> QTensor: ], scale_t=None, dequant_dtype=operand.dequant_dtype, + numerics=operand.numerics, ) diff --git a/aqt/jax/v2/aqt_tensor_test.py b/aqt/jax/v2/aqt_tensor_test.py index 820cda6d..a62937d9 100644 --- a/aqt/jax/v2/aqt_tensor_test.py +++ b/aqt/jax/v2/aqt_tensor_test.py @@ -89,7 +89,10 @@ def test_dynamic_update(self): def test_dtype(self): qt = aqt_tensor.zeros( - shape=(1,), container_dtype=jnp.int8, dequant_dtype=jnp.float32 + shape=(1,), + calibration_axis=(), + container_dtype=jnp.int8, + dequant_dtype=jnp.float32, ) self.assertEqual(qt.dtype, jnp.float32) self.assertEqual(qt.dequant_dtype, jnp.float32) diff --git a/aqt/jax/v2/examples/flax_e2e_model_test.py b/aqt/jax/v2/examples/flax_e2e_model_test.py index e3e6128a..dede6f5d 100644 --- a/aqt/jax/v2/examples/flax_e2e_model_test.py +++ b/aqt/jax/v2/examples/flax_e2e_model_test.py @@ -24,6 +24,7 @@ from aqt.jax.v2 import utils from aqt.jax.v2.examples import flax_e2e_model from aqt.jax.v2.flax import aqt_flax_calibration +from aqt.jax.v2.numerics import int_numerics import jax import jax.numpy as jnp import numpy as np @@ -38,6 +39,19 @@ def _dummy_dataset(ds_size, image_rng, label_rng): } +def _get_numerics(bits): + return int_numerics.IntNumerics( + bits=bits, + preserve_zero=True, + preserve_max_val=False, + clip=True, + round=True, + noise_fn=None, + clip_gradient=False, + dtype=utils.infer_dtype_from_bits(bits), + ) + + class MnistTest(parameterized.TestCase): # Unable to use config_v4() in parameters since it needs jax.device info. @@ -129,6 +143,7 @@ def forward(model, apply_fn): dtype = jnp.dtype expected_dtype = jnp.int4 if bits == 4 else jnp.int8 + expected_numerics = _get_numerics(bits) expected_aqt_pytree = { "aqt": { "AqtEinsum_0": { @@ -138,7 +153,8 @@ def forward(model, apply_fn): qvalue=(expected_dtype, (1, 2, 5, 1, 10)), scale=[(dtype("float32"), (1, 2, 1, 1, 10))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) } } @@ -158,7 +174,8 @@ def forward(model, apply_fn): # After tiling the scale shape is (1, 2, 1, 1, 256), # then transposed to (2, 1, 1, 1, 256). scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) } } @@ -178,11 +195,12 @@ def forward(model, apply_fn): # After tiling the scale shape is (1, 2, 1, 1, 10), # then transposed to (2, 1, 1, 1, 10). scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) } } - } + }, }, "batch_stats": { "BatchNorm_0": { @@ -394,6 +412,7 @@ def assert_array_not_equal(x, y): ) dtype = jnp.dtype expected_dtype = dtype("int4") if bits == 4 else dtype("int8") + expected_numerics = _get_numerics(bits) expected_aqt_pytree = { "AqtEinsum_0": { "AqtDotGeneral_0": { @@ -402,7 +421,8 @@ def assert_array_not_equal(x, y): qvalue=(expected_dtype, (1, 2, 5, 1, 10)), scale=[(dtype("float32"), (1, 2, 1, 1, 10))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) }, "qrhs": { @@ -410,9 +430,10 @@ def assert_array_not_equal(x, y): qvalue=None, scale=[(dtype("float32"), (1, 1, 1, 1, 1))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) - } + }, } }, "Dense_0": { @@ -422,7 +443,8 @@ def assert_array_not_equal(x, y): qvalue=None, scale=[(dtype("float32"), (1, 1, 1, 1, 1))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) }, "qrhs": { @@ -430,9 +452,10 @@ def assert_array_not_equal(x, y): qvalue=(expected_dtype, (1, 2, 1568, 1, 256)), scale=[(dtype("float32"), (1, 2, 1, 1, 256))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) - } + }, } }, "Dense_1": { @@ -442,7 +465,8 @@ def assert_array_not_equal(x, y): qvalue=None, scale=[(dtype("float32"), (1, 1, 1, 1, 1))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) }, "qrhs": { @@ -450,11 +474,12 @@ def assert_array_not_equal(x, y): qvalue=(expected_dtype, (1, 2, 128, 1, 10)), scale=[(dtype("float32"), (1, 2, 1, 1, 10))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) - } + }, } - } + }, } serving_pytree = jax.tree_util.tree_map( @@ -605,6 +630,7 @@ def forward(model, apply_fn): ) dtype = jnp.dtype expected_dtype = dtype("int4") if bits == 4 else dtype("int8") + expected_numerics = _get_numerics(bits) expected_aqt_pytree = { "AqtEinsum_0": { "AqtDotGeneral_0": { @@ -613,7 +639,8 @@ def forward(model, apply_fn): qvalue=(expected_dtype, (1, 2, 5, 1, 10)), scale=[(dtype("float32"), (1, 2, 1, 1, 10))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) }, "qrhs": { @@ -621,9 +648,10 @@ def forward(model, apply_fn): qvalue=None, scale=[(dtype("float32"), (1, 1, 1, 1, 1))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) - } + }, } }, "Dense_0": { @@ -633,7 +661,8 @@ def forward(model, apply_fn): qvalue=None, scale=[(dtype("float32"), (1, 1, 1, 1, 1))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) }, "qrhs": { @@ -641,9 +670,10 @@ def forward(model, apply_fn): qvalue=(expected_dtype, (1, 2, 1568, 1, 256)), scale=[(dtype("float32"), (1, 2, 1, 1, 256))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) - } + }, } }, "Dense_1": { @@ -653,7 +683,8 @@ def forward(model, apply_fn): qvalue=None, scale=[(dtype("float32"), (1, 1, 1, 1, 1))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) }, "qrhs": { @@ -661,11 +692,12 @@ def forward(model, apply_fn): qvalue=(expected_dtype, (1, 2, 128, 1, 10)), scale=[(dtype("float32"), (1, 2, 1, 1, 10))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) - } + }, } - } + }, } serving_pytree = jax.tree_util.tree_map( diff --git a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model_test.py b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model_test.py index bf112d23..e0c34627 100644 --- a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model_test.py +++ b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model_test.py @@ -22,6 +22,7 @@ from aqt.jax.v2 import config from aqt.jax.v2 import utils from aqt.jax.v2.extensions.gptq.examples import gptq_flax_e2e_model +from aqt.jax.v2.numerics import int_numerics import jax import jax.numpy as jnp @@ -35,6 +36,19 @@ def _dummy_dataset(ds_size, image_rng, label_rng): } +def _get_numerics(bits): + return int_numerics.IntNumerics( + bits=bits, + preserve_zero=True, + preserve_max_val=False, + clip=True, + round=True, + noise_fn=None, + clip_gradient=False, + dtype=utils.infer_dtype_from_bits(bits), + ) + + class GptqTest(parameterized.TestCase): def test_gptq(self): @@ -121,6 +135,7 @@ def test_gptq(self): serve_fn, model_serving = gptq_flax_e2e_model.serving_conversion(state) dtype = jnp.dtype expected_dtype = dtype("int8") + expected_numerics = _get_numerics(8) expected_aqt_pytree = { "AqtEinsum_0": { "AqtDotGeneral_0": { @@ -129,7 +144,8 @@ def test_gptq(self): qvalue=(expected_dtype, (1, 2, 5, 1, 10)), scale=[(dtype("float32"), (1, 2, 1, 1, 10))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) }, } @@ -141,7 +157,8 @@ def test_gptq(self): qvalue=(expected_dtype, (1, 2, 1568, 1, 256)), scale=[(dtype("float32"), (1, 2, 1, 1, 256))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) } } @@ -153,11 +170,12 @@ def test_gptq(self): qvalue=(expected_dtype, (1, 2, 128, 1, 10)), scale=[(dtype("float32"), (1, 2, 1, 1, 10))], scale_t=None, - dequant_dtype=dtype("float32") + dequant_dtype=dtype("float32"), + numerics=expected_numerics, ) } } - } + }, } serving_pytree = jax.tree_util.tree_map(