Skip to content

Commit 24cbddd

Browse files
[mxfp8 moe training] add CUDA kernel for per-group conversion of scale factors to blocked layout
stack-info: PR: #3504, branch: danielvegamyhre/stack/86
1 parent 1f9bfd7 commit 24cbddd

File tree

7 files changed

+846
-0
lines changed

7 files changed

+846
-0
lines changed
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import itertools
8+
from dataclasses import dataclass
9+
from typing import List
10+
11+
import torch
12+
from tabulate import tabulate
13+
from tqdm import tqdm
14+
15+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
16+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
17+
mx_block_rearrange_2d_K_groups_cuda,
18+
torch_to_blocked_2d_K_groups,
19+
triton_mx_block_rearrange_2d_K_groups,
20+
)
21+
from torchao.prototype.moe_training.utils import generate_jagged_offs
22+
23+
device = torch.device("cuda")
24+
25+
# Needed since changing args to function causes recompiles
26+
torch._dynamo.config.cache_size_limit = 1000
27+
28+
29+
@dataclass(frozen=True)
30+
class ExperimentConfig:
31+
input_shape: tuple[int]
32+
num_groups: int
33+
version: str # "naive" or "parallel"
34+
35+
36+
@dataclass(frozen=True)
37+
class ExperimentResult:
38+
time_us: float
39+
mem_bw_gbps: float
40+
41+
42+
@dataclass(frozen=True)
43+
class Experiment:
44+
config: ExperimentConfig
45+
result: ExperimentResult
46+
47+
48+
def get_configs() -> List[ExperimentConfig]:
49+
# Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups.
50+
block_size = 32
51+
input_shapes = [
52+
(8192, 16384 // block_size),
53+
(8192, 32768 // block_size),
54+
(8192, 65536 // block_size),
55+
(8192, 131072 // block_size),
56+
(8192, 1048576 // block_size),
57+
(5120, 16384 // block_size),
58+
(5120, 32768 // block_size),
59+
(5120, 65536 // block_size),
60+
(5120, 131072 // block_size),
61+
(5120, 1048576 // block_size),
62+
(7168, 16384 // block_size),
63+
(7168, 32768 // block_size),
64+
(7168, 65536 // block_size),
65+
(7168, 131072 // block_size),
66+
(7168, 1048576 // block_size),
67+
(2048, 16384 // block_size),
68+
(2048, 32768 // block_size),
69+
(2048, 65536 // block_size),
70+
(2048, 131072 // block_size),
71+
(2048, 1048576 // block_size),
72+
]
73+
num_groups = [8]
74+
versions = [
75+
"torch",
76+
"triton",
77+
"cuda_pipelined_64_chunks4",
78+
]
79+
80+
configs = []
81+
for shape, groups, version in itertools.product(
82+
input_shapes,
83+
num_groups,
84+
versions,
85+
):
86+
configs.append(
87+
ExperimentConfig(
88+
input_shape=shape,
89+
num_groups=groups,
90+
version=version,
91+
)
92+
)
93+
return configs
94+
95+
96+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
97+
input_shape, num_groups, version = (
98+
config.input_shape,
99+
config.num_groups,
100+
config.version,
101+
)
102+
input_tensor = torch.randint(
103+
low=0,
104+
high=256,
105+
size=input_shape,
106+
dtype=torch.uint8,
107+
device=device,
108+
)
109+
110+
M, Kg = input_shape
111+
block_size = 32
112+
input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size)
113+
114+
# Select which kernel to benchmark based on version
115+
if version == "torch":
116+
kernel_fn = torch_to_blocked_2d_K_groups
117+
kernel_input = input_tensor
118+
elif version == "triton":
119+
kernel_fn = triton_mx_block_rearrange_2d_K_groups
120+
# Triton uses row-major input
121+
kernel_input = input_tensor
122+
elif version == "cuda_pipelined_64_chunks4":
123+
kernel_fn = lambda t, o: mx_block_rearrange_2d_K_groups_cuda(
124+
t,
125+
o,
126+
)
127+
kernel_input = input_tensor.view(torch.float8_e8m0fnu)
128+
else:
129+
raise ValueError(f"Unknown version: {version}")
130+
131+
# Run kernel to get output shape
132+
outputs = kernel_fn(
133+
kernel_input,
134+
input_group_offsets,
135+
)
136+
if isinstance(outputs, tuple): # torch returns a tuple with extra metadata
137+
out_scales, _ = outputs
138+
else:
139+
out_scales = outputs
140+
141+
# Benchmark the kernel
142+
time_us = benchmark_cuda_function_in_microseconds(
143+
kernel_fn,
144+
kernel_input,
145+
input_group_offsets,
146+
)
147+
148+
# Calculate memory bandwidth
149+
bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
150+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
151+
152+
read_bytes = input_tensor.numel() * bytes_per_input_el
153+
write_bytes = out_scales.numel() * bytes_per_output_el
154+
155+
mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6)
156+
157+
return ExperimentResult(
158+
time_us=time_us,
159+
mem_bw_gbps=mem_bw_gbps,
160+
)
161+
162+
163+
def print_results(experiments: List[Experiment]):
164+
# Group experiments by input shape
165+
shapes_dict = {}
166+
for exp in experiments:
167+
shape_key = exp.config.input_shape
168+
if shape_key not in shapes_dict:
169+
shapes_dict[shape_key] = {}
170+
shapes_dict[shape_key][exp.config.version] = exp.result
171+
172+
headers = [
173+
"kernel_version",
174+
"scale_shape",
175+
"time_us",
176+
"mem_bw_gbps",
177+
"fastest_version",
178+
"speedup_vs_torch",
179+
]
180+
181+
rows = []
182+
for shape, versions in shapes_dict.items():
183+
# Find fastest version for this shape
184+
fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0]
185+
186+
# Get torch baseline time for speedup calculation
187+
torch_time_us = versions.get("torch").time_us if "torch" in versions else None
188+
189+
# Add rows for each version
190+
for version, result in versions.items():
191+
# Calculate speedup vs triton
192+
speedup_str = ""
193+
if version != "torch":
194+
speedup = torch_time_us / result.time_us
195+
speedup_str = f"{speedup:.2f}x"
196+
197+
rows.append(
198+
[
199+
version,
200+
f"({shape[0]}, {shape[1]})",
201+
f"{result.time_us:.2f}",
202+
round(result.mem_bw_gbps, 3),
203+
fastest_version,
204+
speedup_str,
205+
]
206+
)
207+
208+
print(tabulate(rows, headers=headers))
209+
210+
211+
def main():
212+
torch.random.manual_seed(123)
213+
configs = get_configs()
214+
results = []
215+
for config in tqdm(configs):
216+
result = run_experiment(config)
217+
results.append(Experiment(config=config, result=result))
218+
219+
# Use Tabulate to print results
220+
print_results(results)
221+
222+
223+
if __name__ == "__main__":
224+
main()

dsv3_roofline.png

395 KB
Loading

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ def get_extensions():
709709
mxfp8_sources = [
710710
os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"),
711711
os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"),
712+
os.path.join(mxfp8_extension_dir, "mx_block_rearrange_2d_K_groups.cu"),
712713
]
713714

714715
# Only add the extension if the source files exist AND we are building for sm100

test/prototype/moe_training/test_kernels.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,67 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
352352
# Check quantized values
353353
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
354354
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"
355+
356+
357+
@pytest.mark.skipif(
358+
not is_sm_at_least_100(),
359+
reason="MXFP8 requires CUDA capability 10.0 or greater",
360+
)
361+
@pytest.mark.parametrize("m", [256, 512, 1024, 5120])
362+
@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384])
363+
@pytest.mark.parametrize("n_groups", [1, 4, 8, 16])
364+
def test_cuda_mx_block_rearrange_2d_K_groups(
365+
m: int,
366+
total_k: int,
367+
n_groups: int,
368+
):
369+
"""
370+
Test CUDA kernel for mx_block_rearrange_2d_K_groups against Triton reference.
371+
This kernel rearranges E8M0 scales to block-scaled swizzle format for cuBLAS Tmem.
372+
"""
373+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
374+
mx_block_rearrange_2d_K_groups_cuda,
375+
)
376+
377+
device = "cuda"
378+
block_size = 32
379+
input_data = torch.randn(m, total_k, device=device)
380+
381+
e8m0_scales, _ = to_mx(
382+
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
383+
)
384+
385+
# Generate group end offsets along total_K, then divide by block_size to get scale group end offsets
386+
input_group_offsets = generate_jagged_offs(
387+
n_groups, total_k, multiple_of=block_size, device=device
388+
)
389+
scale_group_offsets = input_group_offsets // block_size
390+
391+
# Triton reference implementation
392+
triton_out_scales = triton_mx_block_rearrange_2d_K_groups(
393+
e8m0_scales,
394+
scale_group_offsets,
395+
)
396+
397+
# CUDA kernel implementation
398+
cuda_out_scales = mx_block_rearrange_2d_K_groups_cuda(
399+
e8m0_scales,
400+
scale_group_offsets,
401+
)
402+
403+
# Check that outputs match
404+
assert torch.equal(triton_out_scales, cuda_out_scales.view(torch.float8_e8m0fnu)), (
405+
"CUDA and Triton blocked scales not equal"
406+
)
407+
408+
# Verify output shape
409+
expected_rows = ((m + 127) // 128) * 128 # Padded to multiple of 128
410+
expected_cols = (
411+
e8m0_scales.size(1) + n_groups * 4
412+
) # Original cols + padding per group
413+
assert cuda_out_scales.shape == (
414+
expected_rows,
415+
expected_cols,
416+
), (
417+
f"Output shape mismatch: expected {(expected_rows, expected_cols)}, got {cuda_out_scales.shape}"
418+
)

0 commit comments

Comments
 (0)