Skip to content

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Oct 29, 2025

Summary:

Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:

  1. updates the triton kernels for this scaling type to (a) be importable in an env without triton (for CI), and (b) adds compile support for the gemm
  2. enables the new granularity in various utility functions
  3. wires the new granularity through the float8 inference configs
  4. adds a test which tests for e2e numerical correctness via SQNR
    comparison vs high precision baseline

For now we only have fallback kernels which requires triton and are numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:

  1. we should map the gemm to torch._scaled_mm for CUDA 12.9+
  2. we should enable an fbgemm_gpu_genai path, if available in user env
  3. we should map to a triton kernel for quantizing the weights, as
    torch.compile is currently known slow for 128x128 block
    quantization

Further accuracy testing and enablement of more features is left for future PRs, to keep PR size small.

Test Plan:

pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -s -x
pytest test/dtypes/test_affine_quantized_float.py -s -x

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo added 3 commits October 29, 2025 04:05
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Oct 29, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3257

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit c4769a6 with merge base 1e473ed (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Oct 29, 2025
Summary:

Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:
1. bring the 128x128 gemm triton kernel we have out of prototype and
   wrap it with a custom op for `torch.compile` compatibility
2. enable the new granularity in various utility functions
3. wire the new granularity through the float8 inference configs
4. add a test which tests for e2e numerical correctness via SQNR
   comparison vs high precision baseline

For now I added a fallback which only requires triton and is numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:
1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+
2. we should enable an fbgemm_gpu_genai path, if available in user env
3. we should map to a triton kernel for quantizing the weights, as
   `torch.compile` is currently known slow for 128x128 block
   quantization

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: db464e1
ghstack-comment-id: 3460951962
Pull-Request: #3257
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2025
@vkuzo vkuzo changed the title Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference (wip) Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference Oct 29, 2025
@vkuzo vkuzo added the topic: new feature Use this tag if this PR adds a new feature label Oct 29, 2025
vkuzo added 2 commits October 29, 2025 07:07
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 29, 2025
Summary:

Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:
1. bring the 128x128 gemm triton kernel we have out of prototype and
   wrap it with a custom op for `torch.compile` compatibility
2. enable the new granularity in various utility functions
3. wire the new granularity through the float8 inference configs
4. add a test which tests for e2e numerical correctness via SQNR
   comparison vs high precision baseline

For now I added a fallback which only requires triton and is numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:
1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+
2. we should enable an fbgemm_gpu_genai path, if available in user env
3. we should map to a triton kernel for quantizing the weights, as
   `torch.compile` is currently known slow for 128x128 block
   quantization

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: c9e22bd
ghstack-comment-id: 3460951962
Pull-Request: #3257
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 29, 2025
Summary:

Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:
1. bring the 128x128 gemm triton kernel we have out of prototype and
   wrap it with a custom op for `torch.compile` compatibility
2. enable the new granularity in various utility functions
3. wire the new granularity through the float8 inference configs
4. add a test which tests for e2e numerical correctness via SQNR
   comparison vs high precision baseline

For now I added a fallback which only requires triton and is numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:
1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+
2. we should enable an fbgemm_gpu_genai path, if available in user env
3. we should map to a triton kernel for quantizing the weights, as
   `torch.compile` is currently known slow for 128x128 block
   quantization

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 802d26f
ghstack-comment-id: 3460951962
Pull-Request: #3257
@vkuzo vkuzo changed the title (wip) Add a_1_128_w_128_128 (DeepSeek style) float8 scaling for inference skeleton of a_1_128_w_128_128 (DeepSeek) float8 scaling for inference Oct 29, 2025
@vkuzo vkuzo changed the title skeleton of a_1_128_w_128_128 (DeepSeek) float8 scaling for inference add a_1_128_w_128_128 (DeepSeek) float8 scaling for inference Oct 29, 2025
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 29, 2025
Summary:

Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:
1. bring the 128x128 gemm triton kernel we have out of prototype and
   wrap it with a custom op for `torch.compile` compatibility
2. enable the new granularity in various utility functions
3. wire the new granularity through the float8 inference configs
4. add a test which tests for e2e numerical correctness via SQNR
   comparison vs high precision baseline

For now I added a fallback which only requires triton and is numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:
1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+
2. we should enable an fbgemm_gpu_genai path, if available in user env
3. we should map to a triton kernel for quantizing the weights, as
   `torch.compile` is currently known slow for 128x128 block
   quantization

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 81e336e
ghstack-comment-id: 3460951962
Pull-Request: #3257
triton.cdiv(N, meta["BLOCK_SIZE"]),
from torch.utils._triton import has_triton

if has_triton():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of the changes in this file is just indent from adding the if has_triton() statement

mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)

@torch.library.custom_op("ao::blockwise_fp8_gemm", mutates_args=())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-indent change 1

)
return c

@blockwise_fp8_gemm.register_fake
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-indent change 2

fp8_blockwise_weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y

else:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-indent change 3

vkuzo added 3 commits October 30, 2025 04:17
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/157/head to main October 30, 2025 11:18
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Oct 30, 2025
Summary:

Basic enablement of the a_1_128_w_128_128 float8 scaling recipe in
torchao inference. In detail:
1. bring the 128x128 gemm triton kernel we have out of prototype and
   wrap it with a custom op for `torch.compile` compatibility
2. enable the new granularity in various utility functions
3. wire the new granularity through the float8 inference configs
4. add a test which tests for e2e numerical correctness via SQNR
   comparison vs high precision baseline

For now I added a fallback which only requires triton and is numerically
correct but may not reach optimal performance. Performance optimization is
left for future PRs:
1. we should map the gemm to `torch._scaled_mm` for CUDA 12.9+
2. we should enable an fbgemm_gpu_genai path, if available in user env
3. we should map to a triton kernel for quantizing the weights, as
   `torch.compile` is currently known slow for 128x128 block
   quantization

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
ghstack-source-id: 8d58dfe
ghstack-comment-id: 3460951962
Pull-Request: #3257
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, I think we can also keep the 1x128 block to align with what deepseek is calling it now

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh since we are changing the meaning of PerBlock, please update the doc a bit as well:

class PerBlock(Granularity):

vkuzo added 2 commits October 31, 2025 06:26
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Oct 31, 2025

CI failures exist on main branch, landing

@vkuzo vkuzo merged commit b49178c into main Oct 31, 2025
44 of 50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants