diff --git a/pyproject.toml b/pyproject.toml index 8f3dd61..91e396c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "ComfyUI-QuantOps" description = "Extended quantization layouts for ComfyUI (INT8, row/block-wise FP8)" -version = "1.7.2" +version = "1.8.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/unified_ops.py b/unified_ops.py index bfdcf21..29586a9 100644 --- a/unified_ops.py +++ b/unified_ops.py @@ -173,11 +173,14 @@ def _load_from_state_dict( if is_tensorwise and _HAS_TENSORWISE_INT8_LAYOUT: self.layout_type = "TensorWiseINT8Layout" + _orig_dtype_str = layer_conf.get("orig_dtype", "torch.bfloat16") if layer_conf else "torch.bfloat16" + _DTYPE_MAP = {"torch.bfloat16": torch.bfloat16, "torch.float16": torch.float16, "torch.float32": torch.float32} + _orig_dtype = _DTYPE_MAP.get(_orig_dtype_str, torch.bfloat16) layout_params = TensorWiseINT8Layout.Params( scale=scale.to(torch.float32) if scale is not None else None, - orig_dtype=torch.bfloat16, + orig_dtype=_orig_dtype, orig_shape=tuple(weight_tensor.shape), is_weight=True, ) @@ -205,6 +208,14 @@ def _load_from_state_dict( requires_grad=False, ) else: + # TODO (#2 — medium severity, low risk): this branch fires when + # is_tensorwise=True but _HAS_TENSORWISE_INT8_LAYOUT=False (ck absent). + # Result: raw int8 tensor stored with is_quantized=True, layout_type=None. + # That is a broken state — forward() will hit F.linear with raw int8 weight. + # Fix: degrade to BlockWiseINT8Layout if _HAS_INT8_LAYOUT, else set + # is_quantized=False and log a warning. Not patching now because ck is + # effectively required for tensorwise; if ck import failed the checkpoint + # is already unrunnable regardless. self.weight = torch.nn.Parameter( weight_tensor, requires_grad=False ) @@ -463,7 +474,19 @@ def forward_comfy_cast_weights(self, input): ) else: - # Default trigger for QuantizedTensor dispatch -> layout-specific handler + # Default trigger for QuantizedTensor dispatch -> layout-specific handler. + # TensorWiseINT8Layout and BlockWiseINT8Layout land here — aten.linear + # dispatch in comfy_kitchen handles the actual matmul. + # + # TODO (#3 — low-medium severity, medium risk): this else branch has no 3D + # input reshape guard, unlike all the explicit elif branches above. ComfyUI + # transformer attention layers pass [batch, seq, hidden] (3D). F.linear + # handles 3D natively so it works, but ck dispatch handlers may not. If + # tensorwise inference produces wrong shapes on 3D inputs, add the standard + # tensor_3d guard here (reshape -1,hidden before linear, reshape back after). + # Not patching now — risk of breaking currently-working layouts that fall + # through to this branch (e.g. RowWiseFP8, BlockWiseFP8 if aten dispatch + # handles them here too). out = torch.nn.functional.linear(input, weight, bias) else: diff --git a/utils/eager_quantization.py b/utils/eager_quantization.py index cb5b6bb..7cd3d87 100644 --- a/utils/eager_quantization.py +++ b/utils/eager_quantization.py @@ -35,13 +35,40 @@ def int8_linear( bias: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: - """INT8 linear layer using torch.int8_mm for direct quantized matmul. - - Uses native torch.int8_mm which avoids materializing large float32 intermediates - and handles scaling more efficiently than manual int32 -> float32 conversion. - - Ported from comfy-kitchen eager backend with OOM fixes. + """INT8 linear layer. Delegates to comfy_kitchen.int8_linear (triton->eager) + when available, falls back to local torch.int8_mm chunked path. + + ck.int8_linear signature matches exactly: + (x, weight, weight_scale, bias=None, out_dtype=None) + weight: [N, K] int8, weight_scale: scalar float32, out_dtype defaults bfloat16. """ + # Prefer comfy_kitchen dispatch (triton -> eager via registry). + # ck.int8_linear routes through torch.ops.comfy_kitchen.int8_linear which + # goes through the registry with priority ["cuda", "triton", "eager"]. + # cuda backend has no int8_linear, so triton wins if available, else eager. + try: + import comfy_kitchen as ck + return ck.int8_linear(x, weight, weight_scale, bias, out_dtype) + except ImportError: + pass + except Exception as e: + import logging + logging.warning(f"ComfyUI-QuantOps: ck.int8_linear failed, falling back to local path: {e}") + + # --- Local fallback: chunked torch.int8_mm path (OOM-safe) --- + # Unwrap QuantizedTensor if weight arrived still wrapped (defensive). + try: + from comfy.quant_ops import QuantizedTensor + if isinstance(weight, QuantizedTensor): + weight_scale = weight._params.scale + weight = weight._qdata + except ImportError: + pass + + # Ensure weight is raw int8 and contiguous before torch.int8_mm. + if not weight.is_contiguous(): + weight = weight.contiguous() + orig_shape = x.shape x_2d = x.reshape(-1, x.shape[-1])