-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[sgl-kernel][5/N]Support Expert Specialization Grouped GEMM #12666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
0015e2e
5fec76c
e22af7b
aba026a
3893bf0
fe311aa
ad6f5d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,18 @@ | ||
| """CUTLASS based Fused MoE kernels.""" | ||
|
|
||
| from typing import Optional | ||
| from typing import Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams | ||
| from sglang.srt.utils import is_cuda | ||
| from sglang.srt.utils import is_cuda, is_sm90_supported | ||
|
|
||
| _is_cuda = is_cuda() | ||
| if _is_cuda: | ||
| from sgl_kernel import ( | ||
| apply_shuffle_mul_sum, | ||
| cutlass_fp4_group_mm, | ||
| es_fp8_blockwise_scaled_grouped_mm, | ||
| fp8_blockwise_scaled_grouped_mm, | ||
| prepare_moe_input, | ||
| scaled_fp4_experts_quant, | ||
|
|
@@ -43,6 +44,7 @@ def cutlass_fused_experts_fp8( | |
| problem_sizes2: torch.Tensor, | ||
| use_fp8_blockscale: bool = True, | ||
| output: Optional[torch.Tensor] = None, | ||
| enable_es: Tuple[bool, bool] = (False, False), | ||
| ) -> torch.Tensor: | ||
| """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations. | ||
|
|
||
|
|
@@ -98,6 +100,7 @@ def cutlass_fused_experts_fp8( | |
| use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with | ||
| block scaling. Currently, only `True` is supported. Defaults to `True`. | ||
| output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created. | ||
| enable_es (tuple(bool, bool)): Flag indicating usage of expert specialization kernel for (up-projection, down-projection) | ||
| Returns: | ||
| torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`. | ||
|
|
||
|
|
@@ -121,7 +124,7 @@ def cutlass_fused_experts_fp8( | |
| from sglang.srt.layers.quantization.fp8_kernel import ( | ||
| sglang_per_token_group_quant_fp8, | ||
| ) | ||
|
|
||
| es_up, es_down = enable_es | ||
| out_dtype = a.dtype | ||
| num_experts = w1_q.size(0) | ||
| m = a.size(0) | ||
|
|
@@ -156,52 +159,82 @@ def cutlass_fused_experts_fp8( | |
| a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int) | ||
| w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int) | ||
|
|
||
| fp8_blockwise_scaled_grouped_mm( | ||
| c1, | ||
| a_ptrs, | ||
| b_ptrs, | ||
| out_ptrs, | ||
| a_scales_ptrs, | ||
| b_scales_ptrs, | ||
| rep_a_q, | ||
| w1_q, | ||
| rep_a1_scales, | ||
| w1_scale, | ||
| a1_strides, | ||
| a1_strides, | ||
| c1_strides, | ||
| a_sf_layout, | ||
| w_sf_layout, | ||
| problem_sizes1, | ||
| expert_offsets[:-1], | ||
| workspace, | ||
| ) | ||
| if is_sm90_supported() and es_up: | ||
| es_fp8_blockwise_scaled_grouped_mm( | ||
| c1, | ||
| rep_a_q, | ||
| w1_q, | ||
| rep_a1_scales, | ||
| w1_scale, | ||
| a1_strides, | ||
| a1_strides, | ||
| c1_strides, | ||
| problem_sizes1, | ||
| expert_offsets[:-1], | ||
| workspace, | ||
| ) | ||
| else: | ||
| fp8_blockwise_scaled_grouped_mm( | ||
| c1, | ||
| a_ptrs, | ||
| b_ptrs, | ||
| out_ptrs, | ||
| a_scales_ptrs, | ||
| b_scales_ptrs, | ||
| rep_a_q, | ||
| w1_q, | ||
| rep_a1_scales, | ||
| w1_scale, | ||
| a1_strides, | ||
| a1_strides, | ||
| c1_strides, | ||
| a_sf_layout, | ||
| w_sf_layout, | ||
| problem_sizes1, | ||
| expert_offsets[:-1], | ||
| workspace, | ||
| ) | ||
|
|
||
| intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) | ||
| silu_and_mul(c1, intermediate) | ||
|
|
||
| intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) | ||
|
|
||
| fp8_blockwise_scaled_grouped_mm( | ||
| c2, | ||
| a_ptrs, | ||
| b_ptrs, | ||
| out_ptrs, | ||
| a_scales_ptrs, | ||
| b_scales_ptrs, | ||
| intemediate_q, | ||
| w2_q, | ||
| a2_scale, | ||
| w2_scale, | ||
| a2_strides, | ||
| a2_strides, | ||
| c2_strides, | ||
| a_sf_layout, | ||
| w_sf_layout, | ||
| problem_sizes2, | ||
| expert_offsets[:-1], | ||
| workspace, | ||
| ) | ||
| if is_sm90_supported() and es_down: | ||
| es_fp8_blockwise_scaled_grouped_mm( | ||
| c2, | ||
| intemediate_q, | ||
| w2_q, | ||
| a2_scale, | ||
| w2_scale, | ||
| a2_strides, | ||
| a2_strides, | ||
| c2_strides, | ||
| problem_sizes2, | ||
| expert_offsets[:-1], | ||
| workspace, | ||
| ) | ||
| else: | ||
| fp8_blockwise_scaled_grouped_mm( | ||
| c2, | ||
| a_ptrs, | ||
| b_ptrs, | ||
| out_ptrs, | ||
| a_scales_ptrs, | ||
| b_scales_ptrs, | ||
| intemediate_q, | ||
| w2_q, | ||
| a2_scale, | ||
| w2_scale, | ||
| a2_strides, | ||
| a2_strides, | ||
| c2_strides, | ||
| a_sf_layout, | ||
| w_sf_layout, | ||
| problem_sizes2, | ||
| expert_offsets[:-1], | ||
| workspace, | ||
| ) | ||
|
Comment on lines
+162
to
+237
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is significant code duplication between the GEMM calls for the up-projection (lines 162-196) and the down-projection (lines 203-237). Both blocks contain similar To improve this, you could refactor the logic into a helper function. This function would encapsulate the conditional dispatch to either Here is an example of how you could structure such a helper: def _dispatch_gemm(is_es_enabled, out, a, b, a_scales, b_scales, a_strides, c_strides, problem_sizes, expert_offsets, workspace, **kwargs):
if is_sm90_supported() and is_es_enabled:
es_fp8_blockwise_scaled_grouped_mm(
out,
a,
b,
a_scales,
b_scales,
a_strides,
a_strides, # b_strides for GEMM
c_strides,
problem_sizes,
expert_offsets,
workspace,
)
else:
fp8_blockwise_scaled_grouped_mm(
out,
kwargs["a_ptrs"],
kwargs["b_ptrs"],
kwargs["out_ptrs"],
kwargs["a_scales_ptrs"],
kwargs["b_scales_ptrs"],
a,
b,
a_scales,
b_scales,
a_strides,
a_strides, # b_strides for GEMM
c_strides,
kwargs["a_sf_layout"],
kwargs["w_sf_layout"],
problem_sizes,
expert_offsets,
workspace,
)This would make the main function body much cleaner and avoid repeating the same logic. |
||
|
|
||
| if output is None: | ||
| output = torch.empty((m, k), device=device, dtype=out_dtype) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can user pass enable_es to cutlass_moe?
Do we need an environment variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are still two optimizations that haven't been implemented:
Once these tasks are complete, I will add an environment variable to enable ES.