diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py new file mode 100644 index 0000000000..2cba65d151 --- /dev/null +++ b/benchmarks/prototype/moe_training/mxfp8/bench_triton_mx_block_rearrange_2d_K_groups.py @@ -0,0 +1,265 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm + +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype.moe_training.kernels.mxfp8.quant import ( + mx_block_rearrange_2d_K_groups_cuda, + torch_to_blocked_2d_K_groups, + triton_mx_block_rearrange_2d_K_groups, +) +from torchao.prototype.moe_training.utils import generate_jagged_offs + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: tuple[int] + num_groups: int + version: str # "naive" or "parallel" + + +@dataclass(frozen=True) +class ExperimentResult: + time_us: float + mem_bw_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups. + block_size = 32 + input_shapes = [ + (8192, 32768 // block_size), + (8192, 65536 // block_size), + (8192, 131072 // block_size), + (5120, 32768 // block_size), + (5120, 65536 // block_size), + (5120, 131072 // block_size), + (7168, 32768 // block_size), + (7168, 65536 // block_size), + (7168, 131072 // block_size), + (2048, 32768 // block_size), + (2048, 65536 // block_size), + (2048, 131072 // block_size), + ] + num_groups = [8] + versions = [ + "torch", + "triton", + # CUDA kernel versions: cuda_{max_cols}_{chunks_per_tb} + "cuda_64_4", + "cuda_64_8", + "cuda_64_16", + "cuda_128_4", + "cuda_128_8", + "cuda_128_16", + ] + + configs = [] + for shape, groups, version in itertools.product( + input_shapes, + num_groups, + versions, + ): + configs.append( + ExperimentConfig( + input_shape=shape, + num_groups=groups, + version=version, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + input_shape, num_groups, version = ( + config.input_shape, + config.num_groups, + config.version, + ) + input_tensor = torch.randint( + low=0, + high=256, + size=input_shape, + dtype=torch.uint8, + device=device, + ) + + M, Kg = input_shape + block_size = 32 + input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size) + + # Select which kernel to benchmark based on version + if version == "torch": + kernel_fn = torch_to_blocked_2d_K_groups + kernel_input = input_tensor + elif version == "triton": + kernel_fn = triton_mx_block_rearrange_2d_K_groups + # Triton uses row-major input + kernel_input = input_tensor + elif version.startswith("cuda_"): + # Parse version string: cuda_{max_cols}_{chunks_per_tb} + parts = version.split("_") + max_cols = int(parts[1]) + chunks_per_tb = int(parts[2]) + kernel_fn = ( + lambda t, + o, + mc=max_cols, + cptb=chunks_per_tb: mx_block_rearrange_2d_K_groups_cuda( + t, + o, + max_cols=mc, + chunks_per_tb=cptb, + ) + ) + kernel_input = input_tensor.view(torch.float8_e8m0fnu) + else: + raise ValueError(f"Unknown version: {version}") + + # Run kernel to get output shape + outputs = kernel_fn( + kernel_input, + input_group_offsets, + ) + if isinstance(outputs, tuple): # torch returns a tuple with extra metadata + out_scales, _ = outputs + else: + out_scales = outputs + + # Benchmark the kernel + time_us = benchmark_cuda_function_in_microseconds( + kernel_fn, + kernel_input, + input_group_offsets, + ) + + # Calculate memory bandwidth + bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = out_scales.numel() * bytes_per_output_el + + mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6) + + return ExperimentResult( + time_us=time_us, + mem_bw_gbps=mem_bw_gbps, + ) + + +def print_results(experiments: List[Experiment]): + # Group experiments by input shape + shapes_dict = {} + for exp in experiments: + shape_key = exp.config.input_shape + if shape_key not in shapes_dict: + shapes_dict[shape_key] = {} + shapes_dict[shape_key][exp.config.version] = exp.result + + headers = [ + "kernel_version", + "scale_shape", + "time_us", + "mem_bw_gbps", + "speedup_vs_torch", + "speedup_vs_triton", + ] + + rows = [] + for shape, versions in shapes_dict.items(): + # Get torch baseline time for speedup calculation + torch_time_us = versions.get("torch").time_us if "torch" in versions else None + + # Get triton baseline time for speedup calculation + triton_time_us = ( + versions.get("triton").time_us if "triton" in versions else None + ) + + # Add rows for each version + for version, result in versions.items(): + # Calculate speedup vs torch + speedup_vs_torch_str = "" + if version != "torch" and torch_time_us is not None: + speedup = torch_time_us / result.time_us + speedup_vs_torch_str = f"{speedup:.2f}x" + + # Calculate speedup vs triton (only for CUDA kernels) + speedup_vs_triton_str = "" + if version.startswith("cuda_") and triton_time_us is not None: + speedup = triton_time_us / result.time_us + speedup_vs_triton_str = f"{speedup:.2f}x" + + rows.append( + [ + version, + f"({shape[0]}, {shape[1]})", + f"{result.time_us:.2f}", + round(result.mem_bw_gbps, 3), + speedup_vs_torch_str, + speedup_vs_triton_str, + ] + ) + + # Find best CUDA kernel speedup vs triton for this shape + best_cuda_speedup = 0.0 + best_cuda_version = None + for version, result in versions.items(): + if version.startswith("cuda_") and triton_time_us is not None: + speedup = triton_time_us / result.time_us + if speedup > best_cuda_speedup: + best_cuda_speedup = speedup + best_cuda_version = version + + if best_cuda_version is not None: + rows.append( + [ + f">>> BEST: {best_cuda_speedup:.2f}x vs triton with {best_cuda_version}", + "", + "", + "", + "", + ] + ) + + # Add empty row for visual separation between shapes + rows.append([""] * len(headers)) + + print(tabulate(rows, headers=headers)) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index bd7abb313a..2c6aea96c1 100644 --- a/setup.py +++ b/setup.py @@ -709,6 +709,7 @@ def get_extensions(): mxfp8_sources = [ os.path.join(mxfp8_extension_dir, "mxfp8_extension.cpp"), os.path.join(mxfp8_extension_dir, "mxfp8_cuda.cu"), + os.path.join(mxfp8_extension_dir, "mx_block_rearrange_2d_K_groups.cu"), ] # Only add the extension if the source files exist AND we are building for sm100 diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 0387cc28e0..81d3c2aa94 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -352,3 +352,57 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode): # Check quantized values torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0) assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match" + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +@pytest.mark.parametrize("m", [256, 512, 1024, 5120]) +@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384]) +@pytest.mark.parametrize("n_groups", [1, 4, 8, 16]) +def test_cuda_mx_block_rearrange_2d_K_groups( + m: int, + total_k: int, + n_groups: int, +): + """ + Test CUDA kernel for mx_block_rearrange_2d_K_groups against Triton reference. + """ + from torchao.prototype.moe_training.kernels.mxfp8.quant import ( + mx_block_rearrange_2d_K_groups_cuda, + ) + + device = "cuda" + block_size = 32 + input_data = torch.randn(m, total_k, device=device) + + e8m0_scales, _ = to_mx( + input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + # Generate group end offsets along total_K, then divide by block_size to get scale group end offsets + input_group_offsets = generate_jagged_offs( + n_groups, total_k, multiple_of=block_size, device=device + ) + scale_group_offsets = input_group_offsets // block_size + + # Triton reference implementation + triton_out_scales = triton_mx_block_rearrange_2d_K_groups( + e8m0_scales, + scale_group_offsets, + ) + + # CUDA kernel implementation + cuda_out_scales = mx_block_rearrange_2d_K_groups_cuda( + e8m0_scales, + scale_group_offsets, + ) + + # Check that outputs match + assert torch.equal(triton_out_scales, cuda_out_scales.view(torch.float8_e8m0fnu)), ( + "CUDA and Triton blocked scales not equal" + ) + + # Check strides + assert triton_out_scales.stride() == cuda_out_scales.stride(), "strides not equal" diff --git a/torchao/csrc/cuda/mx_kernels/mx_block_rearrange_2d_K_groups.cu b/torchao/csrc/cuda/mx_kernels/mx_block_rearrange_2d_K_groups.cu new file mode 100644 index 0000000000..1d7bf21ece --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mx_block_rearrange_2d_K_groups.cu @@ -0,0 +1,418 @@ +#include +#include +#include +#include + +#define BLOCK_ROWS 128 +#define BLOCK_COLS 4 +#define BYTES_PER_THREAD 16 +#define SCALE_FACTOR_ROWS 128 + +__device__ __forceinline__ int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +// Terminology: +// tile = 128x4 scaling factor tile +// chunk = chunk of data consisting of multiple tiles (e.g., 128x64 or 128x128) +// superblock = consists of CHUNKS_PER_TB chunks along the column dimension +template +__device__ void find_group_and_local_offset_for_superblock( + int super_col_block_pid, + const int32_t* __restrict__ input_group_end_offsets, + int num_groups, + int cols_per_block, + int* __restrict__ smem_data, // smem_data[0..num_groups-1] = chunks_in_group, smem_data[num_groups..2*num_groups-1] = super_blocks_in_group cumsum + int& group_id, + int& first_chunk_in_group, + int& chunks_until_group_end +) { + if (threadIdx.x == 0) { + int superblock_cumsum = 0; + for (int g = 0; g < num_groups; g++) { + int input_group_start = (g > 0) ? input_group_end_offsets[g - 1] : 0; + int input_group_end = input_group_end_offsets[g]; + int group_size = input_group_end - input_group_start; + int chunks_in_group = ceil_div(group_size, cols_per_block); + int superblocks_in_group = ceil_div(chunks_in_group, CHUNKS_PER_TB); + smem_data[g] = chunks_in_group; + superblock_cumsum += superblocks_in_group; + smem_data[num_groups + g] = superblock_cumsum; + } + } + __syncthreads(); + + group_id = 0; + int superblock_cumsum_before = 0; + for (int g = 0; g < num_groups; g++) { + int cumsum_at_g = smem_data[num_groups + g]; + if (super_col_block_pid < cumsum_at_g) { + group_id = g; + int local_superblock = super_col_block_pid - superblock_cumsum_before; + first_chunk_in_group = local_superblock * CHUNKS_PER_TB; + int chunks_in_group = smem_data[g]; + chunks_until_group_end = chunks_in_group - first_chunk_in_group; + return; + } + superblock_cumsum_before = cumsum_at_g; + } + + first_chunk_in_group = 0; + chunks_until_group_end = 0; +} + +__device__ __forceinline__ int compute_output_group_start_col( + int group_id, + const int32_t* input_group_end_offsets, + int num_groups, + int padding_size +) { + int start_idx = 0; + for (int i = 0; i < group_id; i++) { + int prev_offset = (i > 0) ? input_group_end_offsets[i - 1] : 0; + int curr_offset = input_group_end_offsets[i]; + int group_size = curr_offset - prev_offset; + int padded_size = ceil_div(group_size, padding_size) * padding_size; + start_idx += padded_size; + } + return start_idx; +} + + +// Uses 2-stage software pipelining with double buffering to overlap memory +// transfers with compute. Each threadblock processes CHUNKS_PER_TB consecutive +// chunks within the same group. +template +__global__ void mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel( + const uint8_t* __restrict__ scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* __restrict__ input_group_end_offsets, + uint8_t* __restrict__ output_scales_ptr, + int output_stride_per_block, + int num_groups +) { + constexpr int THREADS_PER_ROW = MAX_COLS / 16; + constexpr int TILES_PER_THREAD = 4; // Each thread processes 16 bytes = 4 tiles (4 bytes per tile) + constexpr int TILE_SIZE = SCALE_FACTOR_ROWS * BLOCK_COLS; // 128 rows * 4 cols = 512 bytes per tile + constexpr int SMEM_SIZE = BLOCK_ROWS * MAX_COLS; + constexpr int NUM_BUFFERS = 2; + + const int super_col_block_pid = blockIdx.x; + const int row_block_pid = blockIdx.y; + const int tid = threadIdx.x; + + __shared__ __align__(16) uint8_t smem[NUM_BUFFERS][SMEM_SIZE]; + __shared__ int smem_group_data[64]; // max 32 groups; 1 int per group size, 1 int per group prefix sum + __shared__ int s_output_group_start_col; + __shared__ int s_input_group_start_col; + __shared__ int s_input_group_end_col; + __shared__ int s_first_chunk_in_group; + __shared__ int s_num_chunks_to_process; + + // PHASE 0: Map super-block to (group, first_chunk) pair + int group_id, first_chunk_in_group, chunks_until_group_end; + find_group_and_local_offset_for_superblock( + super_col_block_pid, + input_group_end_offsets, + num_groups, + MAX_COLS, + smem_group_data, + group_id, + first_chunk_in_group, + chunks_until_group_end + ); + + // Use one thread in the threadblock to compute group boundaries and output group start, + // then broadcast the values via SMEM. + // This avoids (1) unnecessary extra global accesses, and (2) extra register pressure, and + // (3) extra ALU usage by every thread computing redundant values, in a kernel which is already ALU heavy. + // It comes at the cost of these few SMEM accesses and thread block sync, but benchmarks show this is slightly + // better than having all threads do this redundant work. + if (tid == 0) { + s_input_group_start_col = (group_id > 0) ? input_group_end_offsets[group_id - 1] : 0; + s_input_group_end_col = input_group_end_offsets[group_id]; + s_first_chunk_in_group = first_chunk_in_group; + s_output_group_start_col = compute_output_group_start_col( + group_id, input_group_end_offsets, num_groups, 4 + ); + s_num_chunks_to_process = (chunks_until_group_end > 0) ? min(CHUNKS_PER_TB, chunks_until_group_end) : 0; + } + + __syncthreads(); + + int input_group_start_col = s_input_group_start_col; + int input_group_end_col = s_input_group_end_col; + first_chunk_in_group = s_first_chunk_in_group; + int output_group_start_col = s_output_group_start_col; + int num_chunks_to_process = s_num_chunks_to_process; + + if (num_chunks_to_process <= 0) { + return; + } + + // PHASE 1: Precompute thread-constant values + int global_row_base = row_block_pid * BLOCK_ROWS; + int row_idx = tid / THREADS_PER_ROW; + int col_idx = tid % THREADS_PER_ROW; + int global_row = global_row_base + row_idx; + bool row_valid = (global_row < scale_rows); + + int r_div_32 = row_idx >> 5; // row / 32 + int r_mod_32 = row_idx & 31; // row % 32 + int swizzle_base = (r_mod_32 << 4) + (r_div_32 << 2); // (row % 32) * 16 + (row / 32) * 4 + int thread_col_start = col_idx * 16; + int first_tile_idx = thread_col_start >> 2; // thread_col_start / 4 + + int out_group_base_offset = output_group_start_col * padded_rows; + int num_cols_in_group = input_group_end_col - input_group_start_col; + int num_tiles_in_group = ceil_div(num_cols_in_group, BLOCK_COLS); + int tiles_stride_per_row_block = num_tiles_in_group * TILE_SIZE; + + const uint8_t* row_base_ptr = scales_ptr + + static_cast(global_row) * scales_stride_dim0; + + // PHASE 2: Pipelined execution with double buffering + auto load_chunk_async = [&](int chunk_idx, int buf_idx) { + int curr_chunk_in_group = first_chunk_in_group + chunk_idx; + int curr_input_start_col = input_group_start_col + curr_chunk_in_group * MAX_COLS; + int cols_remaining = input_group_end_col - curr_input_start_col; + int cols_to_load = min(MAX_COLS, cols_remaining); + bool can_load = row_valid && (thread_col_start < cols_to_load); + + if (can_load) { + const uint8_t* src_ptr = row_base_ptr + curr_input_start_col + thread_col_start; + uintptr_t gmem_addr = reinterpret_cast(src_ptr); + bool aligned = (gmem_addr % 16 == 0); + bool full_vec = (thread_col_start + 16 <= cols_to_load) && aligned; + + if (full_vec) { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], 16;\n" + : + : "l"(&smem[buf_idx][row_idx * MAX_COLS + thread_col_start]), + "l"(src_ptr) + ); + } else { + uint4 data = make_uint4(0, 0, 0, 0); + uint8_t* bytes = reinterpret_cast(&data); + int bytes_to_load = min(16, cols_to_load - thread_col_start); + #pragma unroll + for (int i = 0; i < 16; i++) { + if (i < bytes_to_load) bytes[i] = __ldg(src_ptr + i); + } + *reinterpret_cast(&smem[buf_idx][row_idx * MAX_COLS + thread_col_start]) = data; + } + } else { + *reinterpret_cast(&smem[buf_idx][row_idx * MAX_COLS + thread_col_start]) = make_uint4(0, 0, 0, 0); + } + }; + + // Process chunk: read from linear SMEM, swizzle, store to GMEM + auto process_chunk = [&](int chunk_idx, int buf_idx) { + int curr_chunk_in_group = first_chunk_in_group + chunk_idx; + int curr_input_start_col = input_group_start_col + curr_chunk_in_group * MAX_COLS; + int cols_remaining = input_group_end_col - curr_input_start_col; + int cols_to_load = min(MAX_COLS, cols_remaining); + + uint4 data = *reinterpret_cast(&smem[buf_idx][row_idx * MAX_COLS + thread_col_start]); + + __syncthreads(); + + int tile_xor = (col_idx & 3) << 2; // col_idx % 4 + int superrow_xor = ((swizzle_base >> 7) & 3) << 2; // ((swizzle_base / 128) % 4) * 4 + int combined_xor = tile_xor ^ superrow_xor; + + uint32_t* data32 = reinterpret_cast(&data); + #pragma unroll + for (int t = 0; t < TILES_PER_THREAD; t++) { + int tile_idx = first_tile_idx + t; + int tile_base = tile_idx * SCALE_FACTOR_ROWS * BLOCK_COLS; + int swizzled_idx = tile_base + (swizzle_base ^ combined_xor); + *reinterpret_cast(&smem[buf_idx][swizzled_idx]) = data32[t]; + } + + __syncthreads(); + + // Compute output pointer: skip past tiles from previous chunks in this group + int tiles_before_this_chunk = curr_chunk_in_group * (MAX_COLS / BLOCK_COLS); + uint8_t* out_base = output_scales_ptr + out_group_base_offset + + row_block_pid * tiles_stride_per_row_block + + tiles_before_this_chunk * TILE_SIZE; + + int num_tiles_this_chunk = ceil_div(cols_to_load, BLOCK_COLS); + int bytes_to_copy = num_tiles_this_chunk * TILE_SIZE; + + // Each thread writes 16 bytes (4 tiles worth of data for its row position) + int byte_offset = tid * 16; + if (byte_offset < bytes_to_copy) { + uint32_t out_data[4]; + + // Read 4 uint32s from swizzled SMEM layout, accounting for writer's XOR pattern + #pragma unroll + for (int i = 0; i < 4; i++) { + int out_byte = byte_offset + i * 4; + int tile_idx = out_byte / TILE_SIZE; + int within_tile_offset = out_byte % TILE_SIZE; + + int writer_col_idx = (tile_idx / TILES_PER_THREAD) % THREADS_PER_ROW; + int writer_tile_xor = (writer_col_idx & 3) << 2; // (writer_col % 4) * 4 + int writer_superrow_xor = ((within_tile_offset >> 7) & 3) << 2; // ((within_tile_offset / 128) % 4) * 4 + int writer_combined_xor = writer_tile_xor ^ writer_superrow_xor; + int smem_addr = tile_idx * TILE_SIZE + (within_tile_offset ^ writer_combined_xor); + out_data[i] = *reinterpret_cast(&smem[buf_idx][smem_addr]); + } + + *reinterpret_cast(out_base + byte_offset) = + *reinterpret_cast(out_data); + } + }; + + if (num_chunks_to_process == 1) + { + load_chunk_async(0, 0); + asm volatile("cp.async.commit_group;\n"); + asm volatile("cp.async.wait_group 0;\n"); + __syncthreads(); + process_chunk(0, 0); + } + else + { + // PROLOGUE: Load first chunk + load_chunk_async(0, 0); + asm volatile("cp.async.commit_group;\n"); + + // STEADY STATE: Overlap load N+1 with processing N + for (int chunk = 0; chunk < num_chunks_to_process - 1; chunk++) { + int curr_buf = chunk & 1; // chunk % 2 + int next_buf = (chunk + 1) & 1; // (chunk + 1) % 2 + + // Kick off async load of chunk N+1 + load_chunk_async(chunk + 1, next_buf); + asm volatile("cp.async.commit_group;\n"); + + // Wait for async load of chunk N, without waiting for chunk N+1, to achieve overlap. + // async loads completed in commit order (FIFO) so this should work. + asm volatile("cp.async.wait_group 1;\n"); + __syncthreads(); + + // Process chunk N, overlapping with async load of chunk N+1. + process_chunk(chunk, curr_buf); + } + + // EPILOGUE: Process final chunk + int last_chunk = num_chunks_to_process - 1; + int last_buf = last_chunk & 1; + asm volatile("cp.async.wait_group 0;\n"); + __syncthreads(); + process_chunk(last_chunk, last_buf); + } +} + + +template __global__ void mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<64, 4>( + const uint8_t* __restrict__, int, int, int, int, + const int32_t* __restrict__, uint8_t* __restrict__, int, int); + +template __global__ void mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<64, 8>( + const uint8_t* __restrict__, int, int, int, int, + const int32_t* __restrict__, uint8_t* __restrict__, int, int); + +template __global__ void mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<64, 16>( + const uint8_t* __restrict__, int, int, int, int, + const int32_t* __restrict__, uint8_t* __restrict__, int, int); + +template __global__ void mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<128, 4>( + const uint8_t* __restrict__, int, int, int, int, + const int32_t* __restrict__, uint8_t* __restrict__, int, int); + +template __global__ void mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<128, 8>( + const uint8_t* __restrict__, int, int, int, int, + const int32_t* __restrict__, uint8_t* __restrict__, int, int); + +template __global__ void mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<128, 16>( + const uint8_t* __restrict__, int, int, int, int, + const int32_t* __restrict__, uint8_t* __restrict__, int, int); + +namespace mxfp8 { + +void launch_mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + int max_cols, // Template selector: 64 or 128 + int chunks_per_tb, // Chunks per super-block: 4, 8, or 16 + cudaStream_t stream +) { + int num_row_blocks = (scale_rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + int output_stride_per_block = BLOCK_ROWS * BLOCK_COLS; + + int total_chunks = (scale_cols + max_cols - 1) / max_cols + num_groups; + int total_super_col_blocks = (total_chunks + chunks_per_tb - 1) / chunks_per_tb + num_groups; + + dim3 grid(total_super_col_blocks, num_row_blocks); + + if (max_cols == 64) { + dim3 block(512); + switch (chunks_per_tb) { + case 4: + mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<64, 4><<>>( + scales_ptr, scales_stride_dim0, scale_rows, scale_cols, padded_rows, + input_group_end_offsets, output_scales_ptr, output_stride_per_block, num_groups); + break; + case 8: + mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<64, 8><<>>( + scales_ptr, scales_stride_dim0, scale_rows, scale_cols, padded_rows, + input_group_end_offsets, output_scales_ptr, output_stride_per_block, num_groups); + break; + case 16: + mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<64, 16><<>>( + scales_ptr, scales_stride_dim0, scale_rows, scale_cols, padded_rows, + input_group_end_offsets, output_scales_ptr, output_stride_per_block, num_groups); + break; + default: + printf("CUDA Error: chunks_per_tb must be 4, 8, or 16, got %d\n", chunks_per_tb); + return; + } + } else if (max_cols == 128) { + dim3 block(1024); + switch (chunks_per_tb) { + case 4: + mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<128, 4><<>>( + scales_ptr, scales_stride_dim0, scale_rows, scale_cols, padded_rows, + input_group_end_offsets, output_scales_ptr, output_stride_per_block, num_groups); + break; + case 8: + mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<128, 8><<>>( + scales_ptr, scales_stride_dim0, scale_rows, scale_cols, padded_rows, + input_group_end_offsets, output_scales_ptr, output_stride_per_block, num_groups); + break; + case 16: + mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined_kernel<128, 16><<>>( + scales_ptr, scales_stride_dim0, scale_rows, scale_cols, padded_rows, + input_group_end_offsets, output_scales_ptr, output_stride_per_block, num_groups); + break; + default: + printf("CUDA Error: chunks_per_tb must be 4, 8, or 16, got %d\n", chunks_per_tb); + return; + } + } else { + printf("CUDA Error: max_cols must be 64 or 128, got %d\n", max_cols); + return; + } + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error (pipelined max_cols=%d, chunks_per_tb=%d): %s\n", + max_cols, chunks_per_tb, cudaGetErrorString(err)); + } +} + +} // namespace mxfp8 diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp index d51bf47b0a..3b49b4c1c0 100644 --- a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp @@ -26,6 +26,20 @@ void mxfp8_quantize_3d_cuda(const at::Tensor &input, const std::string &fp8_format, const std::string &scaling_mode); + +void launch_mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined( + const uint8_t* scales_ptr, + int scales_stride_dim0, + int scale_rows, + int scale_cols, + int padded_rows, + const int32_t* input_group_end_offsets, + uint8_t* output_scales_ptr, + int num_groups, + int max_cols, // Max cols processed per thread block template selector: 64 or 128 + int tiles_per_tb, // Chunks per super-block: 4, 8, or 16 + cudaStream_t stream); + // Helper for tensor validation void check_cuda_tensor(const at::Tensor &t, const char *name) { TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); @@ -178,6 +192,75 @@ mxfp8_quantize_3d(const at::Tensor& input, int64_t scale_dim_n, return std::make_tuple(output_colwise, scales_colwise); } +// Converts e8m0 scale factors to blocked layout needed for MXFP8 Grouped GEMM. +// Layout transformation occurs per group, where groups are along the K dim / columns. +at::Tensor mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined( + at::Tensor scales_tensor, + at::Tensor input_group_end_offsets, + int64_t max_cols, + int64_t tiles_per_tb) { + + // Validate inputs + check_cuda_tensor(scales_tensor, "scales_tensor"); + check_cuda_tensor(input_group_end_offsets, "input_group_end_offsets"); + + TORCH_CHECK(scales_tensor.dim() == 2, "scales_tensor must be 2D"); + TORCH_CHECK(scales_tensor.is_contiguous(), "scales_tensor must be contiguous (row-major)"); + TORCH_CHECK(scales_tensor.scalar_type() == at::kFloat8_e8m0fnu, + "scales_tensor must be e8m0"); + TORCH_CHECK(input_group_end_offsets.scalar_type() == at::kInt, + "input_group_end_offsets must be int32"); + TORCH_CHECK(input_group_end_offsets.dim() == 1, + "input_group_end_offsets must be 1D"); + TORCH_CHECK(max_cols == 64 || max_cols == 128, + "max_cols must be 64 or 128, got: ", max_cols); + TORCH_CHECK(tiles_per_tb == 4 || tiles_per_tb == 8 || tiles_per_tb == 16, + "tiles_per_tb must be 4, 8, or 16, got: ", tiles_per_tb); + + c10::cuda::CUDAGuard device_guard(scales_tensor.device()); + + const int rows = scales_tensor.size(0); + const int cols = scales_tensor.size(1); + const int num_groups = input_group_end_offsets.size(0); + TORCH_CHECK(num_groups <= 32, "num_groups must be <= 32"); + + // Calculate blocks needed - uses 128-row blocks + const int BLOCK_ROWS = 128; + const int BLOCK_COLS = 4; + const int num_row_blocks = (rows + BLOCK_ROWS - 1) / BLOCK_ROWS; + const int padded_rows = num_row_blocks * BLOCK_ROWS; + + // Padding per group is variable/data dependent, so pad each group by upper bound + const int padded_cols = cols + num_groups * BLOCK_COLS; + + // Create output tensor + auto output = at::zeros({padded_rows, padded_cols}, + at::TensorOptions() + .dtype(scales_tensor.scalar_type()) + .device(scales_tensor.device())); + + // Get raw pointers - reinterpret float8 as uint8 + const uint8_t* scales_ptr = reinterpret_cast(scales_tensor.data_ptr()); + const int32_t* offsets_ptr = input_group_end_offsets.data_ptr(); + uint8_t* output_ptr = reinterpret_cast(output.data_ptr()); + + // Launch pipelined kernel with specified max_cols and tiles_per_tb + launch_mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined( + scales_ptr, + scales_tensor.stride(0), + rows, + cols, + padded_rows, + offsets_ptr, + output_ptr, + num_groups, + static_cast(max_cols), + static_cast(tiles_per_tb), + at::cuda::getCurrentCUDAStream()); + + return output; +} + } // namespace mxfp8 @@ -185,4 +268,5 @@ mxfp8_quantize_3d(const at::Tensor& input, int64_t scale_dim_n, TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("mxfp8_quantize", &mxfp8::mxfp8_quantize); m.impl("mxfp8_quantize_3d", &mxfp8::mxfp8_quantize_3d); + m.impl("mx_block_rearrange_2d_K_groups", &mxfp8::mx_block_rearrange_2d_K_groups_rowmajor_128x4_vec_pipelined); } diff --git a/torchao/prototype/moe_training/kernels/mxfp8/quant.py b/torchao/prototype/moe_training/kernels/mxfp8/quant.py index 5c2c28112e..c5f92df369 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/quant.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/quant.py @@ -730,6 +730,67 @@ def _fake_mxfp8_quantize_3d( scales = x.new_empty((E, N // scale_dim_n, K), dtype=torch.float8_e8m0fnu) return q_data, scales + # CUDA kernel for converting e8m0 scale factors to blocked layout on a per-group basis, + # where the groups are along the K/contracting dimension. + lib.define( + "mx_block_rearrange_2d_K_groups(Tensor scales_tensor, Tensor input_group_end_offsets, int max_cols, int chunks_per_tb) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], + ) + + def mx_block_rearrange_2d_K_groups_cuda( + scales_tensor: torch.Tensor, + input_group_end_offsets: torch.Tensor, + max_cols: int = 64, + chunks_per_tb: int = 4, + ) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale to block-scaled swizzle format on a per group basis, + where the groups are along the contraction dimension of the GEMM. + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor. + input_group_end_offsets: tensor of int32 values representing group end indexes for the input scales + max_cols (int, optional): Maximum columns processed per 128x4 thread block. Defaults to 64 (i.e. 128x64 chunk size). + chunks_per_tb (int, optional): How many chunks to process per thread block (for pipelining). Defaults to 4. + """ + assert scales_tensor.ndim == 2, "scales tensor must be 2d" + assert scales_tensor.dtype == torch.float8_e8m0fnu, ( + "Expected dtype to be torch.float8_e8m0fnu" + ) + + return torch.ops.torchao.mx_block_rearrange_2d_K_groups.default( + scales_tensor, + input_group_end_offsets, + max_cols, + chunks_per_tb, + ) + + @torch.library.register_fake("torchao::mx_block_rearrange_2d_K_groups") + def _fake_mx_block_rearrange_2d_K_groups_cuda( + scales_tensor: torch.Tensor, + input_group_end_offsets: torch.Tensor, + max_cols: int, + chunks_per_tb: int, + ) -> torch.Tensor: + """Fake/meta implementation for mx_block_rearrange_2d_K_groups_cuda.""" + assert scales_tensor.ndim == 2, "scales tensor must be 2d" + assert scales_tensor.dtype == torch.float8_e8m0fnu, ( + "Expected dtype to be torch.float8_e8m0fnu" + ) + num_groups = input_group_end_offsets.shape[0] + M, total_K = scales_tensor.shape + + # Group sizes are dynamic, and must be padded to next multiple of 4. + # Therefore, when allocating a buffer for the output, we use upper bound + # padding 4 * num_groups. + blocked_scales = scales_tensor.new_empty( + (M, total_K + 4 * num_groups), dtype=torch.float8_e8m0fnu + ) + return blocked_scales + else: def mxfp8_quantize_cuda_3d(