-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[Quantization] per tensor quantization kernel #42560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+151
−47
Merged
Changes from 9 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
425b7ec
fix
MekkCyber 924de16
style
MekkCyber 8a69f40
Merge branch 'main' into fix-deqant-fp8
MekkCyber e052da3
initial
MekkCyber 62c8601
Merge remote-tracking branch 'upstream/fix-deqant-fp8' into use-kerne…
MekkCyber fe3359d
fix
MekkCyber 75f7e6f
comment
MekkCyber 1738ca0
style
MekkCyber 78c5459
Merge remote-tracking branch 'upstream/HEAD' into use-kernel-fp8
MekkCyber 2144e7c
Merge remote-tracking branch 'upstream/HEAD' into use-kernel-fp8
MekkCyber 033e535
fix
MekkCyber File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,6 +156,79 @@ def _w8a8_block_fp8_matmul( | |
| tl.store(c_ptrs, c, mask=c_mask) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _w8a8_block_fp8_matmul_per_tensor( | ||
| # Pointers to inputs and output | ||
| A, | ||
| B, | ||
| C, | ||
| As, | ||
| Bs, | ||
| # Shape for matmul | ||
| M, | ||
| N, | ||
| K, | ||
| # Block size for block-wise quantization | ||
| group_n, | ||
| group_k, | ||
| # Stride for inputs and output | ||
| stride_am, | ||
| stride_ak, | ||
| stride_bk, | ||
| stride_bn, | ||
| stride_cm, | ||
| stride_cn, | ||
| # Meta-parameters | ||
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| BLOCK_SIZE_K: tl.constexpr, | ||
| GROUP_SIZE_M: tl.constexpr, | ||
| ): | ||
| """Triton-accelerated function used to perform linear operations (dot | ||
| product) on input tensors `A` and `B` with block-wise quantization, and | ||
| store the result in output tensor `C`. | ||
| """ | ||
|
|
||
| pid = tl.program_id(axis=0) | ||
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | ||
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
| num_pid_in_group = GROUP_SIZE_M * num_pid_n | ||
| group_id = pid // num_pid_in_group | ||
| first_pid_m = group_id * GROUP_SIZE_M | ||
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | ||
| pid_m = first_pid_m + (pid % group_size_m) | ||
| pid_n = (pid % num_pid_in_group) // group_size_m | ||
|
|
||
| offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
| offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
| offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
| a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | ||
| b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | ||
| scale_a = tl.load(As) | ||
| scale_b = tl.load(Bs) | ||
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | ||
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | ||
| a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | ||
| b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | ||
|
|
||
| accumulator += tl.dot(a, b) * scale_a * scale_b | ||
| a_ptrs += BLOCK_SIZE_K * stride_ak | ||
| b_ptrs += BLOCK_SIZE_K * stride_bk | ||
|
|
||
| if C.dtype.element_ty == tl.bfloat16: | ||
| c = accumulator.to(tl.bfloat16) | ||
| elif C.dtype.element_ty == tl.float16: | ||
| c = accumulator.to(tl.float16) | ||
| else: | ||
| c = accumulator.to(tl.float32) | ||
|
|
||
| offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
| offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
| c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | ||
| c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | ||
| tl.store(c_ptrs, c, mask=c_mask) | ||
|
|
||
|
|
||
| def w8a8_block_fp8_matmul_triton( | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
|
|
@@ -182,16 +255,25 @@ def w8a8_block_fp8_matmul_triton( | |
| assert len(block_size) == 2 | ||
| block_n, block_k = block_size[0], block_size[1] | ||
|
|
||
| # if we have per-tensor quantization, we use 128x128 block size for tiled matmul multiplication | ||
| if block_n == B.shape[-2] and block_k == B.shape[-1]: | ||
| block_n = 128 | ||
| block_k = 128 | ||
|
|
||
|
Comment on lines
182
to
265
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it doesn't make sense before to set blocks to something else than None when doing per tensor in the FP8Linear. Can we change that so that we fix it here also ? |
||
| assert A.shape[-1] == B.shape[-1] | ||
|
|
||
| assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() | ||
| assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] | ||
| if As.numel() != 1: | ||
| assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() | ||
| assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] | ||
|
|
||
| M = A.numel() // A.shape[-1] | ||
|
|
||
| assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 | ||
| N, K = B.shape | ||
| assert triton.cdiv(N, block_n) == Bs.shape[0], f"{N}, {block_n}, {Bs.shape}" | ||
| assert triton.cdiv(K, block_k) == Bs.shape[1], f"{K}, {block_k}, {Bs.shape}" | ||
| assert B.ndim == 2 and B.is_contiguous() | ||
| if Bs.numel() != 1: | ||
| assert Bs.ndim == 2 | ||
| assert triton.cdiv(N, block_n) == Bs.shape[0], f"{N}, {block_n}, {Bs.shape}" | ||
| assert triton.cdiv(K, block_k) == Bs.shape[1], f"{K}, {block_k}, {Bs.shape}" | ||
|
|
||
| C_shape = A.shape[:-1] + (N,) | ||
| C = A.new_empty(C_shape, dtype=output_dtype) | ||
|
|
@@ -207,32 +289,56 @@ def w8a8_block_fp8_matmul_triton( | |
| def grid(META): | ||
| return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) | ||
|
|
||
| _w8a8_block_fp8_matmul[grid]( | ||
| A, | ||
| B, | ||
| C, | ||
| As, | ||
| Bs, | ||
| M, | ||
| N, | ||
| K, | ||
| block_n, | ||
| block_k, | ||
| A.stride(-2), | ||
| A.stride(-1), | ||
| B.stride(1), | ||
| B.stride(0), | ||
| C.stride(-2), | ||
| C.stride(-1), | ||
| As.stride(-2), | ||
| As.stride(-1), | ||
| Bs.stride(1), | ||
| Bs.stride(0), | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| BLOCK_SIZE_K=BLOCK_SIZE_K, | ||
| GROUP_SIZE_M=8, | ||
| ) | ||
| if As.numel() == 1 and Bs.numel() == 1: | ||
| _w8a8_block_fp8_matmul_per_tensor[grid]( | ||
| A, | ||
| B, | ||
| C, | ||
| As, | ||
| Bs, | ||
| M, | ||
| N, | ||
| K, | ||
| block_n, | ||
| block_k, | ||
| A.stride(-2), | ||
| A.stride(-1), | ||
| B.stride(1), | ||
| B.stride(0), | ||
| C.stride(-2), | ||
| C.stride(-1), | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| BLOCK_SIZE_K=BLOCK_SIZE_K, | ||
| GROUP_SIZE_M=8, | ||
| ) | ||
| else: | ||
| _w8a8_block_fp8_matmul[grid]( | ||
| A, | ||
| B, | ||
| C, | ||
| As, | ||
| Bs, | ||
| M, | ||
| N, | ||
| K, | ||
| block_n, | ||
| block_k, | ||
| A.stride(-2), | ||
| A.stride(-1), | ||
| B.stride(1), | ||
| B.stride(0), | ||
| C.stride(-2), | ||
| C.stride(-1), | ||
| As.stride(-2), | ||
| As.stride(-1), | ||
| Bs.stride(1), | ||
| Bs.stride(0), | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| BLOCK_SIZE_K=BLOCK_SIZE_K, | ||
| GROUP_SIZE_M=8, | ||
| ) | ||
|
|
||
| return C | ||
|
|
||
|
|
@@ -360,23 +466,19 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| if self.activation_scheme == "dynamic": | ||
| qinput, scale = act_quant(input, self.block_size[1]) | ||
| elif self.activation_scheme == "static": | ||
| scale = self.activation_scale | ||
| scale = self.activation_scale.to(torch.float32) | ||
| qinput = (input / scale).to(torch.float8_e4m3fn) | ||
| else: | ||
| raise NotImplementedError("Not supported") | ||
| # TODO: fix this later to use the triton kernel | ||
| if self.activation_scheme == "static": | ||
| output = F.linear(qinput.to(torch.bfloat16), weight.to(torch.bfloat16), None) * scale_inv * scale | ||
| output = output.to(input.dtype) | ||
| else: | ||
| output = w8a8_block_fp8_matmul_triton( | ||
| qinput, | ||
| weight, | ||
| scale, | ||
| scale_inv, | ||
| self.block_size, | ||
| output_dtype=input.dtype, | ||
| ) | ||
| output = w8a8_block_fp8_matmul_triton( | ||
| qinput, | ||
| weight, | ||
| scale, | ||
| scale_inv, | ||
| self.block_size, | ||
| output_dtype=input.dtype, | ||
| ) | ||
|
|
||
| # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the | ||
| # preceding operations are ready before proceeding | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update