Skip to content

Update ops & Add Tests#7

Open
chenliangjyj wants to merge 11 commits into
mainfrom
sync_0427
Open

Update ops & Add Tests#7
chenliangjyj wants to merge 11 commits into
mainfrom
sync_0427

Conversation

@chenliangjyj

Copy link
Copy Markdown
Collaborator
  • Optimize kernels
  • Add ops for MXFP8
  • Add kernels for inference.

zheyishine and others added 10 commits January 14, 2026 16:26
Merge branch core of git@code.alipay.com:pia/linghe.git into main
https://code.alipay.com/pia/linghe/pull_requests/5
* add mxfp8 quant
* add mxfp8 quant
* add mxfp8 quant
* add mxfp8 quant
* add mxfp8 quant
* add mxfp8 quant
* add mxfp8 quant
* revise for colwise
* opt mxfp8 quant
* compatible with triton 3.2.0
* add arg silu for rope
* fix norm
* add fused block permute & unpermute
* add fused permute/unpermute mxfp8 quantize
* refine testcase
* add batch scale & batch clip kernels
* finetune params
* add mla_rope kernel
* add rope inferface & use float32 update in gemm
* add quantization interface
* add varlen rope
* add cu_seqlens_kv for mla rope
* add is_contiguous assert for rope
* refine gemm testcase
* use async h2d
* add rope with cp
* add rope with cp
* add rope with cp
* refine testcases
* add support for mini v3
* add cp rope in facade
* remove megatron code in test_rope
* use two kernels for block_rms_norm
* use two kernels for block_rms_norm
* refine save_for_backward
* use raw input instead of view in saved_tensor
* fix bench norm bug
* remove ctx.rms in norm
* use parameter instead of parameter.data in op
* add silu arg in rope
* opt inplace impl in ce loss
* refine arg names
* add more barrier in ce loss
* add assert in ce loss
* use meaningful shape for batch quantization
* support transpose in mla
* support transpose in mla
* fix arg num bug in mla rope
* linghe 0.1.0
* add approvers in aci
* support more num_experts in fp32_gemm
* update aci
* update
* refine testcases
* remove is_contiguous assertion in batch norm
* use fp32 dw in multiple kernels
* use 0.1.2 version
* recover test image
* use int64 in clip and scale and fp32 in rope
* add is_contiguous assertion and use rsqrt instead of 1/sqrt
* mxfp8 deepep permute
* mxfp8 unpermute func
* support v3 dim
* assert
* scatter add for v3 funny hidden
* use numerical stable impl for ce
* use float64 in ce
* refine testcases
* refine testcases
* refine testcases
* use fp32 in silu and split batch silu backward
* fix bug in topk
* revise&refine test
* tune params
* add arg reuse  in rope
* adapt for triton 3.2.0
* support tp in rope
* support tp in rope
* use int64 in batch ops
* use int64 in batch ops
* use mask instead of max to accelerate clip
* mv mxfp8 silu to linghe
* mv mxfp8 rms to linghe
* gradient fix
* refine args in mxfp8
* use int64 in smooth/mxfp8 batch kernels
* fix transpose in mla
* fix transpose in mla
* fix transpose in mla
* fix transpose in mla
* fix transpose in mla
* fix transpose in mla
* fix transpose in mla
* fix mla bug with bs=1 and strided k_pos_emb
* use padding batch instead of 128 in varlen rope
* do not transpose mla rope output with thd layout
* support more shapes
* refine test.sh
* remove rms none assertion in rms norm
* add embedding and mla
* refine mla impl
* add for hybridep
* add varlen mla
* use 0.2.5
* use 0.2.5
* return zero when grad tensors is empty
* define inf/inf=1
* add numel>0 assert in batch ops
* add experiment op&reduce confusion
* WIP dist ce loss
* refine testcases
* add debug log
* add debug log
* support ignore_index in ce
* add parallel ce
* add parallel ce
* version 0.2.7
* remove inf_or_nan
* add log for labels
* add barrier in ce backward
* refine testcase
* add tp_group in ce
* add tp_group in ce
* use rsqrt instead of 1/sqrt
* use 0.2.8
* refine condition of ce parallel
* support stride for grad of embedding
* use 0.2.9
* return grad for dummy tensor in embedding
* support bf16 in batch ops
* use fast impl for embedding
* add ptx util & revise count zeros
* fix typo in triton_batch_count_zero
* use fast and accurate impl for embedding
* refactor embedding ops
* refactor embedding ops
* revise for liuyu request
* PullRequest: 2 合并计算通信融合算子
* PullRequest: 3 reformat
* refine for public
add infer ops

WIP: update test 1

WIP: update test 2
@chenliangjyj chenliangjyj requested a review from zheyishine April 27, 2026 10:00

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly expands the library by introducing Triton-based kernels and PyTorch facades for Multi-Latent Attention (MLA), advanced quantization techniques (MXFP8, SmoothQuant, blockwise), and optimized GEMM implementations, alongside a comprehensive set of inference kernels and a refactored pytest suite. While the additions are extensive, the review uncovered several critical issues: the use of non-existent Triton APIs like tl.split, missing attributes in autograd ctx objects causing potential AttributeError or NameError, and incorrect return counts in backward methods. Furthermore, the feedback identifies opportunities to clean up the code by removing unused imports, eliminating redundant logic, and fixing typos in conditional checks within the linear and quantization layers.


acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function tl.split does not exist in the Triton language API. This will cause a compilation error. It seems you are trying to split the accumulator into two halves. You can achieve this using slicing.

Suggested change
acc0, acc1 = tl.split(acc)
acc0, acc1 = acc[..., 0], acc[..., 1]

Comment thread linghe/facade/gate.py Outdated
triton_batch_transpose_smooth_fused_permute_with_indices(
grad_output._rowwise_data,
grad_output._rowwise_scale_inv,
grad_smooth_scale,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable grad_smooth_scale is used here but it is not defined in the backward method of _SmoothFusedUnpermute. This will cause a NameError.

@staticmethod
def backward(ctx, grad_output):
if hasattr(grad_output, "_quantizer") or ctx.grad_quantizer is None:
return (grad_output,)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The backward method should return a tuple with the same number of elements as the inputs to the forward method. The forward method has 4 inputs (hidden_states, quantizer, grad_quantizer, cls), so backward should return 4 values. Returning a tuple with a single element (grad_output,) will cause a runtime error.

Suggested change
return (grad_output,)
return grad_output, None, None, None

Comment thread linghe/attn/mla.py
# c1 = tl.min(tl.where(cu <= c0, 2 ** 24, cu), 0)

c0 = 0
c1 = 1048576

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initial value for c1 is 1048576 (which is 2**20), but inside the loop for calculating c1, you are using 2**24 as the large value for elements that don't satisfy the condition. This is inconsistent. To avoid potential issues where the minimum value could be larger than the initial c1, it's better to use a consistent large value. I suggest initializing c1 with 2**24 as well.

Suggested change
c1 = 1048576
c1 = 2**24

import torch
import triton
import triton.language as tl
from triton import Config

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The triton.Config import is unused in this file and can be removed.

BLOCK_N // 2])

def grid(META):
nonlocal a_desc, b_desc, c_desc

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nonlocal a_desc, b_desc, c_desc statement is unnecessary here. These variables are not being reassigned within the grid function, so they can be accessed from the enclosing scope without the nonlocal keyword.

Comment thread linghe/facade/gemm.py Outdated
Comment thread linghe/facade/linear.py
wt_scale if ctx.input_requires_grad else None,
(
hadamard_matrix
if ctx.weight_requires_grad or ctx.weight_requires_grad

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There seems to be a typo in this condition. ctx.weight_requires_grad is checked twice. It should probably be ctx.weight_requires_grad or ctx.input_requires_grad.

Suggested change
if ctx.weight_requires_grad or ctx.weight_requires_grad
if ctx.weight_requires_grad or ctx.input_requires_grad

Comment thread linghe/facade/linear.py
w_scale if ctx.input_requires_grad else None,
(
smooth_scale
if ctx.weight_requires_grad or ctx.weight_requires_grad

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There seems to be a typo in this condition. ctx.weight_requires_grad is checked twice. It should probably be ctx.weight_requires_grad or ctx.input_requires_grad.

Suggested change
if ctx.weight_requires_grad or ctx.weight_requires_grad
if ctx.weight_requires_grad or ctx.input_requires_grad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants