|
2 | 2 | from typing import Optional |
3 | 3 |
|
4 | 4 | import torch |
| 5 | +import torch.nn.functional as F |
5 | 6 | from torch import nn |
6 | 7 |
|
7 | 8 | from rtp_llm.config.quant_config import QuantizationConfig |
8 | | -from rtp_llm.models_py.modules import utils |
9 | | - |
10 | | -logger = logging.getLogger(__name__) |
11 | | - |
12 | 9 | from rtp_llm.models_py.modules.fp8_kernel import ( |
13 | 10 | scaled_fp8_per_tensor_quant, |
14 | 11 | sgl_per_token_group_quant_fp8, |
15 | 12 | ) |
16 | | -from rtp_llm.models_py.modules.quantization.deepgemm_wrapper import ( |
17 | | - fp8_gemm_nt, |
18 | | - has_deep_gemm, |
19 | | -) |
| 13 | +from rtp_llm.models_py.modules.quantization.deepgemm_wrapper import fp8_gemm_nt |
20 | 14 |
|
| 15 | +logger = logging.getLogger(__name__) |
21 | 16 |
|
22 | | -class Fp8DeepGEMMLinear(nn.Module): |
23 | | - """FP8 Linear layer with DeepGEMM quantized matrix multiplication.""" |
24 | 17 |
|
| 18 | +class Fp8PerBlockLinear(nn.Module): |
25 | 19 | def __init__( |
26 | 20 | self, |
27 | 21 | weight: torch.Tensor, |
28 | | - weight_scales: torch.Tensor, |
| 22 | + weight_scale: torch.Tensor, |
29 | 23 | bias: Optional[torch.Tensor] = None, |
30 | | - config=None, |
31 | | - ) -> None: |
| 24 | + ): |
32 | 25 | super().__init__() |
33 | | - if not has_deep_gemm(): |
34 | | - raise RuntimeError( |
35 | | - "DeepGEMM is not available. Please install the `deep_gemm` package to enable DeepGEMM kernels." |
36 | | - ) |
37 | | - self.hidden_size = weight.shape[0] # k |
38 | | - self.output_size = weight.shape[1] # n |
39 | | - self.weight = weight.reshape([weight.shape[1], weight.shape[0]]) |
40 | | - self.weight_scales = weight_scales.reshape( |
41 | | - [weight_scales.shape[1], weight_scales.shape[0]] |
42 | | - ) |
| 26 | + # Initialize attributes |
| 27 | + self.weight = weight |
| 28 | + self.weight_scale = weight_scale |
43 | 29 | self.bias = bias |
44 | | - |
45 | | - def forward(self, input: torch.Tensor) -> torch.Tensor: |
46 | | - # Get input dimensions |
47 | | - input_m = input.shape[0] |
48 | | - input_k = input.shape[1] |
49 | | - output_n = self.output_size |
50 | | - |
51 | | - # Check input dtype - only accept BF16 |
52 | | - if input.dtype != torch.bfloat16: |
| 30 | + # Check weight and weight scale dimensions |
| 31 | + if self.weight.dim() != 2 or self.weight_scale.dim() != 2: |
| 32 | + error_msg = f"Weight and weight scale must be 2D tensors, but got weight dim: {self.weight.dim()} and weight scale dim: {self.weight_scale.dim()}" |
| 33 | + logger.error(error_msg) |
| 34 | + raise ValueError(error_msg) |
| 35 | + # Reshape weight and weight scale |
| 36 | + self.K, self.N = self.weight.shape |
| 37 | + self.scale_K, self.scale_N = self.weight_scale.shape |
| 38 | + self.weight = self.weight.reshape(self.N, self.K) |
| 39 | + self.weight_scale = self.weight_scale.reshape(self.scale_N, self.scale_K) |
| 40 | + # Check weight scale sizes |
| 41 | + if self.scale_N * 128 != self.N or self.scale_K * 128 != self.K: |
| 42 | + error_msg = f"Weight scale dimension mismatch! Expected N: {self.N}, got {self.scale_N * 128}, expected K: {self.K}, got {self.scale_K * 128}" |
| 43 | + logger.error(error_msg) |
| 44 | + raise ValueError(error_msg) |
| 45 | + # Check weight and weight scale dtypes |
| 46 | + if self.weight.dtype != torch.float8_e4m3fn: |
| 47 | + error_msg = f"Weight dtype must be float8_e4m3fn, got {self.weight.dtype}" |
| 48 | + logger.error(error_msg) |
| 49 | + raise ValueError(error_msg) |
| 50 | + if self.weight_scale.dtype != torch.float32: |
53 | 51 | error_msg = ( |
54 | | - f"Fp8DeepGEMMLinear only accepts bfloat16 input, but got {input.dtype}. " |
55 | | - "Please convert input to bfloat16 before calling Fp8DeepGEMMLinear." |
| 52 | + f"Weight scale dtype must be float32, got {self.weight_scale.dtype}" |
56 | 53 | ) |
57 | 54 | logger.error(error_msg) |
58 | 55 | raise ValueError(error_msg) |
59 | | - |
60 | | - input_bf16 = input |
61 | | - |
62 | | - # Quantize input to FP8 |
63 | | - |
64 | | - alignment = self._get_padding_size(input_m) |
65 | | - target_m = (input_m + alignment - 1) // alignment * alignment |
66 | | - need_padding = target_m > input_m |
67 | | - |
68 | | - if need_padding: |
69 | | - input_for_quant = torch.zeros( |
70 | | - target_m, input_k, dtype=torch.bfloat16, device=input.device |
71 | | - ) |
72 | | - input_for_quant[:input_m, :] = input_bf16 |
73 | | - else: |
74 | | - input_for_quant = input_bf16 |
75 | | - |
76 | | - # Quantize using sgl_per_token_group_quant_fp8 |
77 | | - quantization_eps = 1e-4 |
78 | | - input_fp8, input_scales = sgl_per_token_group_quant_fp8( |
79 | | - input_for_quant, |
| 56 | + # Check bias |
| 57 | + if self.bias is not None: |
| 58 | + if self.bias.dim() != 1 and self.bias.dim() != 2: |
| 59 | + error_msg = f"Bias dimension must be 1 or 2, got {self.bias.dim()}" |
| 60 | + logger.error(error_msg) |
| 61 | + raise ValueError(error_msg) |
| 62 | + if self.bias.shape[-1] != self.N: |
| 63 | + error_msg = ( |
| 64 | + f"Bias last dimension must be {self.N}, got {self.bias.shape[-1]}" |
| 65 | + ) |
| 66 | + logger.error(error_msg) |
| 67 | + raise ValueError(error_msg) |
| 68 | + if self.bias.dim() == 2 and self.bias.shape[0] != 1: |
| 69 | + error_msg = f"Bias first dimension must be 1, got {self.bias.shape[0]}" |
| 70 | + logger.error(error_msg) |
| 71 | + raise ValueError(error_msg) |
| 72 | + if self.bias.dtype != torch.bfloat16: |
| 73 | + error_msg = f"Bias dtype must be bfloat16, got {self.bias.dtype}" |
| 74 | + logger.error(error_msg) |
| 75 | + raise ValueError(error_msg) |
| 76 | + |
| 77 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 78 | + # Check input dtype - only accept bfloat16 |
| 79 | + if x.dtype != torch.bfloat16: |
| 80 | + error_msg = f"Input tensor dtype must be bfloat16, got {x.dtype}" |
| 81 | + logger.error(error_msg) |
| 82 | + raise ValueError(error_msg) |
| 83 | + # Check input tensor dimension |
| 84 | + if x.dim() != 2: |
| 85 | + error_msg = f"Input tensor dimension must be 2, got {x.dim()}D tensor" |
| 86 | + logger.error(error_msg) |
| 87 | + raise ValueError(error_msg) |
| 88 | + M, K = x.shape |
| 89 | + # Check input tensor inner dimension expected to be K |
| 90 | + if K != self.K: |
| 91 | + error_msg = f"Input tensor inner dimension expected to be {self.K}, got {K}" |
| 92 | + logger.error(error_msg) |
| 93 | + raise ValueError(error_msg) |
| 94 | + # Quantize x to FP8 |
| 95 | + x_fp8, x_scales = sgl_per_token_group_quant_fp8( |
| 96 | + x, |
80 | 97 | group_size=128, |
81 | | - eps=quantization_eps, |
82 | | - column_major_scales=False, |
| 98 | + eps=1e-4, |
| 99 | + column_major_scales=True, |
| 100 | + scale_tma_aligned=True, |
83 | 101 | ) |
84 | | - |
85 | | - FP8_E4M3_MAX = 448.0 |
86 | | - min_scale_threshold = 1e-4 / FP8_E4M3_MAX |
87 | | - input_scales = torch.clamp(input_scales, min=min_scale_threshold) |
88 | | - input_scales = input_scales.to(torch.float32) |
89 | | - output_m = input_for_quant.shape[0] |
90 | | - output = torch.zeros( |
91 | | - output_m, output_n, dtype=torch.bfloat16, device=input.device |
| 102 | + # Prepare output tensor |
| 103 | + output = torch.empty(M, self.N, dtype=torch.bfloat16, device=x.device) |
| 104 | + # Invoke DeepGEMM |
| 105 | + fp8_gemm_nt( |
| 106 | + (x_fp8, x_scales), |
| 107 | + (self.weight, self.weight_scale), |
| 108 | + output, |
| 109 | + c=None, |
| 110 | + disable_ue8m0_cast=True, |
92 | 111 | ) |
93 | | - |
94 | | - # Call DeepGEMM |
95 | | - deepgemm_input_scales = input_scales |
96 | | - input_fp8 = input_fp8.contiguous() |
97 | | - deepgemm_input_scales = deepgemm_input_scales.contiguous() |
98 | | - weight = self.weight.contiguous() |
99 | | - weight_scales = self.weight_scales.contiguous() |
100 | | - output = output.contiguous() |
101 | | - try: |
102 | | - fp8_gemm_nt( |
103 | | - (input_fp8, deepgemm_input_scales), |
104 | | - (weight, weight_scales), |
105 | | - output, |
106 | | - c=None, |
107 | | - disable_ue8m0_cast=True, |
108 | | - ) |
109 | | - except Exception as e: |
110 | | - # DeepGEMM call failed - log error and re-raise |
111 | | - error_msg = f"DeepGEMM fp8_gemm_nt call failed: {type(e).__name__}: {e}" |
112 | | - logger.error(error_msg) |
113 | | - import traceback |
114 | | - |
115 | | - logger.error(f"Traceback: {traceback.format_exc()}") |
116 | | - raise RuntimeError(error_msg) from e |
117 | | - if need_padding: |
118 | | - output = output[:input_m, :] |
119 | 112 | if self.bias is not None: |
120 | | - output = output + self.bias.to(output.dtype) |
| 113 | + output = output + self.bias |
121 | 114 | return output |
122 | 115 |
|
123 | | - def _get_padding_size(self, m): |
124 | | - """Calculate padding size based on DeepGEMM requirements.""" |
125 | | - if self._gemm_swap_ab_heuristic(m): |
126 | | - if m < 16: |
127 | | - return 16 |
128 | | - else: |
129 | | - return 8 |
130 | | - else: |
131 | | - return 64 |
132 | | - |
133 | | - def _gemm_swap_ab_heuristic(self, m): |
134 | | - return False |
135 | | - |
136 | | - def _expand_input_scales(self, input_scales, target_shape): |
137 | | - """Expand input scales to target shape.""" |
138 | | - # input_scales: [m, k/128] - always row-major |
139 | | - # target_shape: [m, k] |
140 | | - m, k = target_shape |
141 | | - expected_scales_shape = (m, (k + 127) // 128) |
142 | | - if input_scales.shape != expected_scales_shape: |
143 | | - raise ValueError( |
144 | | - f"Input scales shape mismatch! Expected {expected_scales_shape}, got {input_scales.shape}" |
145 | | - ) |
146 | | - expanded = torch.zeros( |
147 | | - target_shape, dtype=input_scales.dtype, device=input_scales.device |
148 | | - ) |
149 | | - for i in range(input_scales.shape[0]): # m tokens |
150 | | - for j in range(input_scales.shape[1]): # k/128 groups |
151 | | - k_start = j * 128 |
152 | | - k_end = min((j + 1) * 128, k) |
153 | | - expanded[i, k_start:k_end] = input_scales[i, j] |
154 | | - return expanded |
155 | | - |
156 | | - def _expand_weight_scales(self): |
157 | | - """Expand weight scales to weight tensor shape.""" |
158 | | - expanded = torch.zeros_like(self.weight, dtype=torch.float32) |
159 | | - for i in range(self.weight_scales.shape[0]): # output_size blocks (60) |
160 | | - for j in range(self.weight_scales.shape[1]): # hidden_size blocks (20) |
161 | | - h_start = i * 128 # output_size dimension |
162 | | - h_end = min((i + 1) * 128, self.weight.shape[0]) |
163 | | - w_start = j * 128 # hidden_size dimension |
164 | | - w_end = min((j + 1) * 128, self.weight.shape[1]) |
165 | | - expanded[h_start:h_end, w_start:w_end] = self.weight_scales[i, j] |
166 | | - return expanded |
167 | | - |
168 | 116 |
|
169 | 117 | class Fp8PerTensorLinear(nn.Module): |
170 | 118 | def __init__( |
171 | 119 | self, |
172 | | - quant_config: QuantizationConfig, |
173 | 120 | weight: torch.Tensor, |
174 | 121 | weight_scale: torch.Tensor, |
175 | 122 | input_scale: Optional[torch.Tensor] = None, |
176 | 123 | bias: Optional[torch.Tensor] = None, |
| 124 | + quant_config: Optional[QuantizationConfig] = None, |
177 | 125 | ) -> None: |
178 | 126 | super().__init__() |
179 | 127 | self.weight = weight.T |
180 | 128 | self.weight_scale = weight_scale |
181 | 129 | self.input_scale = input_scale |
182 | 130 | self.bias = bias |
183 | | - self.block_quant = None |
| 131 | + self.quant_config = quant_config |
184 | 132 |
|
185 | 133 | def forward(self, input: torch.Tensor) -> torch.Tensor: |
186 | 134 |
|
|
0 commit comments