@@ -294,6 +294,7 @@ def forward(
294294 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
295295 emulated : bool = False ,
296296 use_triton_for_dim0_cast : bool = False ,
297+ wgrad_with_hp : bool = False ,
297298 scale_calculation_mode : ScaleCalculationMode = ScaleCalculationMode .RCEIL ,
298299 ) -> torch .Tensor :
299300 # torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
@@ -352,6 +353,7 @@ def forward(
352353 ctx .out_dtype = out_dtype
353354 ctx .emulated = emulated
354355 ctx .use_triton_for_dim0_cast = use_triton_for_dim0_cast
356+ ctx .wgrad_with_hp = wgrad_with_hp
355357 ctx .scale_calculation_mode = scale_calculation_mode
356358 return out
357359
@@ -361,6 +363,7 @@ def backward(ctx, grad_out: torch.Tensor):
361363 block_size = ctx .block_size
362364 out_dtype = ctx .out_dtype
363365 use_triton_for_dim0_cast = ctx .use_triton_for_dim0_cast
366+ wgrad_with_hp = ctx .wgrad_with_hp
364367 scale_calculation_mode = ctx .scale_calculation_mode
365368
366369 # grad_out_data shape: (M, N)
@@ -405,59 +408,68 @@ def backward(ctx, grad_out: torch.Tensor):
405408 out_dtype = out_dtype ,
406409 )
407410
408- # grad_out_t_data shape: (M, N)
409- # grad_out_t_scales shape: (N, M//block_size)
410- grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper (
411- grad_out ,
412- block_size ,
413- elem_dtype = torch .float8_e4m3fn ,
414- hp_dtype = grad_out .dtype ,
415- kernel_preference = KernelPreference .AUTO , # Not used
416- cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
417- scale_calculation_mode = scale_calculation_mode ,
418- )
419- grad_out_t_data = grad_out_t_mx .qdata
420- grad_out_t_scales = grad_out_t_mx .scale
421-
422- # Transpose A so we can scale along the M dimension, then un-transpose.
423- # A shape: (M, K)
424- # A_t_data shape: (K, M)
425- # A_t_scales shape: (K, M//block_size)
426- A_t_mx = _to_mxfp8_dim1_kernel_wrapper (
427- A ,
428- block_size ,
429- elem_dtype = torch .float8_e4m3fn ,
430- hp_dtype = A .dtype ,
431- kernel_preference = KernelPreference .AUTO , # Not used
432- cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
433- scale_calculation_mode = scale_calculation_mode ,
434- )
435- A_t_data = A_t_mx .qdata
436- A_t_scales = A_t_mx .scale
437-
438- # Convert scales to blocked format for 2d-2d grouped mm
439- scale_group_offsets = offs // block_size
440- grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups (
441- grad_out_t_scales ,
442- scale_group_offsets ,
443- )
444- A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups (
445- A_t_scales ,
446- scale_group_offsets ,
447- )
411+ # Optionally compute wgrad in high precision, if specified.
412+ if wgrad_with_hp :
413+ # TODO: migrate all grouped gemms in this file to new torch.nn.functional API
414+ # grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)
415+ grad_B = torch ._grouped_mm (
416+ grad_out .transpose (- 2 , - 1 ), A , offs = offs , out_dtype = out_dtype
417+ )
418+ grad_B_t = grad_B .transpose (- 2 , - 1 )
419+ else :
420+ # grad_out_t_data shape: (M, N)
421+ # grad_out_t_scales shape: (N, M//block_size)
422+ grad_out_t_mx = _to_mxfp8_dim1_kernel_wrapper (
423+ grad_out ,
424+ block_size ,
425+ elem_dtype = torch .float8_e4m3fn ,
426+ hp_dtype = grad_out .dtype ,
427+ kernel_preference = KernelPreference .AUTO , # Not used
428+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
429+ scale_calculation_mode = scale_calculation_mode ,
430+ )
431+ grad_out_t_data = grad_out_t_mx .qdata
432+ grad_out_t_scales = grad_out_t_mx .scale
433+
434+ # Transpose A so we can scale along the M dimension, then un-transpose.
435+ # A shape: (M, K)
436+ # A_t_data shape: (K, M)
437+ # A_t_scales shape: (K, M//block_size)
438+ A_t_mx = _to_mxfp8_dim1_kernel_wrapper (
439+ A ,
440+ block_size ,
441+ elem_dtype = torch .float8_e4m3fn ,
442+ hp_dtype = A .dtype ,
443+ kernel_preference = KernelPreference .AUTO , # Not used
444+ cast_kernel_choice = MXFP8Dim1CastKernelChoice .CUDA ,
445+ scale_calculation_mode = scale_calculation_mode ,
446+ )
447+ A_t_data = A_t_mx .qdata
448+ A_t_scales = A_t_mx .scale
449+
450+ # Convert scales to blocked format for 2d-2d grouped mm
451+ scale_group_offsets = offs // block_size
452+ grad_out_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups (
453+ grad_out_t_scales ,
454+ scale_group_offsets ,
455+ )
456+ A_t_scales_blocked = triton_mx_block_rearrange_2d_K_groups (
457+ A_t_scales ,
458+ scale_group_offsets ,
459+ )
448460
449- # grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)
450- grad_B = torch ._scaled_grouped_mm (
451- grad_out_t_data ,
452- A_t_data .transpose (- 2 , - 1 ),
453- grad_out_t_scales_blocked ,
454- A_t_scales_blocked ,
455- offs = offs ,
456- out_dtype = out_dtype ,
457- )
458- # grad_B_t shape = (E,K,N)
459- grad_B_t = grad_B .transpose (- 2 , - 1 )
460- return grad_A , grad_B_t , None , None , None , None
461+ # grad_B_t = scaled grouped mm of (N,total_M) @ (total_M,K) = (E,N,K)
462+ grad_B = torch ._scaled_grouped_mm (
463+ grad_out_t_data ,
464+ A_t_data .transpose (- 2 , - 1 ),
465+ grad_out_t_scales_blocked ,
466+ A_t_scales_blocked ,
467+ offs = offs ,
468+ out_dtype = out_dtype ,
469+ )
470+ # grad_B_t shape = (E,K,N)
471+ grad_B_t = grad_B .transpose (- 2 , - 1 )
472+ return grad_A , grad_B_t , None , None , None , None , None , None , None
461473
462474
463475def _to_mxfp8_dim1_3d (
0 commit comments