Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
This repository was archived by the owner on Feb 24, 2026. It is now read-only.

When the A_dtype is bfloat16, errors occur. #295

@efsotr

Description

@efsotr
model = bitblas.Linear(
    in_features=1024,
    out_features=1024,
    bias=False,
    A_dtype="bfloat16",  # activation A dtype
    W_dtype="uint2",  # weight W dtype
    accum_dtype="float32",  # accumulation dtype
    out_dtype="float16",  # output dtype
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=True,  # setting for scaling factor
    with_zeros=True,  # setting for zeros
    zeros_mode="original",  # setting for how to calculating zeros
    fast_decoding=False,
    # Target optimization var for dynamic symbolic.
    # For detailed information please checkout docs/PythonAPI.md
    # By default, the optimization var is [1, 16, 32, 64, 128, 256, 512]
    opt_M=[1, 16, 32, 64, 128],
)
2025-03-10 15:32:38 [BitBLAS:ERROR]: An exception occurred for hint {block_M=16,block_N=256,warp_M=16,warp_N=64,block_K=32,threads=128,num_stages=0,enable_rasterization=True,split_k_factor=16}: #include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>

extern "C" __global__ void __launch_bounds__(128) general_dequant_matmul_kernel(bfloat16_t* __restrict__ A, signed char* __restrict__ B, half_t* __restrict__ C, bfloat16_t* __restrict__ Scale, bfloat16_t* __restrict__ Zeros, int m) {
  extern __shared__ __align__(1	...	u(36): error: identifier "__pack_nv_bfloat162" is undefined
          condval = make_uint4(__pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)));
                               ^

1 error detected in the compilation of "/tmp/tmpa36zlrof/tvm_kernels.cu".

2025-03-10 15:32:38 [BitBLAS:ERROR]: An exception occurred for hint {block_M=16,block_N=32,warp_M=16,warp_N=16,block_K=256,threads=64,num_stages=0,enable_rasterization=True,split_k_factor=2}: #include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>

extern "C" __global__ void __launch_bounds__(128) general_dequant_matmul_kernel(bfloat16_t* __restrict__ A, signed char* __restrict__ B, half_t* __restrict__ C, bfloat16_t* __restrict__ Scale, bfloat16_t* __restrict__ Zeros, int m) {
  extern __shared__ __align__(1	...	ls.cu(35): error: identifier "__pack_nv_bfloat162" is undefined
        condval = make_uint4(__pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)), __pack_nv_bfloat162(bfloat16_t(0.000000e+00f), bfloat16_t(0.000000e+00f)));
                             ^

1 error detected in the compilation of "/tmp/tmpesebmis9/tvm_kernels.cu".

@LeiWang1999

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions