Skip to content

Commit a8fa9e5

Browse files
[mxfp8 moe training] add wgrad_with_hp option (#3508)
1 parent 23a58c0 commit a8fa9e5

File tree

2 files changed

+81
-54
lines changed

2 files changed

+81
-54
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,14 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
319319
)
320320
@pytest.mark.parametrize("num_experts", (2, 4, 8, 16))
321321
@pytest.mark.parametrize("use_triton_for_dim0_cast", (True, False))
322+
@pytest.mark.parametrize("wgrad_with_hp", (True, False))
322323
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
323-
M, K, N, num_experts, use_triton_for_dim0_cast
324+
M,
325+
K,
326+
N,
327+
num_experts,
328+
use_triton_for_dim0_cast,
329+
wgrad_with_hp,
324330
):
325331
block_size = 32
326332
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
@@ -340,8 +346,17 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
340346
)
341347

342348
# Forward
349+
out_dtype = torch.bfloat16
350+
emulated = False
343351
out = _to_mxfp8_then_scaled_grouped_mm(
344-
x, w_t, offs, block_size, torch.bfloat16, use_triton_for_dim0_cast
352+
x,
353+
w_t,
354+
offs,
355+
block_size,
356+
out_dtype,
357+
emulated,
358+
use_triton_for_dim0_cast,
359+
wgrad_with_hp,
345360
)
346361
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
347362
sqnr = compute_error(ref_out, out)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 64 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def forward(
294294
out_dtype: Optional[torch.dtype] = torch.bfloat16,
295295
emulated: bool = False,
296296
use_triton_for_dim0_cast: bool = False,
297+
wgrad_with_hp: bool = False,
297298
scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL,
298299
) -> torch.Tensor:
299300
# torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
@@ -352,6 +353,7 @@ def forward(
352353
ctx.out_dtype = out_dtype
353354
ctx.emulated = emulated
354355
ctx.use_triton_for_dim0_cast = use_triton_for_dim0_cast
356+
ctx.wgrad_with_hp = wgrad_with_hp
355357
ctx.scale_calculation_mode = scale_calculation_mode
356358
return out
357359

@@ -361,6 +363,7 @@ def backward(ctx, grad_out: torch.Tensor):
361363
block_size = ctx.block_size
362364
out_dtype = ctx.out_dtype
363365
use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast
366+
wgrad_with_hp = ctx.wgrad_with_hp
364367
scale_calculation_mode = ctx.scale_calculation_mode
365368

366369
# grad_out_data shape: (M, N)
@@ -405,59 +408,68 @@ def backward(ctx, grad_out: torch.Tensor):
405408
out_dtype=out_dtype,
406409
)
407410

408-
# grad_out_t_data shape: (M, N)
409-
# grad_out_t_scales shape: (N, M//block_size)
410-
grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper(
411-
grad_out,
412-
block_size,
413-
elem_dtype=torch.float8_e4m3fn,
414-
hp_dtype=grad_out.dtype,
415-
kernel_preference=KernelPreference.AUTO, # Not used
416-
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
417-
scale_calculation_mode=scale_calculation_mode,
418-
)
419-
grad_out_t_data = grad_out_t_mx.qdata
420-
grad_out_t_scales = grad_out_t_mx.scale
421-
422-
# Transpose A so we can scale along the M dimension, then un-transpose.
423-
# A shape: (M, K)
424-
# A_t_data shape: (K, M)
425-
# A_t_scales shape: (K, M//block_size)
426-
A_t_mx = _to_mxfp8_dim1_kernel_wrapper(
427-
A,
428-
block_size,
429-
elem_dtype=torch.float8_e4m3fn,
430-
hp_dtype=A.dtype,
431-
kernel_preference=KernelPreference.AUTO, # Not used
432-
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
433-
scale_calculation_mode=scale_calculation_mode,
434-
)
435-
A_t_data = A_t_mx.qdata
436-
A_t_scales = A_t_mx.scale
437-
438-
# Convert scales to blocked format for 2d-2d grouped mm
439-
scale_group_offsets = offs // block_size
440-
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
441-
grad_out_t_scales,
442-
scale_group_offsets,
443-
)
444-
A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
445-
A_t_scales,
446-
scale_group_offsets,
447-
)
411+
# Optionally compute wgrad in high precision, if specified.
412+
if wgrad_with_hp:
413+
# TODO: migrate all grouped gemms in this file to new torch.nn.functional API
414+
# grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)
415+
grad_B = torch._grouped_mm(
416+
grad_out.transpose(-2, -1), A, offs=offs, out_dtype=out_dtype
417+
)
418+
grad_B_t = grad_B.transpose(-2, -1)
419+
else:
420+
# grad_out_t_data shape: (M, N)
421+
# grad_out_t_scales shape: (N, M//block_size)
422+
grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper(
423+
grad_out,
424+
block_size,
425+
elem_dtype=torch.float8_e4m3fn,
426+
hp_dtype=grad_out.dtype,
427+
kernel_preference=KernelPreference.AUTO, # Not used
428+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
429+
scale_calculation_mode=scale_calculation_mode,
430+
)
431+
grad_out_t_data = grad_out_t_mx.qdata
432+
grad_out_t_scales = grad_out_t_mx.scale
433+
434+
# Transpose A so we can scale along the M dimension, then un-transpose.
435+
# A shape: (M, K)
436+
# A_t_data shape: (K, M)
437+
# A_t_scales shape: (K, M//block_size)
438+
A_t_mx = _to_mxfp8_dim1_kernel_wrapper(
439+
A,
440+
block_size,
441+
elem_dtype=torch.float8_e4m3fn,
442+
hp_dtype=A.dtype,
443+
kernel_preference=KernelPreference.AUTO, # Not used
444+
cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA,
445+
scale_calculation_mode=scale_calculation_mode,
446+
)
447+
A_t_data = A_t_mx.qdata
448+
A_t_scales = A_t_mx.scale
449+
450+
# Convert scales to blocked format for 2d-2d grouped mm
451+
scale_group_offsets = offs // block_size
452+
grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
453+
grad_out_t_scales,
454+
scale_group_offsets,
455+
)
456+
A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups(
457+
A_t_scales,
458+
scale_group_offsets,
459+
)
448460

449-
# grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)
450-
grad_B = torch._scaled_grouped_mm(
451-
grad_out_t_data,
452-
A_t_data.transpose(-2, -1),
453-
grad_out_t_scales_blocked,
454-
A_t_scales_blocked,
455-
offs=offs,
456-
out_dtype=out_dtype,
457-
)
458-
# grad_B_t shape = (E,K,N)
459-
grad_B_t = grad_B.transpose(-2, -1)
460-
return grad_A, grad_B_t, None, None, None, None
461+
# grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)
462+
grad_B = torch._scaled_grouped_mm(
463+
grad_out_t_data,
464+
A_t_data.transpose(-2, -1),
465+
grad_out_t_scales_blocked,
466+
A_t_scales_blocked,
467+
offs=offs,
468+
out_dtype=out_dtype,
469+
)
470+
# grad_B_t shape = (E,K,N)
471+
grad_B_t = grad_B.transpose(-2, -1)
472+
return grad_A, grad_B_t, None, None, None, None, None, None, None
461473

462474

463475
def _to_mxfp8_dim1_3d(

0 commit comments

Comments
 (0)