diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index a0acc36715a8..35f725f9b696 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -156,6 +156,79 @@ def _w8a8_block_fp8_matmul( tl.store(c_ptrs, c, mask=c_mask) +@triton.jit +def _w8a8_block_fp8_matmul_per_tensor( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with per-tensor quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + scale_a = tl.load(As) + scale_b = tl.load(Bs) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + accumulator += tl.dot(a, b) * scale_a * scale_b + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + def w8a8_block_fp8_matmul_triton( A: torch.Tensor, B: torch.Tensor, @@ -179,19 +252,31 @@ def w8a8_block_fp8_matmul_triton( Returns: torch.Tensor: The result of matmul. """ - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] + if block_size is None: + block_n, block_k = 128, 128 + else: + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + # if we have per-tensor quantization, we use 128x128 block size for tiled matmul multiplication + if block_n == B.shape[-2] and block_k == B.shape[-1]: + block_n = 128 + block_k = 128 assert A.shape[-1] == B.shape[-1] - assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + if As.numel() != 1: + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 N, K = B.shape - assert triton.cdiv(N, block_n) == Bs.shape[0], f"{N}, {block_n}, {Bs.shape}" - assert triton.cdiv(K, block_k) == Bs.shape[1], f"{K}, {block_k}, {Bs.shape}" + assert B.ndim == 2 and B.is_contiguous() + if Bs.numel() != 1: + assert Bs.ndim == 2 + assert triton.cdiv(N, block_n) == Bs.shape[0], f"{N}, {block_n}, {Bs.shape}" + assert triton.cdiv(K, block_k) == Bs.shape[1], f"{K}, {block_k}, {Bs.shape}" C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) @@ -207,32 +292,56 @@ def w8a8_block_fp8_matmul_triton( def grid(META): return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),) - _w8a8_block_fp8_matmul[grid]( - A, - B, - C, - As, - Bs, - M, - N, - K, - block_n, - block_k, - A.stride(-2), - A.stride(-1), - B.stride(1), - B.stride(0), - C.stride(-2), - C.stride(-1), - As.stride(-2), - As.stride(-1), - Bs.stride(1), - Bs.stride(0), - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, - GROUP_SIZE_M=8, - ) + if As.numel() == 1 and Bs.numel() == 1: + _w8a8_block_fp8_matmul_per_tensor[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + ) + else: + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=8, + ) return C @@ -356,23 +465,18 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if self.activation_scheme == "dynamic": qinput, scale = act_quant(input, self.block_size[1]) elif self.activation_scheme == "static": - scale = self.activation_scale + scale = self.activation_scale.to(torch.float32) qinput = (input / scale).to(torch.float8_e4m3fn) else: raise NotImplementedError("Not supported") - # TODO: fix this later to use the triton kernel - if self.activation_scheme == "static": - output = F.linear(qinput.to(torch.bfloat16), weight.to(torch.bfloat16), None) * scale_inv * scale - output = output.to(input.dtype) - else: - output = w8a8_block_fp8_matmul_triton( - qinput, - weight, - scale, - scale_inv, - self.block_size, - output_dtype=input.dtype, - ) + output = w8a8_block_fp8_matmul_triton( + qinput, + weight, + scale, + scale_inv, + self.block_size, + output_dtype=input.dtype, + ) # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the # preceding operations are ready before proceeding