Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 17, 2025

Stacked PRs:


[mxfp8 moe training] add CUDA kernel for per-group conversion of scale factors to blocked layout

TL;DR this PR adds a new CUDA kernel for blocked layout scale factors with groups along K for 2d2d grouped GEMM that is ~2x to 3.5x faster than existing Triton kernel. For the DSV3 shapes we care most about, it's about 2.5x to 3x faster than Triton.

Kernel design

Summary

  • 128x4 thread blocks load 128x64 chunks of row major scale data using coalesced vectorized uint4 (16 byte) loads from global memory. (templating allows for other configurations but I got the best results with this).
  • Each thread has 16 bytes in register memory. It writes this to SMEM in 4 separate 4-byte contiguous chunks (uint32 write) to the appropriate locations in shared memory in order to perform the transformation to per-grouped blocked layout.
  • Avoiding bank conflicts:
    • XOR swizzle based on base column index in smem used to avoid bank conflicts on writes (tile_xor).
    • 2nd XOR swizzle based on "supperow" index (which 128 byte chunk) to avoid bank conflicts on reads when copying the result to GMEM (superrow_xor)
    • Compose these 2 via tile_xor ^ superrow_xor to avoid bank conflicts on both reads and writes.
  • Copy data from SMEM to GMEM via coalesced, vectorized uint4/16 byte writes. It first goes through registers because we need to unswizzle first.
  • Pipeline async load of chunk N+1 with processing of chunk N.

Benchmarks

  • Ranges from 1.8x to 3.5x faster than triton depending on input shape. Speedup is better for larger shapes.
  • Shapes below are (model_dim, total_M // 32) since this is for 2d2d grouped gemm where we scale along the total_M dim.
  • total_M (local batch size * seq len) ranges from 32768 to 131072 (scale widths of 1024 and 4096, respectively)
  • Ran benchmarks with all combinations of "chunk width" and "chunks per threadblock."

Note: memory bandwidth utilization is not the best metric for this kernel since the ALU pipelines are heavily utilized as well. I am just using it as a proxy here while also referencing NCU to verify there's no substantial performance issues remaining.

kernel_version                             scale_shape    time_us    mem_bw_gbps    speedup_vs_torch    speedup_vs_triton
-----------------------------------------  -------------  ---------  -------------  ------------------  -------------------
torch                                      (8192, 1024)   283.42     60.12
triton                                     (8192, 1024)   51.30      332.177        5.53x
cuda_64_4                                  (8192, 1024)   17.47      975.238        16.22x              2.94x
cuda_64_8                                  (8192, 1024)   17.50      973.455        16.19x              2.93x
cuda_64_16                                 (8192, 1024)   17.47      975.238        16.22x              2.94x
cuda_128_4                                 (8192, 1024)   21.54      791.204        13.16x              2.38x
cuda_128_8                                 (8192, 1024)   21.54      791.204        13.16x              2.38x
cuda_128_16                                (8192, 1024)   19.49      874.351        14.54x              2.63x
>>> BEST: 2.94x vs triton with cuda_64_4

torch                                      (8192, 2048)   432.67     78.158
triton                                     (8192, 2048)   51.20      660.48         8.45x
cuda_64_4                                  (8192, 2048)   23.74      1424.216       18.22x              2.16x
cuda_64_8                                  (8192, 2048)   23.71      1426.138       18.25x              2.16x
cuda_64_16                                 (8192, 2048)   27.65      1223.111       15.65x              1.85x
cuda_128_4                                 (8192, 2048)   27.81      1216.074       15.56x              1.84x
cuda_128_8                                 (8192, 2048)   27.78      1217.475       15.58x              1.84x
cuda_128_16                                (8192, 2048)   26.66      1268.629       16.23x              1.92x
>>> BEST: 2.16x vs triton with cuda_64_8

torch                                      (8192, 4096)   444.40     151.6
triton                                     (8192, 4096)   119.65     563.077        3.71x
cuda_64_4                                  (8192, 4096)   33.63      2003.182       13.21x              3.56x
cuda_64_8                                  (8192, 4096)   35.65      1889.896       12.47x              3.36x
cuda_64_16                                 (8192, 4096)   37.89      1778.162       11.73x              3.16x
cuda_128_4                                 (8192, 4096)   39.74      1695.124       11.18x              3.01x
cuda_128_8                                 (8192, 4096)   39.97      1685.624       11.12x              2.99x
cuda_128_16                                (8192, 4096)   46.11      1461.03        9.64x               2.59x
>>> BEST: 3.56x vs triton with cuda_64_4

torch                                      (5120, 1024)   429.76     24.78
triton                                     (5120, 1024)   47.30      225.169        9.09x
cuda_64_4                                  (5120, 1024)   15.42      690.456        27.86x              3.07x
cuda_64_8                                  (5120, 1024)   15.36      693.333        27.98x              3.08x
cuda_64_16                                 (5120, 1024)   15.36      693.333        27.98x              3.08x
cuda_128_4                                 (5120, 1024)   19.46      547.368        22.09x              2.43x
cuda_128_8                                 (5120, 1024)   19.46      547.368        22.09x              2.43x
cuda_128_16                                (5120, 1024)   17.41      611.765        24.69x              2.72x
>>> BEST: 3.08x vs triton with cuda_64_8

torch                                      (5120, 2048)   439.90     48.045
triton                                     (5120, 2048)   46.11      458.348        9.54x
cuda_64_4                                  (5120, 2048)   19.49      1084.532       22.57x              2.37x
cuda_64_8                                  (5120, 2048)   21.50      982.857        20.46x              2.14x
cuda_64_16                                 (5120, 2048)   21.50      982.857        20.46x              2.14x
cuda_128_4                                 (5120, 2048)   23.55      897.391        18.68x              1.96x
cuda_128_8                                 (5120, 2048)   23.55      897.391        18.68x              1.96x
cuda_128_16                                (5120, 2048)   23.49      899.837        18.73x              1.96x
>>> BEST: 2.37x vs triton with cuda_64_4

torch                                      (5120, 4096)   440.38     95.614
triton                                     (5120, 4096)   70.66      595.942        6.23x
cuda_64_4                                  (5120, 4096)   25.60      1644.8         17.20x              2.76x
cuda_64_8                                  (5120, 4096)   27.65      1522.963       15.93x              2.56x
cuda_64_16                                 (5120, 4096)   29.70      1417.931       14.83x              2.38x
cuda_128_4                                 (5120, 4096)   29.70      1417.931       14.83x              2.38x
cuda_128_8                                 (5120, 4096)   31.74      1326.452       13.87x              2.23x
cuda_128_16                                (5120, 4096)   33.63      1251.989       13.09x              2.10x
>>> BEST: 2.76x vs triton with cuda_64_4

torch                                      (7168, 1024)   431.89     34.522
triton                                     (7168, 1024)   48.42      307.944        8.92x
cuda_64_4                                  (7168, 1024)   17.44      854.899        24.76x              2.78x
cuda_64_8                                  (7168, 1024)   15.52      960.66         27.83x              3.12x
cuda_64_16                                 (7168, 1024)   17.44      854.899        24.76x              2.78x
cuda_128_4                                 (7168, 1024)   19.49      765.057        22.16x              2.48x
cuda_128_8                                 (7168, 1024)   19.49      765.057        22.16x              2.48x
cuda_128_16                                (7168, 1024)   19.49      765.057        22.16x              2.48x
>>> BEST: 3.12x vs triton with cuda_64_8

torch                                      (7168, 2048)   432.21     68.461
triton                                     (7168, 2048)   52.26      566.241        8.27x
cuda_64_4                                  (7168, 2048)   21.54      1373.955       20.07x              2.43x
cuda_64_8                                  (7168, 2048)   23.55      1256.348       18.35x              2.22x
cuda_64_16                                 (7168, 2048)   25.63      1154.397       16.86x              2.04x
cuda_128_4                                 (7168, 2048)   25.63      1154.397       16.86x              2.04x
cuda_128_8                                 (7168, 2048)   27.62      1071.462       15.65x              1.89x
cuda_128_16                                (7168, 2048)   23.58      1254.643       18.33x              2.22x
>>> BEST: 2.43x vs triton with cuda_64_4

torch                                      (7168, 4096)   477.47     123.462
triton                                     (7168, 4096)   78.88      747.333        6.05x
cuda_64_4                                  (7168, 4096)   31.71      1858.906       15.06x              2.49x
cuda_64_8                                  (7168, 4096)   31.74      1857.032       15.04x              2.48x
cuda_64_16                                 (7168, 4096)   35.87      1643.333       13.31x              2.20x
cuda_128_4                                 (7168, 4096)   35.84      1644.8         13.32x              2.20x
cuda_128_8                                 (7168, 4096)   37.86      1557.207       12.61x              2.08x
cuda_128_16                                (7168, 4096)   37.73      1562.49        12.66x              2.09x
>>> BEST: 2.49x vs triton with cuda_64_4

torch                                      (2048, 1024)   434.90     9.795
triton                                     (2048, 1024)   48.61      87.637         8.95x
cuda_64_4                                  (2048, 1024)   15.17      280.844        28.67x              3.20x
cuda_64_8                                  (2048, 1024)   15.23      279.664        28.55x              3.19x
cuda_64_16                                 (2048, 1024)   14.43      295.166        30.13x              3.37x
cuda_128_4                                 (2048, 1024)   17.31      246.063        25.12x              2.81x
cuda_128_8                                 (2048, 1024)   15.23      279.664        28.55x              3.19x
cuda_128_16                                (2048, 1024)   15.20      280.253        28.61x              3.20x
>>> BEST: 3.37x vs triton with cuda_64_16

torch                                      (2048, 2048)   430.05     19.659
triton                                     (2048, 2048)   50.18      168.49         8.57x
cuda_64_4                                  (2048, 2048)   15.36      550.4          28.00x              3.27x
cuda_64_8                                  (2048, 2048)   17.44      484.756        24.66x              2.88x
cuda_64_16                                 (2048, 2048)   21.50      393.143        20.00x              2.33x
cuda_128_4                                 (2048, 2048)   17.41      485.647        24.70x              2.88x
cuda_128_8                                 (2048, 2048)   19.52      433.102        22.03x              2.57x
cuda_128_16                                (2048, 2048)   19.46      434.526        22.10x              2.58x
>>> BEST: 3.27x vs triton with cuda_64_4

torch                                      (2048, 4096)   425.26     39.605
triton                                     (2048, 4096)   52.06      323.501        8.17x
cuda_64_4                                  (2048, 4096)   17.44      965.754        24.38x              2.99x
cuda_64_8                                  (2048, 4096)   19.49      864.263        21.82x              2.67x
cuda_64_16                                 (2048, 4096)   23.58      714.16         18.03x              2.21x
cuda_128_4                                 (2048, 4096)   17.57      958.718        24.21x              2.96x
cuda_128_8                                 (2048, 4096)   23.55      715.13         18.06x              2.21x
cuda_128_16                                (2048, 4096)   21.54      782.074        19.75x              2.42x
>>> BEST: 2.99x vs triton with cuda_64_4

Super detailed explanation for anyone interested

Loads from GMEM

Screenshot 2025-12-19 at 2 07 52 PM

Stores to SMEM (ignoring XOR swizzle for avoiding bank conflicts on writes)

Screenshot 2025-12-19 at 2 33 59 PM

Dual XOR swizzle to base SMEM address for each thread to resolve bank conflicts on both writes and reads

(pasting from a google doc blog post draft for this description, so the bit highlighting is visible)
Screenshot 2025-12-19 at 2 35 49 PM

Screenshot 2025-12-19 at 2 09 23 PM

Copy SMEM to GMEM

Here we are doing a linear read pattern from SMEM, doing vectorized coalesced uint4 loads from shared memory and storing in global memory. The issue here is 4 way bank conflicts, but I found no good solution here. We have 32 banks of 4 bytes each, so that's 128 total bytes before wrapping around and hitting the first bank again. This means with 8 threads reading 16 bytes each, 8 * 16 = 128, so every 8 threads in a warp will wrap around and experience a bank conflict for a total of a 4-way bank conflict per warp.

The best solution I found for this was a suggestion from Claude, actually, to basically add a second layer of XOR that shifts the bank every 128 bytes within a 512-byte tile (128 * 4 = 512). Composing this XOR with the existing XOR does avoid bank conflicts for both reads and writes. Composing two XORs like this forms a "Latin square" kind of like Sudoku where given a tile_xor, we have unique banks for every superrow xor (and vice versa).

(diagram coming)

Pipeline async load of chunk N+1 with processing of chunk N

This kernel is instruction-heavy and has a lot of complicated pointer math. It does the prefix sums, then does the lookup of which group this thread block is operating on and block layout transformations. With NCU, I saw the ALU was highly utilized. Therefore, I thought it would be useful to not block the next load from global memory on all of this heavy pointer math going on. So I implemented a pipelined approach with double buffering where we can overlap the load of the next chunk from global memory with the computation needed on the current chunk. Compared to a non-pipelined approach, this approach showed solid improvements on large shapes and neutral for small and medium shapes.

I templated the number of chunks per thread block, benchmarked a few values, and have just hard-coded the config that was best.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3504

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 243017f with merge base 7035fb7 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 17, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from 26e079e to 0dfef18 Compare December 17, 2025 23:53
@danielvegamyhre danielvegamyhre added mx topic: not user facing Use this tag if you don't want this PR to show up in release notes moe labels Dec 17, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft December 19, 2025 01:55
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from 0dfef18 to 11101bc Compare December 19, 2025 01:55
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 19, 2025 01:55
@danielvegamyhre danielvegamyhre marked this pull request as draft December 19, 2025 01:59
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from 11101bc to 7497fcf Compare December 19, 2025 02:00
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 19, 2025 02:00
@danielvegamyhre danielvegamyhre marked this pull request as draft December 19, 2025 02:01
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from 7497fcf to c828ad0 Compare December 19, 2025 02:02
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 19, 2025 02:02
@danielvegamyhre danielvegamyhre marked this pull request as draft December 20, 2025 01:23
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from c828ad0 to 24cbddd Compare December 20, 2025 01:24
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 20, 2025 01:24
@danielvegamyhre danielvegamyhre marked this pull request as draft December 20, 2025 19:40
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 20, 2025 19:40
@danielvegamyhre danielvegamyhre marked this pull request as draft December 20, 2025 19:43
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 20, 2025 19:45
@danielvegamyhre danielvegamyhre marked this pull request as draft December 21, 2025 02:12
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from 24cbddd to 492b8ce Compare December 21, 2025 02:13
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 21, 2025 02:13
@danielvegamyhre danielvegamyhre marked this pull request as draft December 21, 2025 02:16
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 21, 2025 02:16
@danielvegamyhre danielvegamyhre marked this pull request as draft December 21, 2025 02:19
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from 492b8ce to 29a521a Compare December 21, 2025 02:19
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 21, 2025 02:20
@danielvegamyhre danielvegamyhre marked this pull request as draft December 21, 2025 02:34
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from 29a521a to afc2997 Compare December 21, 2025 02:34
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 21, 2025 02:34
@danielvegamyhre danielvegamyhre marked this pull request as draft December 22, 2025 01:19
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 22, 2025 01:19
@danielvegamyhre danielvegamyhre marked this pull request as draft December 22, 2025 01:25
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 22, 2025 01:26
…e factors to blocked layout

stack-info: PR: #3504, branch: danielvegamyhre/stack/86
@danielvegamyhre danielvegamyhre marked this pull request as draft December 22, 2025 01:29
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/86 branch from afc2997 to 243017f Compare December 22, 2025 01:30
@danielvegamyhre danielvegamyhre marked this pull request as ready for review December 22, 2025 01:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants