Commit 50a555b
authored
Add quantize_ nn.Parameter support (#3083)
This PR adds in support for quantizing `nn.Parameter` to `quantize_`.
### bc-breaking changes
The top level `quantize_` API has the following bc-breaking changes:
1) Passing in both`filter_fn` and `ModuleFqnToConfig` is no longer supported and will now throw a value error if both are specified. Previously, we would quantize all modules that were both matched by `filter_fn` and specified in `ModuleFqnToConfig`. Users should now manually specify `filter_fn=None` when using `ModuleFqnToConfig`/`FqnToConfig`.
2) The semantics of `filter_fn=None` have changed. Previously passing in `None` would default to `_is_linear` when running `quantize_`. Now when `filter_fn=None` is specified we ignore `filter_fn` completely and only rely on `FqnToConfig` to quantize the model. Note that this is equivalent to passing in `filter_fn=lambda mod, fqn: True` in the previous API.
3) The default `filter_fn` has changed from `None` to `_is_linear` and `_default` in `ModuleFqnToConfig` now only applies to linear layers. Previously `_default` would apply to all modules that passed `filter_fn`. We plan to deprecate `_default` in the future, please see #3229 for more details.
Before:
```python
model = torch.nn.Sequential(
torch.nn.Linear(128, 128),
torch.nn.Linear(128, 128),
torch.nn.Conv2d(128, 128, 3, 1, 1),
).cuda().to(torch.bfloat16)
config = ModuleFqnToConfig({
"0": Float8DynamicActivationFloat8WeightConfig(),
})
# these are equivalent
quantize_(model, config, filter_fn=_is_linear)
quantize_(model, config, filter_fn=None)
quantize_(model, config)
```
```
> Sequential(
(0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
(1): Linear(in_features=128, out_features=128, bias=True)
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
```
After:
```python
# user must specify None
quantize_(model, config, filter_fn=None)
```
```
> Sequential(
(0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
(1): Linear(in_features=128, out_features=128, bias=True)
(2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
```
After:
```python
# these now error
quantize_(model, config, filter_fn=_is_linear)
quantize_(model, config)
```
```
> ValueError: Custom filter_fn and FqnToConfig were both specified. Only filter_fn=None is supported when FqnToConfig is specified.
```
#### Example for _default changes:
Before:
```python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_parameter("weight", torch.nn.Parameter(torch.randn(128, 128)))
model = torch.nn.Sequential(
torch.nn.Linear(128, 128),
torch.nn.Linear(128, 128),
MyModule(),
).cuda().to(torch.bfloat16)
config = ModuleFqnToConfig({
"_default": Float8DynamicActivationFloat8WeightConfig(),
})
quantize_(model, config, filter_fn=lambda mod, fqn: isinstance(mod, torch.nn.Linear) or isinstance(mod, MyModule))
```
```
> Sequential(
(0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
(1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
(2): MyModule(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
)
```
After:
```python
# only linear is applied for default
quantize_(model, config, filter_fn=None)
```
```
> Sequential(
(0): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
(1): Linear(in_features=128, out_features=128, weight=Float8Tensor(self.act_quant_kwargs=QuantizeTensorToFloat8Kwargs(float8_dtype=torch.float8_e4m3fn, granularity=PerTensor(), mm_config=None, hp_value_lb=None, hp_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>), self.block_size=[128, 128], self.mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), self.scale.shape=torch.Size([1, 1]), self.kernel_preference=<KernelPreference.AUTO: 'auto'>))
(2): MyModule()
)
```
### Summary
`ModuleFqnToConfig` has been renamed to `FqnToConfig`, which now accepts both module fqn and parameter fqns. `ModuleFqnToConfig` has been aliased to maintain BC. The keys to `FqnToConfig` can be one of the following (in order of precedence):
1) exact parameter FQN
```python
quant_config = FqnToConfig({
"linear1.weight": Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
),
})
```
2) exact module FQN
```python
quant_config = FqnToConfig({
"linear1": Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
),
})
```
3) regex that matches parameter FQN (prepended by `re:`)
```python
quant_config = FqnToConfig({
"re:linear*.weight": Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
),
})
```
4) regex that matches module FQN (prepended by `re:`)
```python
quant_config = FqnToConfig({
"re:linear*": Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
),
})
```
5) _default, only applies to `nn.Linear` layers
```python
quant_config = FqnToConfig({
"_default": Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
),
})
```
To enable support for parameter fqn for a paticular config, we need to add the `parameter_name` kwarg into the config signature, and update `CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS`. See the changes [here](https://github.com/pytorch/ao/pull/3083/files#diff-bf4d50867e3d649de2d89146592bf47d2f258c4c19126c8acf0e120ee904b726R1874) for more details.
`Float8DynamicActivationFloat8WeightConfig` has been enabled by this PR, but other configs will throw an `NotImplementedError`.
### Test Plan
1) unit tests for new config:
```
pytest test/quantization/test_quant_api.py::TestFqnToConfig
```
2) regression test for ModuleFqnToConfig
```
pytest test/quantization/test_quant_api.py -k test_module_fqn_to_config
```
3) Make sure that we can load old HF checkpoints to maintain BC, run [this](https://huggingface.co/torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev#test-loading)
4) Make sure that this doesn't break BC with transformers
```
pytest tests/quantization/torchao_integration/test_torchao.py -k test_module_fqn_to_config
```
5) make sure that this doesn't break BC in VLLM:
```
pytest tests/quantization/test_torchao.py
```
___
## How do our configs translate for MoEs?
Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference?
### Some background on MoE inference
There are two ways that forwards is implemented for MoE
- For loop of `nn.Linear` - In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen [here](https://github.com/huggingface/transformers/blob/6cade29278c4aee3f174f8950f97a3873bdb212f/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L117).
**In this case, I argue that the semantics of the configs do not change at all from the normal `nn.Linear` case, as we are just doing a bunch of normal 2d linear matmuls.**
- bmm/grouped mm on the 3d weights / activations directly.
**For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.**
I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly).
Below is a breakdown of what Configs map to AQT / non-AQT subclasses:
| not using AQT | AffineQuantizedTensor |
|-----------|---------------|
| Float8DynamicActivationFloat8WeightConfig | FPXWeightOnlyConfig |
| Float8DynamicActivationInt4WeightConfig | Float8WeightOnlyConfig |
| Float8StaticActivationFloat8WeightConfig | Float8DynamicActivationFloat8SemiSparseWeightConfig |
| Int4WeightOnlyConfig (v2) | GemliteUIntXWeightOnlyConfig |
| | Int4DynamicActivationInt4WeightConfig |
| | Int8DynamicActivationInt4WeightConfig |
| | Int8DynamicActivationInt8WeightConfig |
| | Int8WeightOnlyConfig |
| | IntxWeightOnlyConfig |
| | UIntXWeightOnlyConfig |
For these the majority of the semantics remain the same, the only semantics that really changes is `PerRow` granularity. and there's a very natural extension of `PerRow` to the 3d case (apply on the last dimension).
I took a look at the keys of the non-AQT configs below and what they would mean for MoEs.
#### Float8DynamicActivationFloat8WeightConfig
```
[('activation_dtype', <class 'torch.dtype'>),
('weight_dtype', <class 'torch.dtype'>),
('granularity',
typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.List[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]),
('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]),
('activation_value_lb', typing.Optional[float]),
('activation_value_ub', typing.Optional[float]),
('kernel_preference', <enum 'KernelPreference'>),
('set_inductor_config', <class 'bool'>),
('version', <class 'int'>)]
```
`activation_dtype`, `weight_dtype`, `activation_value_lb`, `activation_value_ub` all do not change meaning semantically.
`granularity=PerTensor()` does not change semantic meaning - we still use a single tensor to scale the entire weight tensor.
`granularity=PerRow()` does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K).
`mm_config` `kernel_preference` and `set_inductor_config` stay the same as well.
#### Float8StaticActivationFloat8WeightConfig
```
[('scale', <class 'torch.Tensor'>),
('activation_dtype', <class 'torch.dtype'>),
('weight_dtype', <class 'torch.dtype'>),
('granularity',
typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.Tuple[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')], typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]),
('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]),
('set_inductor_config', <class 'bool'>)]
```
`scale` should be passed in as a 3d tensor instead of a 2d tensor in the case of `PerRow` granularity
#### Float8DynamicActivationInt4WeightConfig
```
[('int4_packing_format', <enum 'Int4PackingFormat'>)]
```
int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor [supports](https://github.com/pytorch/ao/blob/895573980e085b02a2c6abbc82239bae7f1318d6/torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py#L154) 3d weights.
#### Int4WeightOnlyConfig
```
[('group_size', <class 'int'>),
('layout',
typing.Optional[torchao.dtypes.uintx.tensor_core_tiled_layout.TensorCoreTiledLayout]),
('use_hqq', <class 'bool'>),
('zero_point_domain',
typing.Optional[torchao.quantization.quant_primitives.ZeroPointDomain]),
('set_inductor_config', <class 'bool'>),
('preserve_zero', typing.Optional[bool]),
('int4_packing_format', <enum 'Int4PackingFormat'>),
('int4_choose_qparams_algorithm', <enum 'Int4ChooseQParamsAlgorithm'>),
('version', <class 'int'>)]
```
`group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm`, `set_inductor_config` are the only things that are set for v2 config,
I don't think these semantics of these change, although there are some packing formats that do not support 3d weights. It looks like (`Int4PackingFormat.PLAIN_INT32`, `Int4PackingFormat.MARLIN_SPARSE`).1 parent dffb3a0 commit 50a555b
File tree
6 files changed
+566
-133
lines changed- docs/source
- test
- prototype
- quantization
- torchao/quantization
6 files changed
+566
-133
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
175 | 175 | | |
176 | 176 | | |
177 | 177 | | |
178 | | - | |
| 178 | + | |
179 | 179 | | |
180 | 180 | | |
181 | 181 | | |
| |||
198 | 198 | | |
199 | 199 | | |
200 | 200 | | |
201 | | - | |
| 201 | + | |
202 | 202 | | |
203 | 203 | | |
204 | 204 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
54 | 54 | | |
55 | 55 | | |
56 | 56 | | |
57 | | - | |
58 | | - | |
| 57 | + | |
| 58 | + | |
59 | 59 | | |
60 | | - | |
| 60 | + | |
61 | 61 | | |
62 | 62 | | |
63 | | - | |
| 63 | + | |
64 | 64 | | |
65 | | - | |
| 65 | + | |
66 | 66 | | |
67 | 67 | | |
68 | 68 | | |
69 | 69 | | |
70 | 70 | | |
71 | 71 | | |
72 | | - | |
73 | 72 | | |
74 | 73 | | |
75 | 74 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
588 | 588 | | |
589 | 589 | | |
590 | 590 | | |
591 | | - | |
592 | | - | |
593 | | - | |
594 | | - | |
595 | | - | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
596 | 594 | | |
597 | 595 | | |
598 | 596 | | |
| |||
0 commit comments