Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 9 additions & 19 deletions aqt/jax/v2/aqt_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ def quant(
) -> tuple[aqt_tensor.QTensor, aqt_tensor.GradientFn]:
"""The core quantizing function."""
qt = self.calibrate(x, calibration_axes=calibration_axes)
qt, quant_grad = self.calculate_qvalue(x, qt)
return qt, quant_grad
return self.calculate_qvalue(x, qt)

def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor:
"""Create incomplete QTensor with only quantization parameters."""
if isinstance(self.numerics, no_numerics.NoNumerics):
qt = aqt_tensor.QTensor(
qvalue=x, scale=[], scale_t=None, dequant_dtype=x.dtype
return aqt_tensor.QTensor(
qvalue=x,
scale=[],
scale_t=None,
dequant_dtype=x.dtype,
numerics=self.numerics,
)
return qt

dequant_dtype = x.dtype
# TODO(lew): We should cast earlier. xhs_q should be in cfg.xhs.dtype
Expand Down Expand Up @@ -94,6 +96,7 @@ def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor:
scale=[scale],
scale_t=None,
dequant_dtype=dequant_dtype,
numerics=self.numerics,
)
return qt

Expand All @@ -103,20 +106,7 @@ def calculate_qvalue(
qt: aqt_tensor.QTensor
) -> tuple[aqt_tensor.QTensor, aqt_tensor.GradientFn]:
"""Uses the quantization parameters in qt to quantize x."""
if isinstance(self.numerics, no_numerics.NoNumerics):
return qt, None

# TODO: b/333984742 - make numeric as a member of QTensor, and put
# numerics-related logics into the QTensor.
qt = qt.quant(x)

# TODO(lew): A logical thing would be if this call was part of
# QTensor.quant.
x_q, res = self.numerics.vjp_fwd(qt.qvalue, self.context)
quant_grad = jax.tree_util.Partial(self.numerics.vjp_bwd, res)

qt = qt.replace(qvalue=x_q)
return qt, quant_grad
return qt.quant(x, self.context)


def quantizer_make(
Expand Down
64 changes: 45 additions & 19 deletions aqt/jax/v2/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
import typing
from typing import Any, Callable, Optional, Sequence, TypeAlias
from aqt.jax.v2 import utils
from aqt.jax.v2.numerics import int_numerics
from aqt.jax.v2.numerics import no_numerics
from aqt.jax.v2.numerics import numerics
import flax.cursor
import flax.struct
import jax
import jax.numpy as jnp
import jax.typing as jax_typing
from typing_extensions import Self # for python version < 3.11

AbstractAqtNumerics = numerics.AqtNumerics
GradientFn = Callable[..., Any] | None # None when there is no numerics
_MSG_NO_QVALUE = (
'QTensor does not have qvalue, but it is asked to access the qvalue.'
Expand Down Expand Up @@ -75,6 +79,9 @@ class QTensor:
pytree_node=False, default=None
)

# Numerics of the QTensor.
numerics: AbstractAqtNumerics = utils.static_field(default=None)

@property
def dtype(self) -> jnp.dtype | None:
return self.dequant_dtype
Expand All @@ -89,8 +96,12 @@ def without_qvalue(self) -> Self:
def astype(self, dtype: jnp.dtype) -> Self:
return self.replace(dequant_dtype=dtype) # pytype: disable=attribute-error

def quant(self, x):
"""Quantizes the QTensor."""
def quant(self, x, context: utils.Context) -> tuple[Self, GradientFn]:
"""Uses the quantization parameters in qt to quantize x."""
assert self.numerics is not None, 'Missing numerics used for quantization.'
if isinstance(self.numerics, no_numerics.NoNumerics):
return self, None

assert not self.is_full(), 'Already quantized QTensor.'
assert self.scale is not None, 'Missing scales to be used for quantization.'

Expand All @@ -101,9 +112,10 @@ def quant(self, x):
s_inv = jnp.where(jnp.isinf(s_inv), jnp.ones_like(s_inv), s_inv)
qvalue = qvalue * s_inv

# TODO(lew): We should apply numerics here, so that 'quant' function
# Can be considered a part of API.
return self.replace(qvalue=qvalue) # pytype: disable=attribute-error
x_q, res = self.numerics.vjp_fwd(qvalue, context)
quant_grad = jax.tree_util.Partial(self.numerics.vjp_bwd, res)

return self.replace(qvalue=x_q), quant_grad # pytype: disable=attribute-error

def dequant(self) -> jnp.ndarray:
"""Dequantizes the QTensor."""
Expand Down Expand Up @@ -152,26 +164,14 @@ def __len__(self) -> int:


def zeros(
shape: Sequence[int],
*,
container_dtype: jnp.dtype,
dequant_dtype: jnp.dtype = jnp.bfloat16,
) -> QTensor:
return QTensor(
qvalue=jnp.zeros(shape, dtype=container_dtype),
scale=[],
scale_t=None,
dequant_dtype=dequant_dtype,
)


def zeros_with_scale(
shape: Sequence[int],
calibration_axis: Sequence[utils.AxisIdx],
*,
container_dtype: jnp.dtype,
scale_dtype: jnp.dtype | None = None,
dequant_dtype: jnp.dtype = jnp.bfloat16,
n_bits: int | None = None,
preserve_max_val: bool = False,
) -> QTensor:
"""Initializes a QTensor with empty qvalue along with empty scale value."""
scale_shape = list(shape)
Expand All @@ -186,13 +186,16 @@ def zeros_with_scale(
scale=[jnp.ones(scale_shape, dtype=scale_dtype)],
scale_t=None,
dequant_dtype=dequant_dtype,
numerics=_get_numerics(n_bits, preserve_max_val),
)


def partition_spec(
partitions: Sequence[Any],
calibration_axis: Sequence[utils.AxisIdx],
dtype: jnp.dtype,
n_bits: int | None,
preserve_max_val: bool = False,
) -> QTensor:
"""Returns a QTensor filled with partition specs."""
scale_partitions = list(partitions)
Expand All @@ -203,6 +206,26 @@ def partition_spec(
scale=[jax.sharding.PartitionSpec(*scale_partitions)],
scale_t=None,
dequant_dtype=dtype,
numerics=_get_numerics(n_bits, preserve_max_val),
)


def _get_numerics(
n_bits: int | None, preserve_max_val: bool = False
) -> numerics.AqtNumerics:
if n_bits is None:
return no_numerics.NoNumerics()
pz = False if n_bits == 1 else True
dtype = utils.infer_dtype_from_bits(n_bits) if pz else None
return int_numerics.IntNumerics(
bits=n_bits,
preserve_zero=pz,
preserve_max_val=preserve_max_val,
clip=True,
round=True,
noise_fn=None,
clip_gradient=False, # This can be disabled when using abs-max scaling.
dtype=dtype,
)


Expand Down Expand Up @@ -242,6 +265,7 @@ def get_sliced_scales(scale):
scale=[get_sliced_scales(s) for s in operand.scale],
scale_t=None,
dequant_dtype=operand.dequant_dtype,
numerics=operand.numerics,
)


Expand Down Expand Up @@ -290,6 +314,7 @@ def dynamic_update_slice(
scale=scales,
scale_t=None,
dequant_dtype=operand.dequant_dtype,
numerics=operand.numerics,
)


Expand All @@ -306,4 +331,5 @@ def update_frame(operand: QTensor, frame: int, update: QTensor) -> QTensor:
],
scale_t=None,
dequant_dtype=operand.dequant_dtype,
numerics=operand.numerics,
)
5 changes: 4 additions & 1 deletion aqt/jax/v2/aqt_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def test_dynamic_update(self):

def test_dtype(self):
qt = aqt_tensor.zeros(
shape=(1,), container_dtype=jnp.int8, dequant_dtype=jnp.float32
shape=(1,),
calibration_axis=(),
container_dtype=jnp.int8,
dequant_dtype=jnp.float32,
)
self.assertEqual(qt.dtype, jnp.float32)
self.assertEqual(qt.dequant_dtype, jnp.float32)
Expand Down
Loading