Skip to content

Commit 6931402

Browse files
committed
refactor: optimize fp8_per_block_linear performance
1 parent ac25f1d commit 6931402

File tree

4 files changed

+388
-298
lines changed

4 files changed

+388
-298
lines changed

rtp_llm/models_py/modules/fp8_linear.py

Lines changed: 87 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -2,185 +2,133 @@
22
from typing import Optional
33

44
import torch
5+
import torch.nn.functional as F
56
from torch import nn
67

78
from rtp_llm.config.quant_config import QuantizationConfig
8-
from rtp_llm.models_py.modules import utils
9-
10-
logger = logging.getLogger(__name__)
11-
129
from rtp_llm.models_py.modules.fp8_kernel import (
1310
scaled_fp8_per_tensor_quant,
1411
sgl_per_token_group_quant_fp8,
1512
)
16-
from rtp_llm.models_py.modules.quantization.deepgemm_wrapper import (
17-
fp8_gemm_nt,
18-
has_deep_gemm,
19-
)
13+
from rtp_llm.models_py.modules.quantization.deepgemm_wrapper import fp8_gemm_nt
2014

15+
logger = logging.getLogger(__name__)
2116

22-
class Fp8DeepGEMMLinear(nn.Module):
23-
"""FP8 Linear layer with DeepGEMM quantized matrix multiplication."""
2417

18+
class Fp8PerBlockLinear(nn.Module):
2519
def __init__(
2620
self,
2721
weight: torch.Tensor,
28-
weight_scales: torch.Tensor,
22+
weight_scale: torch.Tensor,
2923
bias: Optional[torch.Tensor] = None,
30-
config=None,
31-
) -> None:
24+
):
3225
super().__init__()
33-
if not has_deep_gemm():
34-
raise RuntimeError(
35-
"DeepGEMM is not available. Please install the `deep_gemm` package to enable DeepGEMM kernels."
36-
)
37-
self.hidden_size = weight.shape[0] # k
38-
self.output_size = weight.shape[1] # n
39-
self.weight = weight.reshape([weight.shape[1], weight.shape[0]])
40-
self.weight_scales = weight_scales.reshape(
41-
[weight_scales.shape[1], weight_scales.shape[0]]
42-
)
26+
# Initialize attributes
27+
self.weight = weight
28+
self.weight_scale = weight_scale
4329
self.bias = bias
44-
45-
def forward(self, input: torch.Tensor) -> torch.Tensor:
46-
# Get input dimensions
47-
input_m = input.shape[0]
48-
input_k = input.shape[1]
49-
output_n = self.output_size
50-
51-
# Check input dtype - only accept BF16
52-
if input.dtype != torch.bfloat16:
30+
# Check weight and weight scale dimensions
31+
if self.weight.dim() != 2 or self.weight_scale.dim() != 2:
32+
error_msg = f"Weight and weight scale must be 2D tensors, but got weight dim: {self.weight.dim()} and weight scale dim: {self.weight_scale.dim()}"
33+
logger.error(error_msg)
34+
raise ValueError(error_msg)
35+
# Reshape weight and weight scale
36+
self.K, self.N = self.weight.shape
37+
self.scale_K, self.scale_N = self.weight_scale.shape
38+
self.weight = self.weight.reshape(self.N, self.K)
39+
self.weight_scale = self.weight_scale.reshape(self.scale_N, self.scale_K)
40+
# Check weight scale sizes
41+
if self.scale_N * 128 != self.N or self.scale_K * 128 != self.K:
42+
error_msg = f"Weight scale dimension mismatch! Expected N: {self.N}, got {self.scale_N * 128}, expected K: {self.K}, got {self.scale_K * 128}"
43+
logger.error(error_msg)
44+
raise ValueError(error_msg)
45+
# Check weight and weight scale dtypes
46+
if self.weight.dtype != torch.float8_e4m3fn:
47+
error_msg = f"Weight dtype must be float8_e4m3fn, got {self.weight.dtype}"
48+
logger.error(error_msg)
49+
raise ValueError(error_msg)
50+
if self.weight_scale.dtype != torch.float32:
5351
error_msg = (
54-
f"Fp8DeepGEMMLinear only accepts bfloat16 input, but got {input.dtype}. "
55-
"Please convert input to bfloat16 before calling Fp8DeepGEMMLinear."
52+
f"Weight scale dtype must be float32, got {self.weight_scale.dtype}"
5653
)
5754
logger.error(error_msg)
5855
raise ValueError(error_msg)
59-
60-
input_bf16 = input
61-
62-
# Quantize input to FP8
63-
64-
alignment = self._get_padding_size(input_m)
65-
target_m = (input_m + alignment - 1) // alignment * alignment
66-
need_padding = target_m > input_m
67-
68-
if need_padding:
69-
input_for_quant = torch.zeros(
70-
target_m, input_k, dtype=torch.bfloat16, device=input.device
71-
)
72-
input_for_quant[:input_m, :] = input_bf16
73-
else:
74-
input_for_quant = input_bf16
75-
76-
# Quantize using sgl_per_token_group_quant_fp8
77-
quantization_eps = 1e-4
78-
input_fp8, input_scales = sgl_per_token_group_quant_fp8(
79-
input_for_quant,
56+
# Check bias
57+
if self.bias is not None:
58+
if self.bias.dim() != 1 and self.bias.dim() != 2:
59+
error_msg = f"Bias dimension must be 1 or 2, got {self.bias.dim()}"
60+
logger.error(error_msg)
61+
raise ValueError(error_msg)
62+
if self.bias.shape[-1] != self.N:
63+
error_msg = (
64+
f"Bias last dimension must be {self.N}, got {self.bias.shape[-1]}"
65+
)
66+
logger.error(error_msg)
67+
raise ValueError(error_msg)
68+
if self.bias.dim() == 2 and self.bias.shape[0] != 1:
69+
error_msg = f"Bias first dimension must be 1, got {self.bias.shape[0]}"
70+
logger.error(error_msg)
71+
raise ValueError(error_msg)
72+
if self.bias.dtype != torch.bfloat16:
73+
error_msg = f"Bias dtype must be bfloat16, got {self.bias.dtype}"
74+
logger.error(error_msg)
75+
raise ValueError(error_msg)
76+
77+
def forward(self, x: torch.Tensor) -> torch.Tensor:
78+
# Check input dtype - only accept bfloat16
79+
if x.dtype != torch.bfloat16:
80+
error_msg = f"Input tensor dtype must be bfloat16, got {x.dtype}"
81+
logger.error(error_msg)
82+
raise ValueError(error_msg)
83+
# Check input tensor dimension
84+
if x.dim() != 2:
85+
error_msg = f"Input tensor dimension must be 2, got {x.dim()}D tensor"
86+
logger.error(error_msg)
87+
raise ValueError(error_msg)
88+
M, K = x.shape
89+
# Check input tensor inner dimension expected to be K
90+
if K != self.K:
91+
error_msg = f"Input tensor inner dimension expected to be {self.K}, got {K}"
92+
logger.error(error_msg)
93+
raise ValueError(error_msg)
94+
# Quantize x to FP8
95+
x_fp8, x_scales = sgl_per_token_group_quant_fp8(
96+
x,
8097
group_size=128,
81-
eps=quantization_eps,
82-
column_major_scales=False,
98+
eps=1e-4,
99+
column_major_scales=True,
100+
scale_tma_aligned=True,
83101
)
84-
85-
FP8_E4M3_MAX = 448.0
86-
min_scale_threshold = 1e-4 / FP8_E4M3_MAX
87-
input_scales = torch.clamp(input_scales, min=min_scale_threshold)
88-
input_scales = input_scales.to(torch.float32)
89-
output_m = input_for_quant.shape[0]
90-
output = torch.zeros(
91-
output_m, output_n, dtype=torch.bfloat16, device=input.device
102+
# Prepare output tensor
103+
output = torch.empty(M, self.N, dtype=torch.bfloat16, device=x.device)
104+
# Invoke DeepGEMM
105+
fp8_gemm_nt(
106+
(x_fp8, x_scales),
107+
(self.weight, self.weight_scale),
108+
output,
109+
c=None,
110+
disable_ue8m0_cast=True,
92111
)
93-
94-
# Call DeepGEMM
95-
deepgemm_input_scales = input_scales
96-
input_fp8 = input_fp8.contiguous()
97-
deepgemm_input_scales = deepgemm_input_scales.contiguous()
98-
weight = self.weight.contiguous()
99-
weight_scales = self.weight_scales.contiguous()
100-
output = output.contiguous()
101-
try:
102-
fp8_gemm_nt(
103-
(input_fp8, deepgemm_input_scales),
104-
(weight, weight_scales),
105-
output,
106-
c=None,
107-
disable_ue8m0_cast=True,
108-
)
109-
except Exception as e:
110-
# DeepGEMM call failed - log error and re-raise
111-
error_msg = f"DeepGEMM fp8_gemm_nt call failed: {type(e).__name__}: {e}"
112-
logger.error(error_msg)
113-
import traceback
114-
115-
logger.error(f"Traceback: {traceback.format_exc()}")
116-
raise RuntimeError(error_msg) from e
117-
if need_padding:
118-
output = output[:input_m, :]
119112
if self.bias is not None:
120-
output = output + self.bias.to(output.dtype)
113+
output = output + self.bias
121114
return output
122115

123-
def _get_padding_size(self, m):
124-
"""Calculate padding size based on DeepGEMM requirements."""
125-
if self._gemm_swap_ab_heuristic(m):
126-
if m < 16:
127-
return 16
128-
else:
129-
return 8
130-
else:
131-
return 64
132-
133-
def _gemm_swap_ab_heuristic(self, m):
134-
return False
135-
136-
def _expand_input_scales(self, input_scales, target_shape):
137-
"""Expand input scales to target shape."""
138-
# input_scales: [m, k/128] - always row-major
139-
# target_shape: [m, k]
140-
m, k = target_shape
141-
expected_scales_shape = (m, (k + 127) // 128)
142-
if input_scales.shape != expected_scales_shape:
143-
raise ValueError(
144-
f"Input scales shape mismatch! Expected {expected_scales_shape}, got {input_scales.shape}"
145-
)
146-
expanded = torch.zeros(
147-
target_shape, dtype=input_scales.dtype, device=input_scales.device
148-
)
149-
for i in range(input_scales.shape[0]): # m tokens
150-
for j in range(input_scales.shape[1]): # k/128 groups
151-
k_start = j * 128
152-
k_end = min((j + 1) * 128, k)
153-
expanded[i, k_start:k_end] = input_scales[i, j]
154-
return expanded
155-
156-
def _expand_weight_scales(self):
157-
"""Expand weight scales to weight tensor shape."""
158-
expanded = torch.zeros_like(self.weight, dtype=torch.float32)
159-
for i in range(self.weight_scales.shape[0]): # output_size blocks (60)
160-
for j in range(self.weight_scales.shape[1]): # hidden_size blocks (20)
161-
h_start = i * 128 # output_size dimension
162-
h_end = min((i + 1) * 128, self.weight.shape[0])
163-
w_start = j * 128 # hidden_size dimension
164-
w_end = min((j + 1) * 128, self.weight.shape[1])
165-
expanded[h_start:h_end, w_start:w_end] = self.weight_scales[i, j]
166-
return expanded
167-
168116

169117
class Fp8PerTensorLinear(nn.Module):
170118
def __init__(
171119
self,
172-
quant_config: QuantizationConfig,
173120
weight: torch.Tensor,
174121
weight_scale: torch.Tensor,
175122
input_scale: Optional[torch.Tensor] = None,
176123
bias: Optional[torch.Tensor] = None,
124+
quant_config: Optional[QuantizationConfig] = None,
177125
) -> None:
178126
super().__init__()
179127
self.weight = weight.T
180128
self.weight_scale = weight_scale
181129
self.input_scale = input_scale
182130
self.bias = bias
183-
self.block_quant = None
131+
self.quant_config = quant_config
184132

185133
def forward(self, input: torch.Tensor) -> torch.Tensor:
186134

rtp_llm/models_py/modules/linear_factory.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,13 @@
1010
from rtp_llm.config.gpt_init_model_parameters import GptInitModelParameters
1111
from rtp_llm.models_py.modules import utils
1212
from rtp_llm.models_py.modules.linear import Linear
13+
from rtp_llm.models_py.modules.quantization.deepgemm_wrapper import has_deep_gemm
1314

1415
if utils.is_cuda():
15-
from rtp_llm.models_py.modules.fp8_linear import Fp8PerTensorLinear
16-
17-
try:
18-
from rtp_llm.models_py.modules.fp8_linear import Fp8DeepGEMMLinear
19-
20-
FP8_LINEAR_AVAILABLE = True
21-
except ImportError:
22-
Fp8DeepGEMMLinear = None
23-
FP8_LINEAR_AVAILABLE = False
24-
else:
25-
Fp8DeepGEMMLinear = None
26-
FP8_LINEAR_AVAILABLE = False
16+
from rtp_llm.models_py.modules.fp8_linear import (
17+
Fp8PerBlockLinear,
18+
Fp8PerTensorLinear,
19+
)
2720

2821

2922
class LinearFactory:
@@ -36,9 +29,6 @@ def should_use_fp8_linear(
3629
weight_key: str,
3730
) -> bool:
3831
"""Check if FP8 linear layer should be used."""
39-
if not FP8_LINEAR_AVAILABLE:
40-
return False
41-
4232
if not hasattr(config, "quant_config") or config.quant_config is None:
4333
return False
4434

@@ -80,19 +70,24 @@ def create_linear(
8070
raise ValueError("FP8 linear layer requires config")
8171
else:
8272
quant_config = config.quant_config
83-
if quant_config.get_method() in [
73+
if quant_config.get_method() == "FP8_PER_BLOCK":
74+
if has_deep_gemm():
75+
return Fp8PerBlockLinear(weight, weight_scales, bias)
76+
else:
77+
raise RuntimeError(
78+
"No available fp8 gemm backend for Fp8PerBlockLinear"
79+
)
80+
elif quant_config.get_method() in [
8481
"FP8_PER_TENSOR_COMPRESSED",
8582
"FP8_DYNAMIC_PER_TENSOR",
8683
]:
8784
return Fp8PerTensorLinear(
88-
config.quant_config, weight, weight_scales, input_scales, bias
85+
weight, weight_scales, input_scales, bias, config.quant_config
8986
)
9087
else:
91-
if not FP8_LINEAR_AVAILABLE:
92-
raise RuntimeError(
93-
"FP8 DeepGEMMLinear layer requested but not available"
94-
)
95-
return Fp8DeepGEMMLinear(weight, weight_scales, bias, config)
88+
raise ValueError(
89+
f"Unsupported quantization method: {quant_config.get_method()}"
90+
)
9691
else:
9792
return Linear(weight, bias)
9893

0 commit comments

Comments
 (0)