Skip to content

Commit e058ea0

Browse files
authored
Merge branch 'main' into main
2 parents 0d1d35e + 3cd0300 commit e058ea0

File tree

25 files changed

+578
-264
lines changed

25 files changed

+578
-264
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Big updates have landed in LLM Compressor! To get a more in-depth look, check ou
3737

3838
Some of the exciting new features include:
3939

40+
* **AutoRound Quantization Support**: Added [`AutoRoundModifier`](examples/autoround/llama3_example.py) for quantization using [AutoRound](https://aclanthology.org/2024.findings-emnlp.662.pdf), an advanced post-training algorithm that optimizes rounding and clipping ranges through sign-gradient descent. This approach combines the efficiency of post-training quantization with the adaptability of parameter tuning, delivering robust compression for large language models while maintaining strong performance.
4041
* **Qwen3 Next and Qwen3 VL MoE Quantization Support**: Quantize the Qwen3 Next and Qwen3 VL MoE models and seamlessly run the models in vLLM. Examples for [NVFP4](examples/quantization_w4a4_fp4/qwen3_next_example.py) and [FP8](examples/quantization_w8a8_fp8/qwen3_next_example.py) Quantization have been added for the Qwen3-Next-80B-A3B-Instruct. For the Qwen3 VL MoE, support has been added for the datafree pathway, specifically [FP8 Quantization](examples/quantization_w8a8_fp8/qwen3_vl_moe_fp8_example.py) (e.g channel-wise and block-wise quantization). NOTE: these models are not supported in tranformers<=4.56.2. You may need to install transformers from source.
4142
* **Quantization with Multiple Modifiers**: Multiple quantization modifiers can now be applied to the same model for mixed-precision quantization, for example applying AWQ W4A16 to a model's `self_attn` layers and GPTQ W8A8 to its `mlp` layers. This is an advanced usage of `llm-compressor` and an active area of research. See the [non-uniform quantization support](examples/quantization_non_uniform) section for more detail and [example usage](examples/quantization_non_uniform/quantization_multiple_modifiers.py).
4243
* **QuIP and SpinQuant-style Transforms**: The newly added [`QuIPModifier`](examples/transform/quip_example.py) and [`SpinQuantModifier`](examples/transform/spinquant_example.py) allow users to quantize their models after injecting hadamard weights into the computation graph, reducing quantization error and greatly improving accuracy recovery for low bit weight and activation quantization.
@@ -55,6 +56,7 @@ Some of the exciting new features include:
5556
* AWQ
5657
* SmoothQuant
5758
* SparseGPT
59+
* AutoRound
5860

5961
### When to Use Which Optimization
6062

@@ -78,6 +80,7 @@ Applying quantization with `llmcompressor`:
7880
* [Weight only quantization to `fp4`](examples/quantization_w4a16_fp4/llama3_example.py)
7981
* [Weight only quantization to `int4` using GPTQ](examples/quantization_w4a16/README.md)
8082
* [Weight only quantization to `int4` using AWQ](examples/awq/README.md)
83+
* [Weight only quantization to `int4` using AutoRound](examples/autoround/README.md)
8184
* [Quantizing MoE LLMs](examples/quantizing_moe/README.md)
8285
* [Quantizing Vision-Language Models](examples/multimodal_vision/README.md)
8386
* [Quantizing Audio-Language Models](examples/multimodal_audio/README.md)

docs/getting-started/compress.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Compression schemes use quantization methods including the following:
3333
| **AWQ** | Uses channelwise scaling to better preserve important outliers in weights and activations | Better accuracy recovery with faster runtime than GPTQ |
3434
| **SmoothQuant** | Smooths outliers in activations by folding them into weights, ensuring better accuracy for weight and activation quantized models | Good accuracy recovery with minimal calibration time; composable with other methods |
3535
| **Round-To-Nearest (RTN)** | Simple quantization technique that rounds each value to the nearest representable level in the target precision. | Provides moderate accuracy recovery in most scenarios. Computationally cheap and fast to implement, making it suitable for real-time or resource-constrained environments. |
36+
| **AutoRound** | AutoRound optimizes rounding and clipping ranges via sign-gradient descent. | Delivers leading 4-bit and superior sub-4-bit accuracy compared to GPTQ/AWQ, with runtime faster than GPTQ and on par with AWQ. |
3637

3738
For this guide, we'll use `GPTQ` composed with `SmoothQuant` to create an `INT W8A8` quantized model. This combination provides a good balance for performance, accuracy, and compatability across a wide range of hardware.
3839

examples/autoround/README.md

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# `AutoRound` Quantization
2+
3+
`llm-compressor` supports [AutoRound](https://aclanthology.org/2024.findings-emnlp.662.pdf), an advanced quantization technique that delivers **high-accuracy**, **low-bit quantization**. The quantized results are fully compatible with `compressed-tensors` and can be served directly with vLLM.
4+
5+
AutoRound introduces three trainable parameters (V, α, and β) to optimize rounding values and clipping ranges during quantization. The method processes each decoder layer sequentially, using block-wise output reconstruction error as the training objective to fine-tune these parameters. This approach combines the efficiency of post-training quantization with the adaptability of parameter tuning, delivering robust compression for large language models while maintaining strong performance.
6+
7+
## Installation
8+
9+
To get started, install:
10+
11+
```bash
12+
git clone https://github.com/vllm-project/llm-compressor.git
13+
cd llm-compressor
14+
pip install -e .
15+
```
16+
17+
## Quickstart
18+
19+
The example includes an end-to-end script for applying the AutoRound quantization algorithm.
20+
21+
```bash
22+
python3 llama3_example.py
23+
```
24+
25+
The resulting model `Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound` is ready to be loaded into vLLM.
26+
27+
## Code Walkthrough
28+
29+
Now, we will step through the code in the example. There are four steps:
30+
1) Load model
31+
2) Prepare calibration data
32+
3) Apply quantization
33+
4) Evaluate accuracy in vLLM
34+
35+
### 1) Load Model
36+
37+
Load the model using `AutoModelForCausalLM` for handling quantized saving and loading.
38+
39+
```python
40+
from transformers import AutoTokenizer, AutoModelForCausalLM
41+
42+
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
43+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
44+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
45+
```
46+
47+
### 2) Prepare Calibration Data
48+
49+
When quantizing model weights with AutoRound, you’ll need a small set of sample data to run the algorithm. By default, we are using [NeelNanda/pile-10k](https://huggingface.co/datasets/NeelNanda/pile-10k) as our calibration dataset.
50+
Recommended starting points:
51+
- 128 samples — typically sufficient for stable calibration (increase if accuracy degrades).
52+
- 2048 sequence length — a good baseline for most LLMs.
53+
- 200 tuning steps — usually enough to converge (increase if accuracy drops).
54+
55+
```python
56+
# Select calibration dataset.
57+
from auto_round.calib_dataset import get_dataset
58+
59+
NUM_CALIBRATION_SAMPLES = 128
60+
MAX_SEQUENCE_LENGTH = 2048
61+
62+
# Get aligned calibration dataset.
63+
ds = get_dataset(
64+
tokenizer=tokenizer,
65+
seqlen=MAX_SEQUENCE_LENGTH,
66+
nsamples=NUM_CALIBRATION_SAMPLES,
67+
)
68+
```
69+
70+
### 3) Apply Quantization
71+
72+
With the dataset ready, we will now apply AutoRound quantization to the model.
73+
74+
```python
75+
from llmcompressor import oneshot
76+
from llmcompressor.modifiers.autoround import AutoRoundModifier
77+
78+
# Configure the quantization algorithm to run.
79+
recipe = AutoRoundModifier(
80+
targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200
81+
)
82+
83+
# Apply quantization.
84+
oneshot(
85+
model=model,
86+
dataset=ds,
87+
recipe=recipe,
88+
max_seq_length=MAX_SEQUENCE_LENGTH,
89+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
90+
# disable shuffling to get slightly better mmlu score
91+
shuffle_calibration_samples=False,
92+
)
93+
94+
95+
# Save to disk compressed.
96+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound"
97+
model.save_pretrained(SAVE_DIR, save_compressed=True)
98+
tokenizer.save_pretrained(SAVE_DIR)
99+
```
100+
101+
We have successfully created an `int4` model!
102+
103+
### 4) Evaluate Accuracy
104+
105+
With the model created, we can now load and run in vLLM (after installing).
106+
107+
```python
108+
from vllm import LLM
109+
model = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound")
110+
```
111+
112+
We can evaluate accuracy with `lm_eval` (`pip install lm-eval==0.4.9.1`):
113+
> Note: quantized models can be sensitive to the presence of the `bos` token. `lm_eval` does not add a `bos` token by default, so make sure to include the `add_bos_token=True` argument when running your evaluations.
114+
115+
Run the following to test accuracy on GSM-8K:
116+
117+
```bash
118+
lm_eval --model vllm \
119+
--model_args pretrained="./Meta-Llama-3-8B-Instruct-W4A16-G128-AutoRound",add_bos_token=true \
120+
--tasks gsm8k \
121+
--num_fewshot 5 \
122+
--limit 1000 \
123+
--batch_size 'auto'
124+
```
125+
126+
We can see the resulting scores look good!
127+
128+
```bash
129+
| Tasks | Version | Filter | n-shot | Metric | | Value | | Stderr |
130+
| ----- | ------: | ---------------- | -----: | ----------- | --- | ----: | --- | -----: |
131+
| gsm8k | 3 | flexible-extract | 5 | exact_match || 0.737 | ± | 0.0139 |
132+
| | | strict-match | 5 | exact_match || 0.736 | ± | 0.0139 |
133+
```
134+
> Note: quantized model accuracy may vary slightly due to nondeterminism.
135+
136+
### Known Issues
137+
Currently, `llm-compressor` supports applying AutoRound only on the `wNa16` quantization schemes. Support for additional schemes is planned. You can follow progress in the [RFC](https://github.com/vllm-project/llm-compressor/issues/1968).
138+
139+
### Questions or Feature Request?
140+
141+
Please open up an issue on [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor) or [intel/auto-round](https://github.com/intel/auto-round).

setup.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ def localversion_func(version: ScmVersion) -> str:
144144
if BUILD_TYPE == "release"
145145
else "compressed-tensors>=0.12.3a2"
146146
),
147-
# TODO: replace it with the release version
148-
("auto_round @ git+https://github.com/intel/auto-round.git@llmc"),
147+
("auto-round==0.9.1"),
149148
],
150149
extras_require={
151150
"dev": [

src/llmcompressor/entrypoints/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@
2929
from llmcompressor.pytorch.model_load.helpers import parse_dtype
3030
from llmcompressor.transformers.compression.compressed_tensors_utils import (
3131
modify_save_pretrained,
32-
untie_word_embeddings,
3332
)
3433
from llmcompressor.transformers.utils.helpers import (
3534
is_model_ct_quantized_from_path,
3635
)
3736
from llmcompressor.typing import Processor
37+
from llmcompressor.utils import untie_word_embeddings
3838
from llmcompressor.utils.fsdp.helpers import is_fsdp_model
3939

4040

src/llmcompressor/modifiers/autoround/base.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@
2020
from llmcompressor.modifiers import Modifier
2121
from llmcompressor.modifiers.quantization.calibration import apply_calibration_status
2222
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
23-
from llmcompressor.transformers.compression.compressed_tensors_utils import (
24-
untie_if_target_shared_embedding,
25-
)
26-
from llmcompressor.utils.pytorch.module import get_no_split_params
23+
from llmcompressor.utils import targets_embeddings, untie_word_embeddings
24+
from llmcompressor.utils.pytorch import get_no_split_params
2725

2826
__all__ = ["AutoRoundModifier"]
2927

@@ -111,9 +109,9 @@ class AutoRoundModifier(Modifier, QuantizationMixin):
111109
# AutoRound modifier arguments
112110
iters: int = 200
113111
enable_torch_compile: bool = True
112+
batch_size: int = 8
114113

115114
# private variables
116-
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
117115
_all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict)
118116
_q_input: Optional[torch.Tensor] = PrivateAttr(default=None)
119117

@@ -128,10 +126,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:
128126
QuantizationMixin.initialize_quantization(self, state.model)
129127

130128
# prepare module names
131-
self._module_names = {
132-
m: name
133-
for name, m in match_named_modules(state.model, self.targets, self.ignore)
134-
}
135129
self._add_temporary_names(state.model)
136130
# freeze all model parameters
137131
for _, param in state.model.named_parameters():
@@ -146,7 +140,9 @@ def start_calibration(self, model: torch.nn.Module):
146140
147141
:param model: model to prepare for calibration
148142
"""
149-
untie_if_target_shared_embedding(model, self._module_names.values())
143+
targets = match_named_modules(model, self.targets, self.ignore)
144+
if targets_embeddings(model, targets):
145+
untie_word_embeddings(model)
150146

151147
for _, module in match_named_modules(model, self.targets, self.ignore):
152148
# Note: No need to register observers for auto-round
@@ -227,6 +223,7 @@ def apply_autoround(self, state, subgraph):
227223
scheme=ar_quant_scheme,
228224
iters=self.iters,
229225
enable_torch_compile=self.enable_torch_compile,
226+
batch_size=self.batch_size,
230227
)
231228
# TODO: configure layer-wise config based on self.resolved_config
232229
ar.configure_layer_config(enable_gguf_official_mixed=False)
@@ -240,7 +237,7 @@ def apply_autoround(self, state, subgraph):
240237
block=decoding_layer,
241238
inputs=cur_inputs,
242239
q_input=self._q_input,
243-
device=device,
240+
device=str(device),
244241
# Leave offload for LLMC
245242
auto_offload=False,
246243
)

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class AWQMapping:
166166
"Llama4ForConditionalGeneration": _default_mappings,
167167
"Mistral3ForConditionalGeneration": _default_mappings,
168168
"MistralForCausalLM": _default_mappings,
169+
"Olmo3ForCausalLM": _exaone4_mappings,
169170
"Phi3ForCausalLM": _phi_mappings,
170171
"Phi3VForCausalLM": _phi_mappings,
171172
"Qwen2ForCausalLM": _default_mappings,

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
reset_quantization_status,
3535
)
3636
from llmcompressor.modifiers.utils.hooks import HooksMixin
37-
from llmcompressor.transformers.compression.compressed_tensors_utils import (
38-
untie_if_target_shared_embedding,
39-
)
37+
from llmcompressor.utils import targets_embeddings, untie_word_embeddings
4038

4139
__all__ = ["QuantizationMixin"]
4240

@@ -184,11 +182,9 @@ def start_calibration(self, model: torch.nn.Module):
184182
185183
:param model: model to prepare for calibration
186184
"""
187-
188-
matched_module_generator = (
189-
x[1] for x in match_named_modules(model, self.resolved_targets, self.ignore)
190-
)
191-
untie_if_target_shared_embedding(model, matched_module_generator)
185+
targets = match_named_modules(model, self.resolved_targets, self.ignore)
186+
if targets_embeddings(model, targets):
187+
untie_word_embeddings(model)
192188

193189
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
194190
self._initialize_observers(module)

src/llmcompressor/modifiers/smoothquant/base.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Callable, Dict, List, Optional, Tuple, Union
33

44
import torch
5-
from compressed_tensors.utils import align_module_device
5+
from compressed_tensors.utils import align_module_device, match_named_modules
66
from loguru import logger
77
from pydantic import ConfigDict, Field
88
from torch.nn import Module
@@ -14,11 +14,7 @@
1414
handle_mapping_resolution_errors,
1515
)
1616
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
17-
from llmcompressor.utils.pytorch.module import (
18-
get_layers,
19-
get_matching_layer,
20-
match_targets,
21-
)
17+
from llmcompressor.utils.pytorch.module import get_layer_by_name
2218

2319
MINIMUM_SMOOTHING_SCALE = 1e-5
2420

@@ -196,31 +192,34 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]:
196192
Transforms the list of activations to smooth and their corresponding weights
197193
into SmoothQuantMapping objects, resolving regular expressions.
198194
199-
For each activation in the mapping list, we find the corresponding weight to
200-
balance by searching for the longest substring. For instance, if our balance
201-
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
202-
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
203-
repeat for model.layer.1 and so on
195+
For each activation in the mapping list, we find ALL corresponding weights to
196+
balance by matching within the parent scope. This ensures all matching layers
197+
are included, which is critical for MoE models where multiple experts need to
198+
be balanced.
204199
"""
205200
resolved_mappings = []
206201
for to_balance, to_smooth in self.mappings:
207-
to_smooth_layers = get_layers(to_smooth, model)
208-
for layer_name, smooth_layer in to_smooth_layers.items():
209-
if not match_targets(layer_name, self.ignore)[0]:
210-
balance_layers = []
211-
for balance_suffix in to_balance:
212-
# find the submodule that matches the activation layer
213-
_, balance_layer = get_matching_layer(
214-
balance_suffix, layer_name, model
215-
)
216-
if balance_layer:
217-
balance_layers.append(balance_layer)
218-
# each mapping can contain multiple layers to balance, but only
219-
# one layer to smooth
220-
mapping = SmoothQuantMapping(
221-
layer_name, smooth_layer, balance_layers
202+
to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth
203+
204+
for smooth_name, smooth_layer in match_named_modules(
205+
model, to_smooth_list, self.ignore
206+
):
207+
# Search for balance layers within the parent scope
208+
smooth_parent_name = ".".join(smooth_name.split(".")[:-1])
209+
smooth_parent = get_layer_by_name(smooth_parent_name, model)
210+
211+
balance_layers = [
212+
balance_layer
213+
for _, balance_layer in match_named_modules(
214+
smooth_parent, to_balance, self.ignore
215+
)
216+
]
217+
218+
if balance_layers:
219+
resolved_mappings.append(
220+
SmoothQuantMapping(smooth_name, smooth_layer, balance_layers)
222221
)
223-
resolved_mappings.append(mapping)
222+
224223
return resolved_mappings
225224

226225
def _setup_scale_hooks(self):

0 commit comments

Comments
 (0)