diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu index 3185a9a86b231..9db229555ecbc 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.cu @@ -200,204 +200,6 @@ bool tryLaunchMoeGemvIntSymmetricInterleavedSwiGLU( } } -/** - * Takes the input maps and prepares the expanded maps for min latency - * @param num_active_experts_per_node: Number of active experts on current node - * @param experts_to_token_scores: The score of each token for each activated expert. 0 if the expert is not chosen by - * the token. Only the first num_active_experts_per_ rows are valid - * @param active_expert_global_ids: The global expert id for each activated expert - * Only the first num_active_experts_per_ values are valid - * @param expert_first_token_offset: Store the first token offset for each expert - */ -template -__device__ __forceinline__ void initTensor(T* value, int const tid, int const total_num, T const init_value) { - for (int i = tid; i < total_num; i += BLOCK_SIZE) { - value[i] = init_value; - } -} - -template -__device__ __forceinline__ void setLocalExperts(int* s_local_experts, T const* token_selected_experts, - int const total_num_experts, int const tid, int const start_expert, int const end_expert) { - for (int i = tid; i < total_num_experts; i += BLOCK_SIZE) { - int const expert = token_selected_experts[i]; - - // If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node) - bool is_valid_expert = expert >= start_expert && expert < end_expert; - if (is_valid_expert) { - int local_expert_id = expert - start_expert; - if (s_local_experts[local_expert_id] == 0) { - s_local_experts[local_expert_id] = 1; // @TODO: Make sure that we allow duplicated write here - } - } - } - __syncthreads(); -} - -template -__device__ __forceinline__ void prefixSum(T* out, T* in, int const num, int const tid) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage tempStorage; - - T threadData = 0; - if (tid < num) { - threadData = in[tid]; - } - - BlockScan(tempStorage).InclusiveSum(threadData, threadData); - __syncthreads(); - - if (tid < num) { - out[tid] = threadData; - } - __syncthreads(); -} - -__device__ __forceinline__ void setActiveNum(int& num_active, int& num_active_offset_start, int& num_active_offset_end, - int const cluster_size, int const cluster_rank) { - int num_remainder = num_active % cluster_size; - int num_active_per_node = max(0, num_active - 1) / cluster_size; // num_active_per_node shouldn't be neg - if (cluster_rank < num_remainder) { - num_active = num_active_per_node + 1; - num_active_offset_start = cluster_rank * num_active; - } else { - num_active = num_active_per_node; - num_active_offset_start = cluster_rank * num_active_per_node + num_remainder; - } - num_active_offset_end = num_active_offset_start + num_active; -} - -template -__global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_per_node, float* experts_to_token_scores, - int* active_expert_global_ids, int64_t* expert_first_token_offset, int const* token_selected_experts, - float const* token_final_scales, int64_t const num_tokens, int const num_experts_per_token, int const start_expert, - int const end_expert, int const num_experts_per_node, bool const smart_routing, int const cluster_rank, - int const cluster_size, int const num_experts_smem) { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - // Use one block to process the min latency case - int tid = threadIdx.x; - // 0. init the global memory experts_to_token_scores [num_experts_per_node, num_token] - int const total_local_scales = num_experts_per_node * num_tokens; - initTensor(experts_to_token_scores, tid, total_local_scales, 0.0f); - initTensor(active_expert_global_ids, tid, num_experts_per_node, -1); - - __threadfence(); //@Todo: check do I need this fence for previous zero setting - - // 1. mask for the active expert: 1 stands for active - extern __shared__ int s_local_experts[]; - int* s_store_experts = s_local_experts + num_experts_smem; - initTensor(s_local_experts, tid, num_experts_smem, 0); - __syncthreads(); - - // 2. set the shared array s_local_experts[] - int const total_num_experts = num_tokens * num_experts_per_token; - setLocalExperts( - s_local_experts, token_selected_experts, total_num_experts, tid, start_expert, end_expert); - - // 3. perform prefix sum to acquire the store position and total active experts - //@TODO: Use cub first, might need to change it to self-defined api - prefixSum(s_store_experts, s_local_experts, num_experts_smem, tid); - - // 4. store the num of active experts - int num_active = s_store_experts[num_experts_smem - 1]; - int num_active_offset_start = 0; - int num_active_offset_end = 0; - - if (smart_routing) { - setActiveNum(num_active, num_active_offset_start, num_active_offset_end, cluster_size, cluster_rank); - } - - if (tid == 0) { - *num_active_experts_per_node = num_active; - } - - // 5. store the global expert id for each expert - if (smart_routing) { - for (int i = tid; i < num_experts_smem; i += BLOCK_SIZE) { - if (s_local_experts[i]) { - int offset = s_store_experts[i] - 1; - if (offset >= num_active_offset_start && offset < num_active_offset_end) { - active_expert_global_ids[offset - num_active_offset_start] = i; - } else { - s_local_experts[i] = 0; - } - } - } - __syncthreads(); // Need sync to update the s_local_experts - } else { - for (int i = tid; i < num_experts_smem; i += BLOCK_SIZE) { - if (s_local_experts[i]) { - int offset = s_store_experts[i] - 1; - active_expert_global_ids[offset] = i + start_expert; - } - } - } - - // 6. store the scale values - __threadfence(); //@Todo: check do I need this fence for previous zero setting - for (int i = tid; i < total_num_experts; i += BLOCK_SIZE) { - int const expert = token_selected_experts[i]; - - // If expert is not in the current node, set it to num_experts_per_node - // If expert is in the current node, subtract start_expert to shift the range to [0, num_experts_per_node) - bool is_valid_expert = smart_routing ? s_local_experts[expert] : (expert >= start_expert && expert < end_expert); - - if (is_valid_expert) { - int token = i / num_experts_per_token; - float const scale = token_final_scales[i]; - int offset = s_store_experts[expert - start_expert] - 1 - num_active_offset_start; - experts_to_token_scores[offset * num_tokens + token] = scale; - } - } - // 7. set default value for redundant memory - for (int i_exp = num_active + tid; i_exp < num_experts_per_node; i_exp += BLOCK_SIZE) { - active_expert_global_ids[i_exp] = -1; - } - // 8. set expert_first_token_offset - for (int i_exp = tid; i_exp < num_experts_per_node + 1; i_exp += BLOCK_SIZE) { - if (i_exp < num_active) { - expert_first_token_offset[i_exp] = i_exp * num_tokens; - } else { - expert_first_token_offset[i_exp] = num_active * num_tokens; - } - } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif -} - -void buildMinLatencyActiveExpertMaps(int* num_active_experts_per_node, float* experts_to_token_scores, - int* active_expert_global_ids, int64_t* expert_first_token_offset, int const* token_selected_experts, - float const* token_final_scales, int64_t const num_tokens, int const experts_per_token, int const start_expert, - int const end_expert, int const num_experts_per_node, int const cluster_rank, int const cluster_size, - int const num_experts_smem, cudaStream_t const stream) { - ORT_ENFORCE(num_experts_per_node == (end_expert - start_expert), - "num_experts_per_node must be equal to end_expert - start_expert"); - - ORT_ENFORCE(num_experts_per_node <= 256, "don't support num_experts_per_node > 256 cases"); - - int const threads = 256; - int const blocks = 1; - bool const smart_routing = cluster_size > 1; - - cudaLaunchConfig_t config; - config.gridDim = blocks; - config.blockDim = threads; - config.dynamicSmemBytes = num_experts_smem * sizeof(int) * 2; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = onnxruntime::llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, buildMinLatencyActiveExpertMapsKernel, num_active_experts_per_node, - experts_to_token_scores, active_expert_global_ids, expert_first_token_offset, token_selected_experts, - token_final_scales, num_tokens, experts_per_token, start_expert, end_expert, num_experts_per_node, - smart_routing, cluster_rank, cluster_size, num_experts_smem); -} - template __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_selected_experts, int* const permuted_row_to_unpermuted_row, int* const unpermuted_row_to_permuted_row, @@ -1168,123 +970,6 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir #endif } -template -__global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int64_t const num_experts_per_node, T const* in1, T const* in2, - WeightType const* weights1, WeightType const* weights2, float const* alpha_scale_flat1, - float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, OutputType* output1, OutputType* output2, - int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert) { - // First, compute the global tid. We only need 1 thread per expert. - int const expert = blockIdx.x * blockDim.x + threadIdx.x; - - if (expert >= num_experts_per_node) { - return; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - - // Note: expert is used to calculate the offset of the input and output - // local_expert is used to calculate the offset of the weight - auto const num_tokens_before_expert = expert * num_tokens; - bool const is_active_expert = expert < *num_active_experts_per; - int const local_expert = is_active_expert ? active_expert_global_ids[expert] - start_expert : -1; - auto const gemm_m = is_active_expert ? num_tokens : 0; - - // M and N transposed since we are using the #tokens as the N dimension - layout_info1.shape_info.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm1_n, gemm_m, gemm1_k); - layout_info2.shape_info.problem_shapes[expert] = TmaWarpSpecializedGroupedGemmInput::ProblemShape::UnderlyingProblemShape(gemm2_n, gemm_m, gemm2_k); - - if (alpha_scale_flat1) { - assert(alpha_scale_flat2); - if (is_active_expert) { - layout_info1.alpha_scale_ptr_array[expert] = alpha_scale_flat1 + local_expert; - layout_info2.alpha_scale_ptr_array[expert] = alpha_scale_flat2 + local_expert; - } else { - layout_info1.alpha_scale_ptr_array[expert] = nullptr; - layout_info2.alpha_scale_ptr_array[expert] = nullptr; - } - } - - if (quant_params.fp4.fc1.weight_block_scale) { - setupFP4BlockScalingFactors(layout_info1, expert, - gemm_m, gemm1_n, gemm1_k, fp4_act_flat1, quant_params.fp4.fc1.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc1 uses the same A input for all experts and the scaling factor B offsets from - // the local expert index - if (is_active_expert) { - layout_info1.fpX_block_scaling_factors_A[expert] = fp4_act_flat1; - layout_info1.fpX_block_scaling_factors_B[expert] = quant_params.fp4.fc1.weight_block_scale + getOffsetWeightSF( - local_expert, gemm1_n, gemm1_k, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info1.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info1.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - - if (quant_params.fp4.fc2.weight_block_scale) { - setupFP4BlockScalingFactors(layout_info2, expert, - gemm_m, gemm2_n, gemm2_k, fp4_act_flat2, quant_params.fp4.fc2.weight_block_scale, num_tokens_before_expert); - - // Override the scaling factors, fc2 scaling factor B offsets by the local expert index - if (is_active_expert) { - layout_info2.fpX_block_scaling_factors_B[expert] = quant_params.fp4.fc2.weight_block_scale + getOffsetWeightSF( - local_expert, gemm2_n, gemm2_k, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - } else { - layout_info2.fpX_block_scaling_factors_A[expert] = nullptr; - layout_info2.fpX_block_scaling_factors_B[expert] = nullptr; - } - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif - - assert(gemm_m <= INT32_MAX); - assert(gemm1_n > 0 && gemm1_n <= INT32_MAX); - assert(gemm1_k > 0 && gemm1_k <= INT32_MAX); - assert(gemm2_n > 0 && gemm2_n <= INT32_MAX); - assert(gemm2_k > 0 && gemm2_k <= INT32_MAX); - computeTmaWarpSpecializedInputStrides(layout_info1, gemm_m, gemm1_n, gemm1_k, expert, - cutlass::gemm::collective::detail::int4_group_size); - computeTmaWarpSpecializedInputStrides(layout_info2, gemm_m, gemm2_n, gemm2_k, expert, - cutlass::gemm::collective::detail::int4_group_size); - - if (is_active_expert) { - // Note: under low latency mode, we use the same input for all experts - // so for gemm1, the inputs are the same, - // for gemm2, we use the input generated by gemm1 - layout_info1.ptr_a[expert] = in1; - layout_info2.ptr_a[expert] = safe_inc_ptr(in2, expert * num_tokens * gemm2_k); - - // Each expert's weight matrix is a constant size NxK, get the matrix at index `expert` - layout_info1.ptr_b[expert] = safe_inc_ptr(weights1, local_expert * (gemm1_n * gemm2_k)); - layout_info2.ptr_b[expert] = safe_inc_ptr(weights2, local_expert * (gemm1_n * gemm2_k)); - - assert(layout_info1.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE); - layout_info1.default_epilogue.ptr_d[expert] = safe_inc_ptr(output1, expert * num_tokens * gemm1_n); - - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - // The output prior to this contains N elements per token, with `num_tokens` tokens - layout_info2.default_epilogue.ptr_d[expert] = safe_inc_ptr(output2, expert * num_tokens * gemm2_n); - } - } else { - layout_info1.ptr_a[expert] = nullptr; - layout_info2.ptr_a[expert] = nullptr; - layout_info1.ptr_b[expert] = nullptr; - layout_info2.ptr_b[expert] = nullptr; - - layout_info1.default_epilogue.ptr_d[expert] = nullptr; - if (layout_info2.fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE) { - layout_info2.default_epilogue.ptr_d[expert] = nullptr; - } - } -} - // ========================== Permutation things ======================================= // Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. @@ -2886,68 +2571,6 @@ CutlassMoeFCRunner: return std::make_pair(layout_info1, layout_info2); } -template -std::pair -CutlassMoeFCRunner::computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, T const* input1, T const* input2, - WeightType const* weights1, WeightType const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* output1, - UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, - int start_expert, cudaStream_t stream) { - ORT_ENFORCE(!use_w4afp8, "W4AFP8 is not supported in low latency mode"); - - // Always nullptr - layout_info1.ptr_c = nullptr; - layout_info1.stride_c = nullptr; - layout_info2.ptr_c = nullptr; - layout_info2.stride_c = nullptr; - - auto alpha_scale_flat1 = use_fp4 ? quant_params.fp4.fc1.global_scale - : use_wfp4afp8 ? (quant_params.fp8_mxfp4.fc1.global_scale - ? quant_params.fp8_mxfp4.fc1.global_scale - : quant_params.mxfp8_mxfp4.fc1.global_scale) - : fp8_dequant1; - auto alpha_scale_flat2 = use_fp4 ? quant_params.fp4.fc2.global_scale - : use_wfp4afp8 ? (quant_params.fp8_mxfp4.fc2.global_scale - ? quant_params.fp8_mxfp4.fc2.global_scale - : quant_params.mxfp8_mxfp4.fc2.global_scale) - : fp8_dequant2; - if (!alpha_scale_flat1) { - layout_info1.alpha_scale_ptr_array = nullptr; - } - if (!alpha_scale_flat2) { - layout_info2.alpha_scale_ptr_array = nullptr; - } - - layout_info1.int4_groupwise_params.enabled = false; - layout_info2.int4_groupwise_params.enabled = false; - - int const threads = std::min(1024, num_experts); - int const blocks = (num_experts + threads - 1) / threads; - - cudaLaunchConfig_t config; - config.gridDim = blocks; - config.blockDim = threads; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = onnxruntime::llm::common::getEnvEnablePDL(); - config.numAttrs = 1; - config.attrs = attrs; - cudaLaunchKernelEx(&config, - computeStridesTmaWarpSpecializedLowLatencyKernel, layout_info1, - layout_info2, num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts, input1, input2, weights1, weights2, - alpha_scale_flat1, alpha_scale_flat2, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, bias1, bias2, output1, - output2, num_active_experts_per, active_expert_global_ids, start_expert); - - return std::make_pair(layout_info1, layout_info2); -} - template std::pair CutlassMoeFCRunner::setupTmaWarpSpecializedInputs( @@ -3083,78 +2706,6 @@ __global__ void populateRandomBufferKernel(void* buffer_void, size_t size) { buffer[tid * elem_per_thread + i] = curand4(&state); } -template -__global__ void prepareMinLatencyBuffer(int* num_active_experts_per_node, int* active_expert_global_ids, - int64_t* expert_first_token_offset, int const num_tokens, int const num_experts_per_token, - int const num_experts_per_node) { - int tid = threadIdx.x; - int bid = blockIdx.x; - - // 0. set offset - num_active_experts_per_node += bid; - active_expert_global_ids += bid * num_experts_per_node; - expert_first_token_offset += bid * (num_experts_per_node + 1); - - // 1. set the num_active_experts_per_node - int num_active = max(1, (int)(bid * ((float)num_experts_per_node / NUM_ROUTING_SAMPLES))); - *num_active_experts_per_node = num_active; - - // 2. generate random active experts - extern __shared__ float s_buf[]; - float* expert_refs = s_buf; - int* expert_refs_idx = reinterpret_cast(expert_refs + num_experts_per_node); - - curandState_t local_state; - curand_init(bid, tid, 0, &local_state); - for (int i = tid; i < num_experts_per_node; i += BLOCK_SIZE) { - expert_refs[i] = (float)curand_uniform(&local_state); - expert_refs_idx[i] = (int)i; - } - __syncthreads(); - - float thread_key[1]; - int thread_value[1]; - thread_key[0] = std::numeric_limits::max(); - thread_value[0] = num_experts_per_node; - - if (tid < num_experts_per_node) { - thread_key[0] = expert_refs[tid]; - thread_value[0] = expert_refs_idx[tid]; - } - - using BlockRadixSort = cub::BlockRadixSort; - using BlockRadixSortValue = cub::BlockRadixSort; - - union TempStorage { - typename BlockRadixSort::TempStorage key_value; - typename BlockRadixSortValue::TempStorage value; - }; - __shared__ union TempStorage temp_storage; - - BlockRadixSort(temp_storage.key_value).Sort(thread_key, thread_value); - __syncthreads(); - - if (tid > num_active) { - thread_value[0] = std::numeric_limits::max(); - } - BlockRadixSortValue(temp_storage.value).Sort(thread_value); - __syncthreads(); - - // 3. set the active_expert_global_ids and expert_first_token_offset - for (int i = tid; i < num_experts_per_node; i += BLOCK_SIZE) { - if (i < num_active) { - active_expert_global_ids[i] = thread_value[0]; - expert_first_token_offset[i] = i * num_tokens; - } else { - active_expert_global_ids[i] = -1; - expert_first_token_offset[i] = num_active * num_tokens; - } - } - if (tid == 0) { - expert_first_token_offset[num_experts_per_node] = num_active * num_tokens; - } -} - void populateRandomBuffer(void* buffer_void, size_t size, cudaStream_t stream) { // Each thread initialises 128 bytes ORT_ENFORCE(size % 128 == 0, "Unexpected size alignment"); @@ -3292,10 +2843,6 @@ std::map> GemmProfilerBackend::getProfile size_t const blocked_expert_counts_cumsum_size = blocked_expert_counts_size; size_t const blocked_row_to_unpermuted_row_size = num_experts_per_node * maxM * sizeof(int); - // The follow buffers are used in min_latency_mode - size_t num_active_experts_per_node_size = 0; - size_t active_expert_global_ids_size = 0; - size_t map_offset = 0; std::map> out_map; @@ -3316,8 +2863,6 @@ std::map> GemmProfilerBackend::getProfile ADD(blocked_expert_counts_cumsum); ADD(blocked_row_to_unpermuted_row); ADD(token_topk_unpermuted_scales); - ADD(num_active_experts_per_node); - ADD(active_expert_global_ids); ADD(input); ADD(output); ADD(intermediate); @@ -3358,8 +2903,6 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha GET_WS_PTR(int*, blocked_expert_counts); GET_WS_PTR(int*, blocked_expert_counts_cumsum); GET_WS_PTR(int*, blocked_row_to_unpermuted_row); - GET_WS_PTR(int*, num_active_experts_per_node); - GET_WS_PTR(int*, active_expert_global_ids); #undef GET_WS_PTR_BASE #undef GET_WS_PTR @@ -3473,8 +3016,6 @@ void GemmProfilerBackend::prepareTmaWsInputs( GET_WS_PTR(void*, gemm_workspace); GET_WS_PTR(float*, alpha_scale_ptr_array); GET_WS_PTR(TmaWarpSpecializedGroupedGemmInput::ElementSF*, fp4_act_scale_flat); - GET_WS_PTR(int*, num_active_experts_per_node); - GET_WS_PTR(int*, active_expert_global_ids); #undef GET_WS_PTR @@ -3573,8 +3114,6 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac GET_WS_PTR(float const*, token_topk_unpermuted_scales); auto const* token_topk_permuted_scales = token_topk_unpermuted_scales; - GET_WS_PTR_OFFSET(int*, num_active_experts_per_node, mSampleIndex); - GET_WS_PTR_OFFSET(int*, active_expert_global_ids, (mSampleIndex * mNumExpertsPerNode)); GET_WS_PTR(void const*, input); GET_WS_PTR(void*, output); GET_WS_PTR(void*, intermediate); diff --git a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h index 9f780c313e1fd..bf9ab2e684c5a 100644 --- a/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h +++ b/onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_kernels.h @@ -305,16 +305,6 @@ class CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) = 0; - virtual std::pair - computeStridesTmaWarpSpecializedLowLatencyDispatch(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, void const* input1, void const* input2, - void const* weights1, void const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - void const* bias1, void const* bias2, void* output1, void* output2, int const* num_active_experts_per, - int const* active_expert_global_ids, int start_expert, cudaStream_t stream) = 0; - virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; bool is_profiler = false; @@ -521,25 +511,6 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { reinterpret_cast(gemm2_output), stream); } - std::pair - computeStridesTmaWarpSpecializedLowLatencyDispatch(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, void const* input1, void const* input2, - void const* weights1, void const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - void const* bias1, void const* bias2, void* output1, void* output2, int const* num_active_experts_per, - int const* active_expert_global_ids, int start_expert, cudaStream_t stream) override { - return Self::computeStridesTmaWarpSpecializedLowLatency(layout_info1, layout_info2, num_tokens, gemm1_n, - gemm1_k, gemm2_n, gemm2_k, num_experts, reinterpret_cast(input1), - reinterpret_cast(input2), reinterpret_cast(weights1), - reinterpret_cast(weights2), fp8_dequant1, fp8_dequant2, fc1_fp4_act_flat, - fc2_fp4_act_flat, quant_params, reinterpret_cast(bias1), - reinterpret_cast(bias2), reinterpret_cast(output1), - reinterpret_cast(output2), num_active_experts_per, active_expert_global_ids, - start_expert, stream); - } - private: std::pair setupTmaWarpSpecializedInputs( int64_t num_rows, int64_t expanded_num_rows, ActivationType fc1_activation_type, bool use_fused_gated_activation, @@ -560,16 +531,6 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); - static std::pair - computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1, - TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k, - int64_t gemm2_n, int64_t gemm2_k, int const num_experts, T const* input1, T const* input2, - WeightType const* weights1, WeightType const* weights2, float const* fp8_dequant1, float const* fp8_dequant2, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, - ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* output1, - UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, - int start_expert, cudaStream_t stream); std::map> getWorkspaceDeviceBufferSizes(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type,