Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 76 additions & 43 deletions python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
"""CUTLASS based Fused MoE kernels."""

from typing import Optional
from typing import Optional, Tuple

import torch

from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_sm90_supported

_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,
es_fp8_blockwise_scaled_grouped_mm,
fp8_blockwise_scaled_grouped_mm,
prepare_moe_input,
scaled_fp4_experts_quant,
Expand Down Expand Up @@ -43,6 +44,7 @@ def cutlass_fused_experts_fp8(
problem_sizes2: torch.Tensor,
use_fp8_blockscale: bool = True,
output: Optional[torch.Tensor] = None,
enable_es: Tuple[bool, bool] = (False, False),
) -> torch.Tensor:
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.

Expand Down Expand Up @@ -98,6 +100,7 @@ def cutlass_fused_experts_fp8(
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
block scaling. Currently, only `True` is supported. Defaults to `True`.
output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created.
enable_es (tuple(bool, bool)): Flag indicating usage of expert specialization kernel for (up-projection, down-projection)
Returns:
torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.

Expand All @@ -121,7 +124,7 @@ def cutlass_fused_experts_fp8(
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)

es_up, es_down = enable_es
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can user pass enable_es to cutlass_moe?
Do we need an environment variable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still two optimizations that haven't been implemented:

  1. Based on arithmetic intensity combined with MNK, a suitable kernel is dynamically selected to adapt to more scenarios.
  2. PDL support: Support PDL for SM90 Array TMA GEMM NVIDIA/cutlass#2719

Once these tasks are complete, I will add an environment variable to enable ES.

out_dtype = a.dtype
num_experts = w1_q.size(0)
m = a.size(0)
Expand Down Expand Up @@ -156,52 +159,82 @@ def cutlass_fused_experts_fp8(
a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)

fp8_blockwise_scaled_grouped_mm(
c1,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
rep_a_q,
w1_q,
rep_a1_scales,
w1_scale,
a1_strides,
a1_strides,
c1_strides,
a_sf_layout,
w_sf_layout,
problem_sizes1,
expert_offsets[:-1],
workspace,
)
if is_sm90_supported() and es_up:
es_fp8_blockwise_scaled_grouped_mm(
c1,
rep_a_q,
w1_q,
rep_a1_scales,
w1_scale,
a1_strides,
a1_strides,
c1_strides,
problem_sizes1,
expert_offsets[:-1],
workspace,
)
else:
fp8_blockwise_scaled_grouped_mm(
c1,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
rep_a_q,
w1_q,
rep_a1_scales,
w1_scale,
a1_strides,
a1_strides,
c1_strides,
a_sf_layout,
w_sf_layout,
problem_sizes1,
expert_offsets[:-1],
workspace,
)

intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
silu_and_mul(c1, intermediate)

intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)

fp8_blockwise_scaled_grouped_mm(
c2,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
intemediate_q,
w2_q,
a2_scale,
w2_scale,
a2_strides,
a2_strides,
c2_strides,
a_sf_layout,
w_sf_layout,
problem_sizes2,
expert_offsets[:-1],
workspace,
)
if is_sm90_supported() and es_down:
es_fp8_blockwise_scaled_grouped_mm(
c2,
intemediate_q,
w2_q,
a2_scale,
w2_scale,
a2_strides,
a2_strides,
c2_strides,
problem_sizes2,
expert_offsets[:-1],
workspace,
)
else:
fp8_blockwise_scaled_grouped_mm(
c2,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
intemediate_q,
w2_q,
a2_scale,
w2_scale,
a2_strides,
a2_strides,
c2_strides,
a_sf_layout,
w_sf_layout,
problem_sizes2,
expert_offsets[:-1],
workspace,
)
Comment on lines +162 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between the GEMM calls for the up-projection (lines 162-196) and the down-projection (lines 203-237). Both blocks contain similar if/else logic to switch between the standard and expert specialization kernels. This redundancy makes the code harder to read and maintain.

To improve this, you could refactor the logic into a helper function. This function would encapsulate the conditional dispatch to either es_fp8_blockwise_scaled_grouped_mm or fp8_blockwise_scaled_grouped_mm.

Here is an example of how you could structure such a helper:

def _dispatch_gemm(is_es_enabled, out, a, b, a_scales, b_scales, a_strides, c_strides, problem_sizes, expert_offsets, workspace, **kwargs):
    if is_sm90_supported() and is_es_enabled:
        es_fp8_blockwise_scaled_grouped_mm(
            out,
            a,
            b,
            a_scales,
            b_scales,
            a_strides,
            a_strides,  # b_strides for GEMM
            c_strides,
            problem_sizes,
            expert_offsets,
            workspace,
        )
    else:
        fp8_blockwise_scaled_grouped_mm(
            out,
            kwargs["a_ptrs"],
            kwargs["b_ptrs"],
            kwargs["out_ptrs"],
            kwargs["a_scales_ptrs"],
            kwargs["b_scales_ptrs"],
            a,
            b,
            a_scales,
            b_scales,
            a_strides,
            a_strides,  # b_strides for GEMM
            c_strides,
            kwargs["a_sf_layout"],
            kwargs["w_sf_layout"],
            problem_sizes,
            expert_offsets,
            workspace,
        )

This would make the main function body much cleaner and avoid repeating the same logic.


if output is None:
output = torch.empty((m, k), device=device, dtype=out_dtype)
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/test/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def run_test(tp_size, batch_size, model_config, check=False):
problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")

enable_es = (False, False)
if torch.cuda.get_device_name(torch.cuda.current_device()) == "NVIDIA H200":
enable_es = (False, True)
elif torch.cuda.get_device_name(torch.cuda.current_device()) == "NVIDIA H20":
enable_es = (True, True)

# --- Lambdas for Benchmarking ---
cutlass_lambda = lambda: cutlass_fused_experts_fp8(
x,
Expand All @@ -150,6 +156,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
expert_offsets,
problem_sizes1,
problem_sizes2,
enable_es=enable_es,
)

topk_output = StandardTopKOutput(
Expand Down Expand Up @@ -234,6 +241,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
expert_offsets,
problem_sizes1,
problem_sizes2,
enable_es=enable_es,
)

# Run Triton version (requires original shape weights, use inplace=False)
Expand Down
Loading