Skip to content

Commit 15b79ea

Browse files
authored
[Quantization] fix fbgemm (#42561)
* initial commit * passing tests * fix replace_linear * style * rm list * fix * style
1 parent 377a8ee commit 15b79ea

File tree

4 files changed

+133
-120
lines changed

4 files changed

+133
-120
lines changed

src/transformers/integrations/fbgemm_fp8.py

Lines changed: 102 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Optional
16+
1517
from ..activations import ACT2FN
18+
from ..core_model_loading import ConversionOps
19+
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
1620
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
1721

1822

@@ -29,18 +33,75 @@
2933
logger = logging.get_logger(__name__)
3034

3135

36+
class FbgemmFp8Quantize(ConversionOps):
37+
def __init__(self, hf_quantizer):
38+
self.hf_quantizer = hf_quantizer
39+
40+
def convert(
41+
self,
42+
input_dict: dict[str, torch.Tensor | list[torch.Tensor]],
43+
model: Optional[torch.nn.Module] = None,
44+
**kwargs,
45+
) -> dict[str, torch.Tensor]:
46+
target_key, value = tuple(input_dict.items())[0]
47+
value = value[0]
48+
49+
from ..integrations import FbgemmFp8Llama4TextExperts
50+
51+
module, tensor_name = get_module_from_name(model, target_key)
52+
53+
if isinstance(module, FbgemmFp8Llama4TextExperts):
54+
if tensor_name == "gate_up_proj":
55+
# Process each expert separately
56+
# Transpose the second and third dimension
57+
transposed_param = value.transpose(1, 2)
58+
59+
# Reshape to 2D for quantization
60+
original_shape = transposed_param.shape
61+
flattened_param = transposed_param.reshape(-1, original_shape[-1])
62+
63+
# Quantize using per row instead of per column
64+
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
65+
66+
# Reshape back to original dimensions
67+
new_value = new_value_flat.reshape(original_shape)
68+
new_value = new_value.transpose(1, 2)
69+
weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
70+
elif tensor_name == "down_proj":
71+
# Process each expert separately
72+
# Transpose the weights for proper quantization
73+
transposed_param = value.transpose(1, 2)
74+
75+
# Reshape to 2D for quantization
76+
original_shape = transposed_param.shape
77+
flattened_param = transposed_param.reshape(-1, original_shape[-1])
78+
79+
# Quantize using per column
80+
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
81+
82+
# Reshape back to original dimensions
83+
new_value = new_value_flat.reshape(original_shape)
84+
new_value = new_value.transpose(1, 2)
85+
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
86+
else:
87+
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(value)
88+
weight_scale = torch.nn.Parameter(weight_scale.view(weight_scale.shape[0], 1))
89+
90+
return {target_key: torch.nn.Parameter(new_value), f"{target_key}_scale": weight_scale}
91+
92+
3293
class FbgemmFp8Linear(torch.nn.Linear):
33-
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
94+
def __init__(self, in_features, out_features, bias, dtype=torch.float8_e4m3fn):
3495
super().__init__(in_features, out_features, bias)
3596
self.in_features = in_features
3697
self.out_features = out_features
3798

38-
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
39-
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
99+
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=dtype))
100+
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=torch.float32))
40101
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
41102

42103
if bias:
43-
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
104+
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=torch.float32))
44105
else:
45106
self.bias = None
46107

@@ -154,90 +215,11 @@ def forward(self, hidden_states):
154215
return next_states.view(-1, self.hidden_size)
155216

156217

157-
def _replace_with_fbgemm_fp8_linear(
158-
model,
159-
modules_to_not_convert=None,
160-
current_key_name=None,
161-
quantization_config=None,
162-
has_been_replaced=False,
163-
pre_quantized=False,
164-
config=None,
165-
tp_plan=None,
166-
):
167-
"""
168-
Private method that wraps the recursion for module replacement.
169-
170-
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
171-
"""
172-
173-
import re
174-
175-
if current_key_name is None:
176-
current_key_name = []
177-
178-
for name, module in model.named_children():
179-
current_key_name.append(name)
180-
181-
if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
182-
# Check if the current key is not in the `modules_to_not_convert`
183-
current_key_name_str = ".".join(current_key_name)
184-
if not any(
185-
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
186-
):
187-
with init_empty_weights(include_buffers=True):
188-
in_features = module.in_features
189-
out_features = module.out_features
190-
model._modules[name] = FbgemmFp8Linear(
191-
in_features,
192-
out_features,
193-
module.bias is not None,
194-
)
195-
has_been_replaced = True
196-
197-
# Force requires grad to False to avoid unexpected errors
198-
model._modules[name].requires_grad_(False)
199-
# set non persistent buffer outside of init_empty_weights
200-
model._modules[name].input_scale_ub = torch.tensor(
201-
[quantization_config.activation_scale_ub],
202-
dtype=torch.float,
203-
)
204-
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
205-
current_key_name_str = ".".join(current_key_name)
206-
if not any(
207-
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
208-
):
209-
with init_empty_weights(include_buffers=True):
210-
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
211-
model._modules[name] = FbgemmFp8Llama4TextExperts(
212-
config.text_config,
213-
)
214-
model._modules[name].input_scale_ub = torch.tensor(
215-
[quantization_config.activation_scale_ub], dtype=torch.float
216-
)
217-
218-
if len(list(module.children())) > 0:
219-
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
220-
module,
221-
modules_to_not_convert,
222-
current_key_name,
223-
quantization_config,
224-
has_been_replaced=has_been_replaced,
225-
pre_quantized=pre_quantized,
226-
config=config,
227-
tp_plan=tp_plan,
228-
)
229-
# Remove the last key for recursion
230-
current_key_name.pop(-1)
231-
return model, has_been_replaced
232-
233-
234218
def replace_with_fbgemm_fp8_linear(
235219
model,
236220
modules_to_not_convert=None,
237-
current_key_name=None,
238221
quantization_config=None,
239222
pre_quantized=False,
240-
config=None,
241223
tp_plan=None,
242224
):
243225
"""
@@ -254,26 +236,45 @@ def replace_with_fbgemm_fp8_linear(
254236
modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
255237
Names of the modules to not convert in `FP8Linear`. In practice we keep the `lm_head` in full precision
256238
for numerical stability reasons.
257-
current_key_name (`list[`str`]`, *optional*):
258-
An array to track the current key of the recursion. This is used to check whether the current key (part of
259-
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
260-
`disk`).
261239
"""
262240

263-
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
264-
265-
if quantization_config.modules_to_not_convert is not None:
266-
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
267-
modules_to_not_convert = list(set(modules_to_not_convert))
268-
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
269-
model,
270-
modules_to_not_convert,
271-
current_key_name,
272-
quantization_config,
273-
pre_quantized=pre_quantized,
274-
config=config,
275-
tp_plan=tp_plan,
276-
)
241+
has_been_replaced = False
242+
module_kwargs = {} if pre_quantized else {"dtype": None}
243+
244+
for module_name, module in model.named_modules():
245+
if not should_convert_module(module_name, modules_to_not_convert):
246+
continue
247+
248+
new_module = None
249+
with init_empty_weights(include_buffers=True):
250+
if module.__class__.__name__ == "Llama4TextExperts":
251+
# TODO: make sure tp works later
252+
# if tp_plan is not None:
253+
# tp_key = re.sub(r"\d+", "*", f"{module_name}.down_proj_scale")
254+
# tp_plan[tp_key] = None
255+
text_config = getattr(model.config, "text_config", model.config)
256+
new_module = FbgemmFp8Llama4TextExperts(text_config or model.config)
257+
elif isinstance(module, nn.Linear):
258+
new_module = FbgemmFp8Linear(
259+
module.in_features,
260+
module.out_features,
261+
module.bias is not None,
262+
**module_kwargs,
263+
)
264+
new_module.requires_grad_(False)
265+
266+
if new_module is None:
267+
continue
268+
269+
if hasattr(new_module, "input_scale_ub"):
270+
new_module.input_scale_ub = torch.tensor(
271+
[quantization_config.activation_scale_ub],
272+
dtype=torch.float,
273+
)
274+
275+
model.set_submodule(module_name, new_module)
276+
has_been_replaced = True
277+
277278
if not has_been_replaced:
278279
logger.warning(
279280
"You are loading your model using FP8 quantization but no linear modules were found in your model."

src/transformers/quantizers/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,9 @@ def get_keys_to_not_convert(model):
6161
for name, module in model.named_modules()
6262
if output_emb_module is not None and id(module) == id(output_emb_module)
6363
}
64-
candidates = tied_keys | last_module_key | output_emb_keys
64+
modules_to_not_convert = tied_keys | last_module_key | output_emb_keys
6565

66-
modules_to_not_convert = {name.replace(suffix, "") for name in candidates for suffix in [".weight", ".bias"]}
67-
return modules_to_not_convert
66+
return list(modules_to_not_convert)
6867

6968

7069
class HfQuantizer(ABC):

src/transformers/quantizers/quantizer_fbgemm_fp8.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,8 @@ def is_serializable(self, safe_serialization=None):
285285
@property
286286
def is_trainable(self) -> bool:
287287
return False
288+
289+
def get_quantize_ops(self):
290+
from ..integrations.fbgemm_fp8 import FbgemmFp8Quantize
291+
292+
return FbgemmFp8Quantize(self)

tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import gc
1616
import tempfile
1717
import unittest
18+
from typing import Any
1819

1920
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FbgemmFp8Config, OPTForCausalLM
2021
from transformers.testing_utils import (
@@ -71,7 +72,12 @@ class FbgemmFp8Test(unittest.TestCase):
7172
input_text = "What are we having for dinner?"
7273
max_new_tokens = 9
7374

74-
EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad"
75+
EXPECTED_OUTPUT = set[Any](
76+
[
77+
"What are we having for dinner?\nI'm having a steak and a salad",
78+
"What are we having for dinner? I don’t know. What are we having",
79+
]
80+
)
7581

7682
device_map = "cuda"
7783

@@ -155,27 +161,29 @@ def test_quantized_model_conversion(self):
155161
if isinstance(module, FbgemmFp8Linear):
156162
nb_fbgemm_linear += 1
157163

158-
self.assertEqual(nb_linears - 1, nb_fbgemm_linear)
164+
self.assertEqual(nb_linears, nb_fbgemm_linear)
159165

160166
with init_empty_weights():
161167
model = OPTForCausalLM(config)
162168
quantization_config = FbgemmFp8Config(modules_to_not_convert=["fc1"])
163-
model = replace_with_fbgemm_fp8_linear(model, quantization_config=quantization_config)
169+
model = replace_with_fbgemm_fp8_linear(
170+
model, modules_to_not_convert=["fc1"], quantization_config=quantization_config
171+
)
164172
nb_fbgemm_linear = 0
165173
for module in model.modules():
166174
if isinstance(module, FbgemmFp8Linear):
167175
nb_fbgemm_linear += 1
168176

169-
self.assertEqual(nb_linears - 25, nb_fbgemm_linear)
177+
self.assertEqual(nb_linears - 24, nb_fbgemm_linear)
170178

171179
def test_quantized_model(self):
172180
"""
173181
Simple test that checks if the quantized model is working properly
174182
"""
175183
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
176184

177-
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
178-
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
185+
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
186+
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
179187

180188
def test_save_pretrained(self):
181189
"""
@@ -188,8 +196,8 @@ def test_save_pretrained(self):
188196

189197
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
190198

191-
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
192-
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
199+
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
200+
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
193201

194202
def test_change_loading_attributes(self):
195203
"""
@@ -208,8 +216,8 @@ def test_change_loading_attributes(self):
208216

209217
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
210218

211-
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
212-
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
219+
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
220+
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
213221

214222
@require_torch_multi_gpu
215223
def test_quantized_model_multi_gpu(self):
@@ -224,8 +232,8 @@ def test_quantized_model_multi_gpu(self):
224232
)
225233
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
226234

227-
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
228-
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
235+
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
236+
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
229237

230238
def test_quantized_model_offload(self):
231239
"""
@@ -250,8 +258,8 @@ def test_save_pretrained_offload(self):
250258
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
251259

252260
quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map)
253-
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
254-
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
261+
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
262+
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
255263

256264
@require_torch_multi_gpu
257265
def test_save_pretrained_multi_gpu(self):
@@ -266,8 +274,8 @@ def test_save_pretrained_multi_gpu(self):
266274

267275
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
268276

269-
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
270-
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
277+
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
278+
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
271279

272280

273281
@require_torch_gpu

0 commit comments

Comments
 (0)