Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions docs/pytorch_converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,11 @@ edge_model(*inputs_2, signature_name="input2")
## Quantization

Following is the code snippet to quantize a model with [PT2E
quantization](https://pytorch.org/tutorials/prototype/quantization_in_pytorch_2_0_export_tutorial.html)
quantization](https://docs.pytorch.org/ao/stable/tutorials_source/pt2e_quant_ptq.html)
using the `ai_edge_torch` backend.

```python
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch._export import capture_pre_autograd_graph

from ai_edge_torch.quantize.pt2e_quantizer import get_symmetric_quantization_config
from ai_edge_torch.quantize.pt2e_quantizer import PT2EQuantizer
from ai_edge_torch.quantize.quant_config import QuantConfig
Expand All @@ -135,7 +133,12 @@ pt2e_quantizer = PT2EQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
)

pt2e_torch_model = capture_pre_autograd_graph(torch_model, sample_args)
# > For pytorch 2.6+
pt2e_torch_model = torch.export.export(torch_model, sample_args).module()
# > For pytorch 2.5 and before
# from torch._export import capture_pre_autograd_graph
# pt2e_torch_model = capture_pre_autograd_graph(torch_model, sample_args)

pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
Expand Down
Loading