[WIP]Full fine tune support#2020
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces 'Full Weight Gradient' mode to support full fine-tuning and hybrid training by implementing C++ kernels for base weight gradient accumulation and updating the Python autograd infrastructure. Feedback identifies critical bugs in the C++ implementation, specifically incorrect indexing of token positions and potential buffer overflows in the memory pool. High-severity issues were also raised regarding the lack of vectorization in gradient computation and the overwriting of gradients instead of accumulation. Further optimizations were suggested for GPU memory allocation and the removal of explicit garbage collection calls to prevent performance degradation.
| pos_start += cache.m_local_num_cache[cache.m_expert_id_map_cache[prev_id]]; | ||
| } | ||
|
|
||
| const auto& local_pos = cache.m_local_pos_cache[expert_idx]; |
There was a problem hiding this comment.
| float* acc_gate = static_cast<float*>(forward_pool_); // [I, H] | ||
| float* acc_up = acc_gate + (size_t)I * H; // [I, H] | ||
| float* acc_down = acc_up + (size_t)I * H; // [H, I] |
There was a problem hiding this comment.
Critical issue: forward_pool_ is used for large FP32 accumulators (acc_gate, acc_up, acc_down), but its size is only determined by LoRA working buffers in alloc_forward_buffers. For full fine-tuning, these accumulators require (2 * I * H + H * I) * sizeof(float) bytes, which can be hundreds of megabytes (e.g., ~175MB for I=2048, H=7168). This will cause a buffer overflow if forward_pool_ is not explicitly resized to accommodate these.
| // gate_proj grad: [I, H] += grad_gate_out[t]^T @ input[t] | ||
| for (int i = 0; i < I; i++) { | ||
| float gg = GGML_BF16_TO_FP32(gate_grad_row[i]); | ||
| float gu = GGML_BF16_TO_FP32(up_grad_row[i]); | ||
| for (int h = 0; h < H; h++) { | ||
| float inp = GGML_BF16_TO_FP32(input_row[h]); | ||
| acc_gate[i * H + h] += gg * inp; | ||
| acc_up[i * H + h] += gu * inp; | ||
| } | ||
| } | ||
|
|
||
| // down_proj grad: [H, I] += grad_output[t]^T @ intermediate[t] | ||
| for (int h = 0; h < H; h++) { | ||
| float go = GGML_BF16_TO_FP32(grad_out_row[h]); | ||
| for (int i = 0; i < I; i++) { | ||
| acc_down[h * I + i] += go * GGML_BF16_TO_FP32(inter_row[i]); | ||
| } | ||
| } |
There was a problem hiding this comment.
High severity performance issue: The base weight gradient computation uses a triple nested loop without vectorization or parallelization. For typical MoE sizes, this results in billions of operations executed sequentially on a single thread. This should be implemented using a vectorized outer product (rank-1 update) or a GEMM (
| ggp[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_gate[i * H + h]); | ||
| gup_ptr[(size_t)expert_idx * I * H + (size_t)i * H + h] = GGML_FP32_TO_BF16(acc_up[i * H + h]); | ||
| } | ||
| } | ||
| for (int h = 0; h < H; h++) { | ||
| for (int i = 0; i < I; i++) { | ||
| gdp[(size_t)expert_idx * H * I + (size_t)h * I + i] = GGML_FP32_TO_BF16(acc_down[h * I + i]); | ||
| } |
There was a problem hiding this comment.
High severity issue: The gradients are being overwritten in the output buffer (ggp[...] = ...) instead of being accumulated. This will break gradient accumulation if multiple micro-batches are used before an optimizer step. While PyTorch usually handles accumulation, here the C++ kernel writes directly into a persistent buffer in the wrapper, so it must perform in-place accumulation if the buffer is reused.
| wrapper.wrapper.grad_down_proj_buf, | ||
| ): | ||
| if grad_buf is not None: | ||
| grad_gpu = grad_buf.cuda() |
There was a problem hiding this comment.
| import gc | ||
| gc.collect() |
There was a problem hiding this comment.
Medium severity performance issue: Explicitly calling gc.collect() during weight loading can cause significant pauses, especially if called frequently during training (e.g., in the update_base_weights fallback path). It is better to rely on standard reference counting or investigate why del old_moe is insufficient to release the C++ object.
What does this PR do?
support full finetune for KT
Fixes # (issue)
Before submitting