diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 12d2c3991cd..1352112828b 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -1,17 +1,18 @@ """CUTLASS based Fused MoE kernels.""" -from typing import Optional +from typing import Optional, Tuple import torch from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams -from sglang.srt.utils import is_cuda +from sglang.srt.utils import is_cuda, is_sm90_supported _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import ( apply_shuffle_mul_sum, cutlass_fp4_group_mm, + es_fp8_blockwise_scaled_grouped_mm, fp8_blockwise_scaled_grouped_mm, prepare_moe_input, scaled_fp4_experts_quant, @@ -43,6 +44,7 @@ def cutlass_fused_experts_fp8( problem_sizes2: torch.Tensor, use_fp8_blockscale: bool = True, output: Optional[torch.Tensor] = None, + enable_es: Tuple[bool, bool] = (False, False), ) -> torch.Tensor: """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations. @@ -98,6 +100,7 @@ def cutlass_fused_experts_fp8( use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with block scaling. Currently, only `True` is supported. Defaults to `True`. output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created. + enable_es (tuple(bool, bool)): Flag indicating usage of expert specialization kernel for (up-projection, down-projection) Returns: torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`. @@ -121,7 +124,7 @@ def cutlass_fused_experts_fp8( from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_group_quant_fp8, ) - + es_up, es_down = enable_es out_dtype = a.dtype num_experts = w1_q.size(0) m = a.size(0) @@ -156,52 +159,82 @@ def cutlass_fused_experts_fp8( a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int) w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int) - fp8_blockwise_scaled_grouped_mm( - c1, - a_ptrs, - b_ptrs, - out_ptrs, - a_scales_ptrs, - b_scales_ptrs, - rep_a_q, - w1_q, - rep_a1_scales, - w1_scale, - a1_strides, - a1_strides, - c1_strides, - a_sf_layout, - w_sf_layout, - problem_sizes1, - expert_offsets[:-1], - workspace, - ) + if is_sm90_supported() and es_up: + es_fp8_blockwise_scaled_grouped_mm( + c1, + rep_a_q, + w1_q, + rep_a1_scales, + w1_scale, + a1_strides, + a1_strides, + c1_strides, + problem_sizes1, + expert_offsets[:-1], + workspace, + ) + else: + fp8_blockwise_scaled_grouped_mm( + c1, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + rep_a_q, + w1_q, + rep_a1_scales, + w1_scale, + a1_strides, + a1_strides, + c1_strides, + a_sf_layout, + w_sf_layout, + problem_sizes1, + expert_offsets[:-1], + workspace, + ) intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype) silu_and_mul(c1, intermediate) intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) - fp8_blockwise_scaled_grouped_mm( - c2, - a_ptrs, - b_ptrs, - out_ptrs, - a_scales_ptrs, - b_scales_ptrs, - intemediate_q, - w2_q, - a2_scale, - w2_scale, - a2_strides, - a2_strides, - c2_strides, - a_sf_layout, - w_sf_layout, - problem_sizes2, - expert_offsets[:-1], - workspace, - ) + if is_sm90_supported() and es_down: + es_fp8_blockwise_scaled_grouped_mm( + c2, + intemediate_q, + w2_q, + a2_scale, + w2_scale, + a2_strides, + a2_strides, + c2_strides, + problem_sizes2, + expert_offsets[:-1], + workspace, + ) + else: + fp8_blockwise_scaled_grouped_mm( + c2, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + intemediate_q, + w2_q, + a2_scale, + w2_scale, + a2_strides, + a2_strides, + c2_strides, + a_sf_layout, + w_sf_layout, + problem_sizes2, + expert_offsets[:-1], + workspace, + ) if output is None: output = torch.empty((m, k), device=device, dtype=out_dtype) diff --git a/python/sglang/test/test_cutlass_moe.py b/python/sglang/test/test_cutlass_moe.py index fdab5a3acb0..4e4eee376f6 100755 --- a/python/sglang/test/test_cutlass_moe.py +++ b/python/sglang/test/test_cutlass_moe.py @@ -128,6 +128,12 @@ def run_test(tp_size, batch_size, model_config, check=False): problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device="cuda") problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda") + enable_es = (False, False) + if torch.cuda.get_device_name(torch.cuda.current_device()) == "NVIDIA H200": + enable_es = (False, True) + elif torch.cuda.get_device_name(torch.cuda.current_device()) == "NVIDIA H20": + enable_es = (True, True) + # --- Lambdas for Benchmarking --- cutlass_lambda = lambda: cutlass_fused_experts_fp8( x, @@ -150,6 +156,7 @@ def run_test(tp_size, batch_size, model_config, check=False): expert_offsets, problem_sizes1, problem_sizes2, + enable_es=enable_es, ) topk_output = StandardTopKOutput( @@ -234,6 +241,7 @@ def run_test(tp_size, batch_size, model_config, check=False): expert_offsets, problem_sizes1, problem_sizes2, + enable_es=enable_es, ) # Run Triton version (requires original shape weights, use inplace=False)