|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import itertools |
| 8 | +from dataclasses import dataclass |
| 9 | +from typing import List |
| 10 | + |
| 11 | +import torch |
| 12 | +from tabulate import tabulate |
| 13 | +from tqdm import tqdm |
| 14 | + |
| 15 | +from benchmarks.utils import benchmark_cuda_function_in_microseconds |
| 16 | +from torchao.prototype.moe_training.kernels.mxfp8.quant import ( |
| 17 | + mx_block_rearrange_2d_K_groups_cuda, |
| 18 | + torch_to_blocked_2d_K_groups, |
| 19 | + triton_mx_block_rearrange_2d_K_groups, |
| 20 | +) |
| 21 | +from torchao.prototype.moe_training.utils import generate_jagged_offs |
| 22 | + |
| 23 | +device = torch.device("cuda") |
| 24 | + |
| 25 | +# Needed since changing args to function causes recompiles |
| 26 | +torch._dynamo.config.cache_size_limit = 1000 |
| 27 | + |
| 28 | + |
| 29 | +@dataclass(frozen=True) |
| 30 | +class ExperimentConfig: |
| 31 | + input_shape: tuple[int] |
| 32 | + num_groups: int |
| 33 | + version: str # "naive" or "parallel" |
| 34 | + |
| 35 | + |
| 36 | +@dataclass(frozen=True) |
| 37 | +class ExperimentResult: |
| 38 | + time_us: float |
| 39 | + mem_bw_gbps: float |
| 40 | + |
| 41 | + |
| 42 | +@dataclass(frozen=True) |
| 43 | +class Experiment: |
| 44 | + config: ExperimentConfig |
| 45 | + result: ExperimentResult |
| 46 | + |
| 47 | + |
| 48 | +def get_configs() -> List[ExperimentConfig]: |
| 49 | + # Llama4 and DSV3 671b shapes. Input activations are scaled along the total_M dim, which contains all the token groups. |
| 50 | + block_size = 32 |
| 51 | + input_shapes = [ |
| 52 | + (8192, 16384 // block_size), |
| 53 | + (8192, 32768 // block_size), |
| 54 | + (8192, 65536 // block_size), |
| 55 | + (8192, 131072 // block_size), |
| 56 | + (8192, 1048576 // block_size), |
| 57 | + (5120, 16384 // block_size), |
| 58 | + (5120, 32768 // block_size), |
| 59 | + (5120, 65536 // block_size), |
| 60 | + (5120, 131072 // block_size), |
| 61 | + (5120, 1048576 // block_size), |
| 62 | + (7168, 16384 // block_size), |
| 63 | + (7168, 32768 // block_size), |
| 64 | + (7168, 65536 // block_size), |
| 65 | + (7168, 131072 // block_size), |
| 66 | + (7168, 1048576 // block_size), |
| 67 | + (2048, 16384 // block_size), |
| 68 | + (2048, 32768 // block_size), |
| 69 | + (2048, 65536 // block_size), |
| 70 | + (2048, 131072 // block_size), |
| 71 | + (2048, 1048576 // block_size), |
| 72 | + ] |
| 73 | + num_groups = [8] |
| 74 | + versions = [ |
| 75 | + "torch", |
| 76 | + "triton", |
| 77 | + "cuda_pipelined_64_chunks4", |
| 78 | + ] |
| 79 | + |
| 80 | + configs = [] |
| 81 | + for shape, groups, version in itertools.product( |
| 82 | + input_shapes, |
| 83 | + num_groups, |
| 84 | + versions, |
| 85 | + ): |
| 86 | + configs.append( |
| 87 | + ExperimentConfig( |
| 88 | + input_shape=shape, |
| 89 | + num_groups=groups, |
| 90 | + version=version, |
| 91 | + ) |
| 92 | + ) |
| 93 | + return configs |
| 94 | + |
| 95 | + |
| 96 | +def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
| 97 | + input_shape, num_groups, version = ( |
| 98 | + config.input_shape, |
| 99 | + config.num_groups, |
| 100 | + config.version, |
| 101 | + ) |
| 102 | + input_tensor = torch.randint( |
| 103 | + low=0, |
| 104 | + high=256, |
| 105 | + size=input_shape, |
| 106 | + dtype=torch.uint8, |
| 107 | + device=device, |
| 108 | + ) |
| 109 | + |
| 110 | + M, Kg = input_shape |
| 111 | + block_size = 32 |
| 112 | + input_group_offsets = generate_jagged_offs(num_groups, Kg, multiple_of=block_size) |
| 113 | + |
| 114 | + # Select which kernel to benchmark based on version |
| 115 | + if version == "torch": |
| 116 | + kernel_fn = torch_to_blocked_2d_K_groups |
| 117 | + kernel_input = input_tensor |
| 118 | + elif version == "triton": |
| 119 | + kernel_fn = triton_mx_block_rearrange_2d_K_groups |
| 120 | + # Triton uses row-major input |
| 121 | + kernel_input = input_tensor |
| 122 | + elif version == "cuda_pipelined_64_chunks4": |
| 123 | + kernel_fn = lambda t, o: mx_block_rearrange_2d_K_groups_cuda( |
| 124 | + t, |
| 125 | + o, |
| 126 | + ) |
| 127 | + kernel_input = input_tensor.view(torch.float8_e8m0fnu) |
| 128 | + else: |
| 129 | + raise ValueError(f"Unknown version: {version}") |
| 130 | + |
| 131 | + # Run kernel to get output shape |
| 132 | + outputs = kernel_fn( |
| 133 | + kernel_input, |
| 134 | + input_group_offsets, |
| 135 | + ) |
| 136 | + if isinstance(outputs, tuple): # torch returns a tuple with extra metadata |
| 137 | + out_scales, _ = outputs |
| 138 | + else: |
| 139 | + out_scales = outputs |
| 140 | + |
| 141 | + # Benchmark the kernel |
| 142 | + time_us = benchmark_cuda_function_in_microseconds( |
| 143 | + kernel_fn, |
| 144 | + kernel_input, |
| 145 | + input_group_offsets, |
| 146 | + ) |
| 147 | + |
| 148 | + # Calculate memory bandwidth |
| 149 | + bytes_per_input_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 |
| 150 | + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 |
| 151 | + |
| 152 | + read_bytes = input_tensor.numel() * bytes_per_input_el |
| 153 | + write_bytes = out_scales.numel() * bytes_per_output_el |
| 154 | + |
| 155 | + mem_bw_gbps = ((read_bytes + write_bytes) / 1e9) / (time_us / 1e6) |
| 156 | + |
| 157 | + return ExperimentResult( |
| 158 | + time_us=time_us, |
| 159 | + mem_bw_gbps=mem_bw_gbps, |
| 160 | + ) |
| 161 | + |
| 162 | + |
| 163 | +def print_results(experiments: List[Experiment]): |
| 164 | + # Group experiments by input shape |
| 165 | + shapes_dict = {} |
| 166 | + for exp in experiments: |
| 167 | + shape_key = exp.config.input_shape |
| 168 | + if shape_key not in shapes_dict: |
| 169 | + shapes_dict[shape_key] = {} |
| 170 | + shapes_dict[shape_key][exp.config.version] = exp.result |
| 171 | + |
| 172 | + headers = [ |
| 173 | + "kernel_version", |
| 174 | + "scale_shape", |
| 175 | + "time_us", |
| 176 | + "mem_bw_gbps", |
| 177 | + "fastest_version", |
| 178 | + "speedup_vs_torch", |
| 179 | + ] |
| 180 | + |
| 181 | + rows = [] |
| 182 | + for shape, versions in shapes_dict.items(): |
| 183 | + # Find fastest version for this shape |
| 184 | + fastest_version = min(versions.items(), key=lambda x: x[1].time_us)[0] |
| 185 | + |
| 186 | + # Get torch baseline time for speedup calculation |
| 187 | + torch_time_us = versions.get("torch").time_us if "torch" in versions else None |
| 188 | + |
| 189 | + # Add rows for each version |
| 190 | + for version, result in versions.items(): |
| 191 | + # Calculate speedup vs triton |
| 192 | + speedup_str = "" |
| 193 | + if version != "torch": |
| 194 | + speedup = torch_time_us / result.time_us |
| 195 | + speedup_str = f"{speedup:.2f}x" |
| 196 | + |
| 197 | + rows.append( |
| 198 | + [ |
| 199 | + version, |
| 200 | + f"({shape[0]}, {shape[1]})", |
| 201 | + f"{result.time_us:.2f}", |
| 202 | + round(result.mem_bw_gbps, 3), |
| 203 | + fastest_version, |
| 204 | + speedup_str, |
| 205 | + ] |
| 206 | + ) |
| 207 | + |
| 208 | + print(tabulate(rows, headers=headers)) |
| 209 | + |
| 210 | + |
| 211 | +def main(): |
| 212 | + torch.random.manual_seed(123) |
| 213 | + configs = get_configs() |
| 214 | + results = [] |
| 215 | + for config in tqdm(configs): |
| 216 | + result = run_experiment(config) |
| 217 | + results.append(Experiment(config=config, result=result)) |
| 218 | + |
| 219 | + # Use Tabulate to print results |
| 220 | + print_results(results) |
| 221 | + |
| 222 | + |
| 223 | +if __name__ == "__main__": |
| 224 | + main() |
0 commit comments