-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[Quantization] fix fbgemm #42561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[Quantization] fix fbgemm #42561
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
3cb1e7b
initial commit
MekkCyber 0f60898
Merge remote-tracking branch 'upstream/HEAD' into fix-fbgemm
MekkCyber 43e4d1d
passing tests
MekkCyber 08d6ff0
fix replace_linear
MekkCyber 0a3d11b
style
MekkCyber 7510720
rm list
MekkCyber 494c9bd
Merge branch 'main' into fix-fbgemm
MekkCyber 737aaa8
fix
MekkCyber ee64fac
style
MekkCyber File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,12 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import re | ||
| from typing import Optional | ||
|
|
||
| from ..activations import ACT2FN | ||
| from ..core_model_loading import ConversionOps | ||
| from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module | ||
| from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging | ||
|
|
||
|
|
||
|
|
@@ -29,18 +34,85 @@ | |
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class FbgemmFp8Quantize(ConversionOps): | ||
| def __init__(self, hf_quantizer): | ||
| self.hf_quantizer = hf_quantizer | ||
|
|
||
| def convert( | ||
| self, | ||
| input_dict: dict[str, torch.Tensor | list[torch.Tensor]], | ||
| model: Optional[torch.nn.Module] = None, | ||
| **kwargs, | ||
| ) -> dict[str, torch.Tensor]: | ||
| target_key, value = tuple(input_dict.items())[0] | ||
| value = value[0] | ||
|
|
||
| from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts | ||
|
|
||
| module, tensor_name = get_module_from_name(model, target_key) | ||
|
|
||
| # Sanity checks | ||
| if isinstance(module, FbgemmFp8Linear): | ||
| if tensor_name == "weight" and value.dtype == torch.float8_e4m3fn: | ||
| raise ValueError("Expect unquantized weights but got a quantized weight") | ||
| if tensor_name == "weight_scale": | ||
| raise ValueError("Expect unquantized weights but got a weight_scale") | ||
| if isinstance(module, FbgemmFp8Llama4TextExperts): | ||
| if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale": | ||
| raise ValueError("Expect unquantized weights but got a quantized weight_scale") | ||
|
|
||
| if isinstance(module, FbgemmFp8Llama4TextExperts): | ||
| if tensor_name == "gate_up_proj": | ||
| # Process each expert separately | ||
| # Transpose the second and third dimension | ||
| transposed_param = value.transpose(1, 2) | ||
|
|
||
| # Reshape to 2D for quantization | ||
| original_shape = transposed_param.shape | ||
| flattened_param = transposed_param.reshape(-1, original_shape[-1]) | ||
|
|
||
| # Quantize using per row instead of per column | ||
| new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param) | ||
|
|
||
| # Reshape back to original dimensions | ||
| new_value = new_value_flat.reshape(original_shape) | ||
| new_value = new_value.transpose(1, 2) | ||
| weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1]) | ||
| elif tensor_name == "down_proj": | ||
| # Process each expert separately | ||
| # Transpose the weights for proper quantization | ||
| transposed_param = value.transpose(1, 2) | ||
|
|
||
| # Reshape to 2D for quantization | ||
| original_shape = transposed_param.shape | ||
| flattened_param = transposed_param.reshape(-1, original_shape[-1]) | ||
|
|
||
| # Quantize using per column | ||
| new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param) | ||
|
|
||
| # Reshape back to original dimensions | ||
| new_value = new_value_flat.reshape(original_shape) | ||
| new_value = new_value.transpose(1, 2) | ||
| weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1) | ||
| else: | ||
| new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(value) | ||
| weight_scale = torch.nn.Parameter(weight_scale.view(weight_scale.shape[0], 1)) | ||
|
|
||
| return {target_key: torch.nn.Parameter(new_value), f"{target_key}_scale": weight_scale} | ||
|
|
||
|
|
||
| class FbgemmFp8Linear(torch.nn.Linear): | ||
| def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32): | ||
| def __init__(self, in_features, out_features, bias, dtype=torch.float8_e4m3fn): | ||
| super().__init__(in_features, out_features, bias) | ||
| self.in_features = in_features | ||
| self.out_features = out_features | ||
|
|
||
| self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn)) | ||
| self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype)) | ||
| self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=dtype)) | ||
| self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=torch.float32)) | ||
| self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False) | ||
|
|
||
| if bias: | ||
| self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype)) | ||
| self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=torch.float32)) | ||
| else: | ||
| self.bias = None | ||
|
|
||
|
|
@@ -154,87 +226,9 @@ def forward(self, hidden_states): | |
| return next_states.view(-1, self.hidden_size) | ||
|
|
||
|
|
||
| def _replace_with_fbgemm_fp8_linear( | ||
| model, | ||
| modules_to_not_convert=None, | ||
| current_key_name=None, | ||
| quantization_config=None, | ||
| has_been_replaced=False, | ||
| pre_quantized=False, | ||
| config=None, | ||
| tp_plan=None, | ||
| ): | ||
| """ | ||
| Private method that wraps the recursion for module replacement. | ||
|
|
||
| Returns the converted model and a boolean that indicates if the conversion has been successful or not. | ||
| """ | ||
|
|
||
| import re | ||
|
|
||
| if current_key_name is None: | ||
| current_key_name = [] | ||
|
|
||
| for name, module in model.named_children(): | ||
| current_key_name.append(name) | ||
|
|
||
| if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert: | ||
| # Check if the current key is not in the `modules_to_not_convert` | ||
| current_key_name_str = ".".join(current_key_name) | ||
| if not any( | ||
| (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert | ||
| ): | ||
| with init_empty_weights(include_buffers=True): | ||
| in_features = module.in_features | ||
| out_features = module.out_features | ||
| model._modules[name] = FbgemmFp8Linear( | ||
| in_features, | ||
| out_features, | ||
| module.bias is not None, | ||
| ) | ||
| has_been_replaced = True | ||
|
|
||
| # Force requires grad to False to avoid unexpected errors | ||
| model._modules[name].requires_grad_(False) | ||
| # set non persistent buffer outside of init_empty_weights | ||
| model._modules[name].input_scale_ub = torch.tensor( | ||
| [quantization_config.activation_scale_ub], | ||
| dtype=torch.float, | ||
| ) | ||
| if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert: | ||
| current_key_name_str = ".".join(current_key_name) | ||
| if not any( | ||
| (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert | ||
| ): | ||
| with init_empty_weights(include_buffers=True): | ||
| tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None | ||
| model._modules[name] = FbgemmFp8Llama4TextExperts( | ||
| config.text_config, | ||
| ) | ||
| model._modules[name].input_scale_ub = torch.tensor( | ||
| [quantization_config.activation_scale_ub], dtype=torch.float | ||
| ) | ||
|
|
||
| if len(list(module.children())) > 0: | ||
| _, has_been_replaced = _replace_with_fbgemm_fp8_linear( | ||
| module, | ||
| modules_to_not_convert, | ||
| current_key_name, | ||
| quantization_config, | ||
| has_been_replaced=has_been_replaced, | ||
| pre_quantized=pre_quantized, | ||
| config=config, | ||
| tp_plan=tp_plan, | ||
| ) | ||
| # Remove the last key for recursion | ||
| current_key_name.pop(-1) | ||
| return model, has_been_replaced | ||
|
|
||
|
|
||
| def replace_with_fbgemm_fp8_linear( | ||
| model, | ||
| modules_to_not_convert=None, | ||
| current_key_name=None, | ||
| quantization_config=None, | ||
| pre_quantized=False, | ||
| config=None, | ||
|
||
|
|
@@ -260,20 +254,42 @@ def replace_with_fbgemm_fp8_linear( | |
| `disk`). | ||
| """ | ||
|
|
||
| modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert | ||
|
|
||
| if quantization_config.modules_to_not_convert is not None: | ||
| modules_to_not_convert.extend(quantization_config.modules_to_not_convert) | ||
| modules_to_not_convert = list(set(modules_to_not_convert)) | ||
| model, has_been_replaced = _replace_with_fbgemm_fp8_linear( | ||
| model, | ||
| modules_to_not_convert, | ||
| current_key_name, | ||
| quantization_config, | ||
| pre_quantized=pre_quantized, | ||
| config=config, | ||
| tp_plan=tp_plan, | ||
| ) | ||
| has_been_replaced = False | ||
| module_kwargs = {} if pre_quantized else {"dtype": None} | ||
|
|
||
| for module_name, module in model.named_modules(): | ||
| if not should_convert_module(module_name, modules_to_not_convert): | ||
| continue | ||
|
|
||
| new_module = None | ||
| with init_empty_weights(include_buffers=True): | ||
| if module.__class__.__name__ == "Llama4TextExperts": | ||
| if tp_plan is not None: | ||
| tp_key = re.sub(r"\d+", "*", f"{module_name}.down_proj_scale") | ||
| tp_plan[tp_key] = None | ||
|
||
| text_config = getattr(config, "text_config", config) | ||
| new_module = FbgemmFp8Llama4TextExperts(text_config or model.config) | ||
| elif isinstance(module, nn.Linear): | ||
| new_module = FbgemmFp8Linear( | ||
| module.in_features, | ||
| module.out_features, | ||
| module.bias is not None, | ||
| **module_kwargs, | ||
| ) | ||
| new_module.requires_grad_(False) | ||
|
|
||
| if new_module is None: | ||
| continue | ||
|
|
||
| if hasattr(new_module, "input_scale_ub"): | ||
| new_module.input_scale_ub = torch.tensor( | ||
| [quantization_config.activation_scale_ub], | ||
| dtype=torch.float, | ||
| ) | ||
|
|
||
| model.set_submodule(module_name, new_module) | ||
| has_been_replaced = True | ||
|
|
||
| if not has_been_replaced: | ||
| logger.warning( | ||
| "You are loading your model using FP8 quantization but no linear modules were found in your model." | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove those checks, this shouldn't be possible here.