Skip to content

Commit 01374eb

Browse files
[mxfp8 moe training] compute prefix sum of group sizes inside kernel intead of precomputing (#3285)
1 parent 8e3b3da commit 01374eb

File tree

5 files changed

+42
-52
lines changed

5 files changed

+42
-52
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_M_groups.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from benchmarks.utils import benchmark_cuda_function_in_microseconds
1717
from torchao.prototype.moe_training.kernels.mxfp8 import (
18-
compute_blocked_scale_offsets_for_M_groups,
1918
torch_to_blocked_2d_M_groups,
2019
triton_mx_block_rearrange_2d_M_groups,
2120
)
@@ -80,9 +79,6 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8079

8180
Mg, K = input_shape
8281
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)
83-
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
84-
input_group_offsets
85-
)
8682

8783
# bench torch
8884
compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups)
@@ -100,13 +96,11 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
10096
triton_out_scales = triton_mx_block_rearrange_2d_M_groups(
10197
input_tensor,
10298
input_group_offsets,
103-
output_group_offsets,
10499
)
105100
triton_time_us = benchmark_cuda_function_in_microseconds(
106101
triton_mx_block_rearrange_2d_M_groups,
107102
input_tensor,
108103
input_group_offsets,
109-
output_group_offsets,
110104
)
111105

112106
# mem bw calculations

test/prototype/moe_training/test_kernels.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
triton_fp8_per_group_rowwise_scales,
2222
)
2323
from torchao.prototype.moe_training.kernels.mxfp8 import (
24-
compute_blocked_scale_offsets_for_K_groups,
25-
compute_blocked_scale_offsets_for_M_groups,
2624
torch_to_blocked_2d_K_groups,
2725
torch_to_blocked_2d_M_groups,
2826
torch_to_blocked_per_group_3d,
@@ -236,13 +234,9 @@ def test_triton_mx_block_rearrange_2d_M_groups(
236234
)
237235

238236
# triton kernel
239-
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
240-
input_group_offsets
241-
)
242237
triton_out_scales = triton_mx_block_rearrange_2d_M_groups(
243238
e8m0_scales,
244239
input_group_offsets,
245-
output_group_offsets,
246240
)
247241
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
248242
"blocked scales not equal"
@@ -306,16 +300,9 @@ def test_triton_mx_block_rearrange_2d_K_groups(
306300
)
307301

308302
# triton kernel
309-
_, output_group_offsets = compute_blocked_scale_offsets_for_K_groups(
310-
scale_group_offsets
311-
)
312-
assert torch.equal(output_group_offsets, ref_start_cols_after_padding), (
313-
"output scale group start offsets not equal"
314-
)
315303
triton_out_scales = triton_mx_block_rearrange_2d_K_groups(
316304
e8m0_scales,
317305
scale_group_offsets,
318-
output_group_offsets,
319306
)
320307
assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal"
321308

torchao/prototype/moe_training/kernels/mxfp8/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2-
compute_blocked_scale_offsets_for_K_groups, # noqa: F401
3-
compute_blocked_scale_offsets_for_M_groups, # noqa: F401
42
mxfp8_quantize_cuda_3d, # noqa: F401
53
torch_to_blocked_2d_K_groups, # noqa: F401
64
torch_to_blocked_2d_M_groups, # noqa: F401

torchao/prototype/moe_training/kernels/mxfp8/quant.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def compute_blocked_scale_offsets_for_K_groups(
223223
def triton_mx_block_rearrange_2d_M_groups(
224224
scales_tensor: torch.Tensor,
225225
input_group_end_offsets: torch.Tensor,
226-
output_group_start_offsets: torch.Tensor,
227226
) -> torch.Tensor:
228227
"""
229228
Rearranges an E8M0 tensor scale to block-scaled swizzle format,
@@ -275,15 +274,14 @@ def triton_mx_block_rearrange_2d_M_groups(
275274
scales_tensor.stride(1),
276275
rows,
277276
cols,
278-
num_groups,
279277
# Original offsets (to read from)
280278
input_group_end_offsets,
281279
# Output scales tensor and group offsets after padding (to write to)
282280
output.view(torch.uint8),
283281
output.stride(0),
284-
output_group_start_offsets,
285282
output_stride_per_block,
286283
output_stride_per_row_of_blocks,
284+
num_groups=num_groups,
287285
BLOCK_ROWS=BLOCK_ROWS,
288286
BLOCK_COLS=BLOCK_COLS,
289287
)
@@ -297,13 +295,12 @@ def triton_scale_swizzle_M_groups(
297295
scales_stride_dim1,
298296
scale_rows,
299297
scale_cols,
300-
num_groups,
301298
orig_offsets, # (num_groups,)
302299
output_scales_ptr,
303300
output_scales_stride_dim0,
304-
output_scales_group_offsets, # (num_groups,)
305301
output_stride_per_block,
306302
output_stride_per_row_of_blocks,
303+
num_groups: tl.constexpr,
307304
BLOCK_ROWS: tl.constexpr,
308305
BLOCK_COLS: tl.constexpr,
309306
):
@@ -316,10 +313,13 @@ def triton_scale_swizzle_M_groups(
316313
input_group_end_row = tl.load(
317314
orig_offsets + group_pid, mask=group_pid < num_groups, other=0
318315
)
319-
# Output scales start row we will begin writing to
320-
output_group_start_row = tl.load(
321-
output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0
316+
317+
# Calculate this group's start row after blocked format padding, by doing a prefix sum
318+
# of each previous group's padded size.
319+
output_group_start_row = _blocked_group_start_idx(
320+
group_pid, orig_offsets, num_groups, 128
322321
)
322+
323323
# Calculate destination indices for each row and col in block swizzled layout.
324324
# We can reuse this swizzle transformation on each block of data we read.
325325
row_offs = tl.arange(0, BLOCK_ROWS)[:, None]
@@ -489,7 +489,6 @@ def triton_scale_swizzle_per_group_3d(
489489
def triton_mx_block_rearrange_2d_K_groups(
490490
scales_tensor: torch.Tensor,
491491
input_group_end_offsets: torch.Tensor,
492-
output_group_start_offsets: torch.Tensor,
493492
) -> torch.Tensor:
494493
"""
495494
Rearranges an E8M0 tensor scale to block-scaled swizzle format on a per group basis,
@@ -538,13 +537,10 @@ def triton_mx_block_rearrange_2d_K_groups(
538537
rows,
539538
cols,
540539
padded_rows,
541-
num_groups,
542-
# Original offsets (to read from)
543540
input_group_end_offsets,
544-
# Output scales tensor and group offsets after padding (to write to)
545541
output.view(torch.uint8),
546-
output_group_start_offsets,
547542
output_stride_per_block,
543+
num_groups=num_groups,
548544
BLOCK_ROWS=BLOCK_ROWS,
549545
BLOCK_COLS=BLOCK_COLS,
550546
DEBUG=False,
@@ -560,11 +556,10 @@ def triton_scale_swizzle_2d_K_groups(
560556
scale_rows,
561557
scale_cols,
562558
padded_rows,
563-
num_groups,
564559
orig_offsets, # (num_groups,)
565560
output_scales_ptr,
566-
output_scales_group_offsets, # (num_groups,)
567561
output_stride_per_block,
562+
num_groups: tl.constexpr,
568563
BLOCK_ROWS: tl.constexpr,
569564
BLOCK_COLS: tl.constexpr,
570565
DEBUG: tl.constexpr = False,
@@ -578,8 +573,11 @@ def triton_scale_swizzle_2d_K_groups(
578573
)
579574
input_group_end_col = tl.load(orig_offsets + group_pid)
580575

581-
# Output scales start row we will begin writing to
582-
output_group_start_col = tl.load(output_scales_group_offsets + group_pid)
576+
# Calculate this group's start row after blocked format padding, by doing a prefix sum
577+
# of each previous group's padded size.
578+
output_group_start_col = _blocked_group_start_idx(
579+
group_pid, orig_offsets, num_groups, 4
580+
)
583581

584582
row_offs = tl.arange(0, BLOCK_ROWS)[:, None]
585583
col_offs = tl.arange(0, BLOCK_COLS)[None, :]
@@ -651,6 +649,31 @@ def _dest_indices_for_block(
651649
return dest_indices_flat
652650

653651

652+
@triton.jit
653+
def _blocked_group_start_idx(
654+
group_pid,
655+
orig_offsets,
656+
num_groups: tl.constexpr,
657+
padding_size: tl.constexpr,
658+
):
659+
"""Prefix sum to compute the start index of a given group."""
660+
offsets = tl.load(orig_offsets + tl.arange(0, num_groups))
661+
prev_offsets = tl.load(
662+
orig_offsets + tl.arange(0, num_groups) - 1,
663+
mask=tl.arange(0, num_groups) > 0,
664+
other=0,
665+
)
666+
group_sizes = tl.where(
667+
tl.arange(0, num_groups) > 0,
668+
offsets - prev_offsets,
669+
offsets,
670+
)
671+
padded_sizes = tl.cdiv(group_sizes, padding_size) * padding_size
672+
prefix_mask = tl.arange(0, num_groups) < group_pid
673+
group_start_idx = tl.sum(tl.where(prefix_mask, padded_sizes, 0))
674+
return group_start_idx
675+
676+
654677
mxfp8_cuda_extension_available = False
655678
if is_sm_at_least_100():
656679
try:

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
triton_fp8_rowwise_3d_transpose_rhs,
1818
)
1919
from torchao.prototype.moe_training.kernels.mxfp8 import (
20-
compute_blocked_scale_offsets_for_K_groups,
21-
compute_blocked_scale_offsets_for_M_groups,
2220
mxfp8_quantize_cuda_3d,
2321
triton_mx_block_rearrange_2d_K_groups,
2422
triton_mx_block_rearrange_2d_M_groups,
@@ -332,13 +330,9 @@ def forward(
332330
)
333331

334332
# Convert scales to blocked format for 2d-3d grouped mm
335-
_, blocked_scales_group_offsets_2d3d = (
336-
compute_blocked_scale_offsets_for_M_groups(offs)
337-
)
338333
A_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
339334
A_scale,
340335
offs,
341-
blocked_scales_group_offsets_2d3d,
342336
)
343337
B_scales_blocked = triton_mx_block_rearrange_per_group_3d(B_scales)
344338

@@ -353,7 +347,7 @@ def forward(
353347
out_dtype=out_dtype,
354348
)
355349

356-
ctx.save_for_backward(A, B_t, offs, blocked_scales_group_offsets_2d3d)
350+
ctx.save_for_backward(A, B_t, offs)
357351
ctx.block_size = block_size
358352
ctx.out_dtype = out_dtype
359353
ctx.emulated = emulated
@@ -363,7 +357,7 @@ def forward(
363357

364358
@staticmethod
365359
def backward(ctx, grad_out: torch.Tensor):
366-
A, B_t, offs, blocked_scales_group_offsets_2d3d = ctx.saved_tensors
360+
A, B_t, offs = ctx.saved_tensors
367361
block_size = ctx.block_size
368362
out_dtype = ctx.out_dtype
369363
use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast
@@ -398,7 +392,6 @@ def backward(ctx, grad_out: torch.Tensor):
398392
grad_out_scales_blocked = triton_mx_block_rearrange_2d_M_groups(
399393
grad_out_scale,
400394
offs,
401-
blocked_scales_group_offsets_2d3d,
402395
)
403396
B_scales_blocked = triton_mx_block_rearrange_per_group_3d(B_scales)
404397

@@ -444,18 +437,13 @@ def backward(ctx, grad_out: torch.Tensor):
444437

445438
# Convert scales to blocked format for 2d-2d grouped mm
446439
scale_group_offsets = offs // block_size
447-
_, blocked_scale_group_offsets = compute_blocked_scale_offsets_for_K_groups(
448-
scale_group_offsets
449-
)
450440
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
451441
grad_out_t_scales,
452442
scale_group_offsets,
453-
blocked_scale_group_offsets,
454443
)
455444
A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
456445
A_t_scales,
457446
scale_group_offsets,
458-
blocked_scale_group_offsets,
459447
)
460448

461449
# grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)

0 commit comments

Comments
 (0)