Skip to content

Commit c77e654

Browse files
authored
add NVFP4 formal document (#2321)
1 parent 5864c7a commit c77e654

File tree

2 files changed

+109
-6
lines changed

2 files changed

+109
-6
lines changed

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,22 @@ model = load(
126126
</thead>
127127
<tbody>
128128
<tr>
129-
<td colspan="2" align="center"><a href="./docs/source/3x/PyTorch.md">Overview</a></td>
130-
<td colspan="2" align="center"><a href="./docs/source/3x/PT_DynamicQuant.md">Dynamic Quantization</a></td>
131-
<td colspan="2" align="center"><a href="./docs/source/3x/PT_StaticQuant.md">Static Quantization</a></td>
129+
<td colspan="8" align="center"><a href="./docs/source/3x/PyTorch.md">Overview</a></td>
130+
</tr>
131+
<tr>
132+
<td colspan="3" align="center"><a href="./docs/source/3x/PT_DynamicQuant.md">Dynamic Quantization</a></td>
133+
<td colspan="3" align="center"><a href="./docs/source/3x/PT_StaticQuant.md">Static Quantization</a></td>
132134
<td colspan="2" align="center"><a href="./docs/source/3x/PT_SmoothQuant.md">Smooth Quantization</a></td>
133135
</tr>
134136
<tr>
135-
<td colspan="2" align="center"><a href="./docs/source/3x/PT_WeightOnlyQuant.md">Weight-Only Quantization</a></td>
136-
<td colspan="2" align="center"><a href="./docs/source/3x/PT_FP8Quant.md">FP8 Quantization</a></td>
137-
<td colspan="2" align="center"><a href="./docs/source/3x/PT_MXQuant.md">MX Quantization</a></td>
137+
<td colspan="3" align="center"><a href="./docs/source/3x/PT_WeightOnlyQuant.md">Weight-Only Quantization</a></td>
138+
<td colspan="3" align="center"><a href="./docs/source/3x/PT_FP8Quant.md">FP8 Quantization</a></td>
138139
<td colspan="2" align="center"><a href="./docs/source/3x/PT_MixedPrecision.md">Mixed Precision</a></td>
139140
</tr>
141+
<tr>
142+
<td colspan="4" align="center"><a href="./docs/source/3x/PT_MXQuant.md">MX Quantization</a></td>
143+
<td colspan="4" align="center"><a href="./docs/source/3x/PT_NVFP4Quant.md">NVFP4 Quantization</a></td>
144+
</tr>
140145
</tbody>
141146
<thead>
142147
<tr>

docs/source/3x/PT_NVFP4Quant.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
NVFP4 Quantization
2+
==================
3+
4+
1. [Introduction](#introduction)
5+
2. [Get Started with NVFP4 Quantization API](#get-started-with-nvfp4-quantization-api)
6+
3. [Reference](#reference)
7+
8+
## Introduction
9+
10+
Large language models (LLMs) have revolutionized fields such as natural language understanding, generation, and multimodal processing. As these models grow, their computational and memory requirements increase, making efficient deployment challenging. To address these issues, quantization methods are employed to reduce model size and accelerate inference with minimal loss in accuracy.
11+
12+
NVFP4 is a specialized 4-bit floating-point format (FP4) developed by NVIDIA for deep learning workloads. Compared to traditional INT8 or FP16 formats, NVFP4 offers further reductions in memory footprint and computational resource use, enabling efficient inference for LLMs and other neural networks on supported hardware.
13+
14+
The following table summarizes the NVFP4 quantization format:
15+
16+
<table>
17+
<tr>
18+
<th>Format Name</th>
19+
<th>Element Data type</th>
20+
<th>Element Bits</th>
21+
<th>Scaling Block Size</th>
22+
<th>Scale Data Type</th>
23+
<th>Scale Bits</th>
24+
<th>Global Tensor-Wise Scale Data Type</th>
25+
<th>Global Tensor-Wise Scale Bits</th>
26+
</tr>
27+
<tr>
28+
<td>NVFP4</td>
29+
<td>E2M1</td>
30+
<td>4</td>
31+
<td>16</td>
32+
<td>UE4M3</td>
33+
<td>8</td>
34+
<td>FP32</td>
35+
<td>32</td>
36+
</tr>
37+
</table>
38+
39+
> Note: UE4M3 is the same data type as normal FP8 E4M3, here UE4M3 is named to remind that the sign bit remains 0 and scale is always positive.
40+
41+
### Understanding the Scaling Mechanism
42+
43+
NVFP4 uses a two-level scaling approach to maintain accuracy while reducing precision:
44+
45+
- **Block-wise Scale**: The quantized tensor is divided into blocks of size 16 (the Scaling Block Size). Each block has its own scale factor stored in UE4M3 format (8 bits), which is used to convert the 4-bit E2M1 quantized values back to a higher precision representation. This fine-grained scaling helps preserve local variations in the data.
46+
47+
- **Global Tensor-Wise Scale**: In addition to the block-wise scales, a single FP32 (32-bit) scale factor is applied to the entire tensor. This global scale provides an additional level of normalization for the whole weight or activation tensor. For activations, this global scale is static (computed during calibration and fixed during inference) to optimize performance.
48+
49+
The dequantization formula can be expressed as:
50+
51+
$$\text{dequantized\_value} = \text{quantized\_value} \times \text{block\_scale} \times \text{global\_scale}$$
52+
53+
This hierarchical scaling strategy balances compression efficiency with numerical accuracy, enabling NVFP4 to maintain model performance while significantly reducing memory footprint.
54+
55+
At similar accuracy levels, NVFP4 can deliver lower memory usage and improved compute efficiency for multiply-accumulate operations compared to higher-precision formats. Neural Compressor supports post-training quantization to NVFP4, providing recipes and APIs for users to quantize LLMs easily.
56+
57+
## Get Started with NVFP4 Quantization API
58+
59+
To quantize a model to the NVFP4 format, use the AutoRound Quantization API as shown below.
60+
61+
```python
62+
from neural_compressor.torch.quantization import AutoRoundConfig, prepare, convert
63+
from transformers import AutoModelForCausalLM, AutoTokenizer
64+
65+
fp32_model = AutoModelForCausalLM.from_pretrained(
66+
"facebook/opt-125m",
67+
device_map="auto",
68+
)
69+
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m", trust_remote_code=True)
70+
output_dir = "./saved_inc"
71+
72+
# quantization configuration
73+
quant_config = AutoRoundConfig(
74+
tokenizer=tokenizer,
75+
nsamples=32,
76+
seqlen=32,
77+
iters=20,
78+
scheme="NVFP4", # NVFP4 format
79+
export_format="llm_compressor",
80+
output_dir=output_dir, # default is "temp_auto_round"
81+
)
82+
83+
# quantize the model and save to output_dir
84+
model = prepare(model=fp32_model, quant_config=quant_config)
85+
model = convert(model)
86+
87+
# loading
88+
model = AutoModelForCausalLM.from_pretrained(output_dir, torch_dtype="auto", device_map="auto")
89+
90+
# inference
91+
text = "There is a girl who likes adventure,"
92+
inputs = tokenizer(text, return_tensors="pt").to(model.device)
93+
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=10)[0]))
94+
```
95+
96+
## Reference
97+
98+
[1]: NVIDIA, Introducing NVFP4 for efficient and accurate low-precision inference,NVIDIA Developer Blog, Jun. 2025. [Online]. Available: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/

0 commit comments

Comments
 (0)