Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 112 additions & 96 deletions src/transformers/integrations/fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Optional

from ..activations import ACT2FN
from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging


Expand All @@ -29,18 +34,85 @@
logger = logging.get_logger(__name__)


class FbgemmFp8Quantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer

def convert(
self,
input_dict: dict[str, torch.Tensor | list[torch.Tensor]],
model: Optional[torch.nn.Module] = None,
**kwargs,
) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0]

from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts

module, tensor_name = get_module_from_name(model, target_key)

# Sanity checks
if isinstance(module, FbgemmFp8Linear):
if tensor_name == "weight" and value.dtype == torch.float8_e4m3fn:
raise ValueError("Expect unquantized weights but got a quantized weight")
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a weight_scale")
if isinstance(module, FbgemmFp8Llama4TextExperts):
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove those checks, this shouldn't be possible here.

Suggested change
# Sanity checks
if isinstance(module, FbgemmFp8Linear):
if tensor_name == "weight" and value.dtype == torch.float8_e4m3fn:
raise ValueError("Expect unquantized weights but got a quantized weight")
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a weight_scale")
if isinstance(module, FbgemmFp8Llama4TextExperts):
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")


if isinstance(module, FbgemmFp8Llama4TextExperts):
if tensor_name == "gate_up_proj":
# Process each expert separately
# Transpose the second and third dimension
transposed_param = value.transpose(1, 2)

# Reshape to 2D for quantization
original_shape = transposed_param.shape
flattened_param = transposed_param.reshape(-1, original_shape[-1])

# Quantize using per row instead of per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)

# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
elif tensor_name == "down_proj":
# Process each expert separately
# Transpose the weights for proper quantization
transposed_param = value.transpose(1, 2)

# Reshape to 2D for quantization
original_shape = transposed_param.shape
flattened_param = transposed_param.reshape(-1, original_shape[-1])

# Quantize using per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)

# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
else:
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(value)
weight_scale = torch.nn.Parameter(weight_scale.view(weight_scale.shape[0], 1))

return {target_key: torch.nn.Parameter(new_value), f"{target_key}_scale": weight_scale}


class FbgemmFp8Linear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
def __init__(self, in_features, out_features, bias, dtype=torch.float8_e4m3fn):
super().__init__(in_features, out_features, bias)
self.in_features = in_features
self.out_features = out_features

self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=dtype))
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=torch.float32))
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)

if bias:
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=torch.float32))
else:
self.bias = None

Expand Down Expand Up @@ -154,87 +226,9 @@ def forward(self, hidden_states):
return next_states.view(-1, self.hidden_size)


def _replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
config=None,
tp_plan=None,
):
"""
Private method that wraps the recursion for module replacement.

Returns the converted model and a boolean that indicates if the conversion has been successful or not.
"""

import re

if current_key_name is None:
current_key_name = []

for name, module in model.named_children():
current_key_name.append(name)

if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = FbgemmFp8Linear(
in_features,
out_features,
module.bias is not None,
)
has_been_replaced = True

# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# set non persistent buffer outside of init_empty_weights
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub],
dtype=torch.float,
)
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
model._modules[name] = FbgemmFp8Llama4TextExperts(
config.text_config,
)
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub], dtype=torch.float
)

if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
pre_quantized=pre_quantized,
config=config,
tp_plan=tp_plan,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced


def replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
pre_quantized=False,
config=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use model.config directly

Expand All @@ -260,20 +254,42 @@ def replace_with_fbgemm_fp8_linear(
`disk`).
"""

modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert

if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert,
current_key_name,
quantization_config,
pre_quantized=pre_quantized,
config=config,
tp_plan=tp_plan,
)
has_been_replaced = False
module_kwargs = {} if pre_quantized else {"dtype": None}

for module_name, module in model.named_modules():
if not should_convert_module(module_name, modules_to_not_convert):
continue

new_module = None
with init_empty_weights(include_buffers=True):
if module.__class__.__name__ == "Llama4TextExperts":
if tp_plan is not None:
tp_key = re.sub(r"\d+", "*", f"{module_name}.down_proj_scale")
tp_plan[tp_key] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment this for now

text_config = getattr(config, "text_config", config)
new_module = FbgemmFp8Llama4TextExperts(text_config or model.config)
elif isinstance(module, nn.Linear):
new_module = FbgemmFp8Linear(
module.in_features,
module.out_features,
module.bias is not None,
**module_kwargs,
)
new_module.requires_grad_(False)

if new_module is None:
continue

if hasattr(new_module, "input_scale_ub"):
new_module.input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub],
dtype=torch.float,
)

model.set_submodule(module_name, new_module)
has_been_replaced = True

if not has_been_replaced:
logger.warning(
"You are loading your model using FP8 quantization but no linear modules were found in your model."
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ def get_keys_to_not_convert(model):
for name, module in model.named_modules()
if output_emb_module is not None and id(module) == id(output_emb_module)
}
candidates = tied_keys | last_module_key | output_emb_keys
modules_to_not_convert = tied_keys | last_module_key | output_emb_keys

modules_to_not_convert = {name.replace(suffix, "") for name in candidates for suffix in [".weight", ".bias"]}
return modules_to_not_convert
return list(modules_to_not_convert)


class HfQuantizer(ABC):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/quantizers/quantizer_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,8 @@ def is_serializable(self, safe_serialization=None):
@property
def is_trainable(self) -> bool:
return False

def get_quantize_ops(self):
from ..integrations.fbgemm_fp8 import FbgemmFp8Quantize

return FbgemmFp8Quantize(self)
40 changes: 24 additions & 16 deletions tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import gc
import tempfile
import unittest
from typing import Any

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, FbgemmFp8Config, OPTForCausalLM
from transformers.testing_utils import (
Expand Down Expand Up @@ -71,7 +72,12 @@ class FbgemmFp8Test(unittest.TestCase):
input_text = "What are we having for dinner?"
max_new_tokens = 9

EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad"
EXPECTED_OUTPUT = set[Any](
[
"What are we having for dinner?\nI'm having a steak and a salad",
"What are we having for dinner? I don’t know. What are we having",
]
)

device_map = "cuda"

Expand Down Expand Up @@ -155,27 +161,29 @@ def test_quantized_model_conversion(self):
if isinstance(module, FbgemmFp8Linear):
nb_fbgemm_linear += 1

self.assertEqual(nb_linears - 1, nb_fbgemm_linear)
self.assertEqual(nb_linears, nb_fbgemm_linear)

with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = FbgemmFp8Config(modules_to_not_convert=["fc1"])
model = replace_with_fbgemm_fp8_linear(model, quantization_config=quantization_config)
model = replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert=["fc1"], quantization_config=quantization_config
)
nb_fbgemm_linear = 0
for module in model.modules():
if isinstance(module, FbgemmFp8Linear):
nb_fbgemm_linear += 1

self.assertEqual(nb_linears - 25, nb_fbgemm_linear)
self.assertEqual(nb_linears - 24, nb_fbgemm_linear)

def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)

def test_save_pretrained(self):
"""
Expand All @@ -188,8 +196,8 @@ def test_save_pretrained(self):

input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)

def test_change_loading_attributes(self):
"""
Expand All @@ -208,8 +216,8 @@ def test_change_loading_attributes(self):

input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)

@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
Expand All @@ -224,8 +232,8 @@ def test_quantized_model_multi_gpu(self):
)
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)

def test_quantized_model_offload(self):
"""
Expand All @@ -250,8 +258,8 @@ def test_save_pretrained_offload(self):
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)

@require_torch_multi_gpu
def test_save_pretrained_multi_gpu(self):
Expand All @@ -266,8 +274,8 @@ def test_save_pretrained_multi_gpu(self):

input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)

output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)


@require_torch_gpu
Expand Down