Skip to content

[WIP]Full fine tune support#2020

Open
poryfly wants to merge 2 commits into
kvcache-ai:mainfrom
poryfly:full-fine-tune-support
Open

[WIP]Full fine tune support#2020
poryfly wants to merge 2 commits into
kvcache-ai:mainfrom
poryfly:full-fine-tune-support

Conversation

@poryfly
Copy link
Copy Markdown
Contributor

@poryfly poryfly commented May 22, 2026

What does this PR do?

support full finetune for KT
Fixes # (issue)

Before submitting

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces 'Full Weight Gradient' mode to support full fine-tuning and hybrid training by implementing C++ kernels for base weight gradient accumulation and updating the Python autograd infrastructure. Feedback identifies critical bugs in the C++ implementation, specifically incorrect indexing of token positions and potential buffer overflows in the memory pool. High-severity issues were also raised regarding the lack of vectorization in gradient computation and the overwriting of gradients instead of accumulation. Further optimizations were suggested for GPU memory allocation and the removal of explicit garbage collection calls to prevent performance degradation.

pos_start += cache.m_local_num_cache[cache.m_expert_id_map_cache[prev_id]];
}

const auto& local_pos = cache.m_local_pos_cache[expert_idx];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical bug: cache.m_local_pos_cache is indexed by token ID, not expert index. Accessing it with expert_idx will lead to out-of-bounds access or incorrect data. You need a mapping from expert to the original token indices to correctly retrieve input_row and grad_out_row from the global buffers.

Comment on lines +1943 to +1945
float* acc_gate = static_cast<float*>(forward_pool_); // [I, H]
float* acc_up = acc_gate + (size_t)I * H; // [I, H]
float* acc_down = acc_up + (size_t)I * H; // [H, I]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Critical issue: forward_pool_ is used for large FP32 accumulators (acc_gate, acc_up, acc_down), but its size is only determined by LoRA working buffers in alloc_forward_buffers. For full fine-tuning, these accumulators require (2 * I * H + H * I) * sizeof(float) bytes, which can be hundreds of megabytes (e.g., ~175MB for I=2048, H=7168). This will cause a buffer overflow if forward_pool_ is not explicitly resized to accommodate these.

Comment on lines +1959 to +1976
// gate_proj grad: [I, H] += grad_gate_out[t]^T @ input[t]
for (int i = 0; i < I; i++) {
float gg = GGML_BF16_TO_FP32(gate_grad_row[i]);
float gu = GGML_BF16_TO_FP32(up_grad_row[i]);
for (int h = 0; h < H; h++) {
float inp = GGML_BF16_TO_FP32(input_row[h]);
acc_gate[i * H + h] += gg * inp;
acc_up[i * H + h] += gu * inp;
}
}

// down_proj grad: [H, I] += grad_output[t]^T @ intermediate[t]
for (int h = 0; h < H; h++) {
float go = GGML_BF16_TO_FP32(grad_out_row[h]);
for (int i = 0; i < I; i++) {
acc_down[h * I + i] += go * GGML_BF16_TO_FP32(inter_row[i]);
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High severity performance issue: The base weight gradient computation uses a triple nested loop without vectorization or parallelization. For typical MoE sizes, this results in billions of operations executed sequentially on a single thread. This should be implemented using a vectorized outer product (rank-1 update) or a GEMM ($Grad = Grad_Out^T \times Input$), and parallelized across experts using the worker pool.

Comment on lines +1982 to +1989
ggp[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_gate[i * H + h]);
gup_ptr[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_up[i * H + h]);
}
}
for (int h = 0; h < H; h++) {
for (int i = 0; i < I; i++) {
gdp[(size_t)expert_idx * H * I + (size_t)h * I + i] = GGML_FP32_TO_BF16(acc_down[h * I + i]);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

High severity issue: The gradients are being overwritten in the output buffer (ggp[...] = ...) instead of being accumulated. This will break gradient accumulation if multiple micro-batches are used before an optimizer step. While PyTorch usually handles accumulation, here the C++ kernel writes directly into a persistent buffer in the wrapper, so it must perform in-place accumulation if the buffer is reused.

wrapper.wrapper.grad_down_proj_buf,
):
if grad_buf is not None:
grad_gpu = grad_buf.cuda()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Medium severity performance issue: grad_buf.cuda() creates a new GPU allocation and performs a synchronous copy every time sync_kt_lora_gradients is called. This should be optimized by using pre-allocated GPU buffers to avoid memory fragmentation and allocation overhead in the training loop.

Comment on lines +251 to +252
import gc
gc.collect()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Medium severity performance issue: Explicitly calling gc.collect() during weight loading can cause significant pauses, especially if called frequently during training (e.g., in the update_base_weights fallback path). It is better to rely on standard reference counting or investigate why del old_moe is insufficient to release the C++ object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant