From 1c74b8b59ee539e191bc0bab845d1eba754672e3 Mon Sep 17 00:00:00 2001 From: realAsma <86726418+realAsma@users.noreply.github.com> Date: Tue, 25 Nov 2025 21:23:40 -0800 Subject: [PATCH 1/8] [1/N] Refactored AutoQuantizeSearcher to _AutoQuantizeBaseSearcher & AutoQuantizeGradientSearcher; seperated quant modules and score modules (#586) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? **Type of change:** Refator; Minor new feature **Overview:** ? 1. Refactored AutoQuantizeSearcher to _AutoQuantizeBaseSearcher & AutoQuantizeGradientSearcher - Prepares architecture for additional search methods. 2. seperated quant modules and score modules - separate quantization modules from scoring modules, enabling auto-quantization to measure sensitivity at parent layers (e.g., MLP output for MoE experts) rather than individual ops. 3. Also see https://github.com/NVIDIA/TensorRT-Model-Optimizer/pull/592 and https://github.com/NVIDIA/TensorRT-Model-Optimizer/pull/588 ## Testing See unittests; `tests/unit/torch/quantization/test_autoquant.py` and `tests/unit/torch/quantization/plugins/test_huggingface.py` ## Before your PR is "*Ready for review*" - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Not Required ## Additional Information ## Summary by CodeRabbit * **New Features** * Added support for score modules in quantization workflows. * Added optional naming for quantization recipes. * **Bug Fixes** * Improved quantization grouping rules documentation with clearer configuration examples. * **Refactor** * Renamed quantization module parameters for improved clarity. * Enhanced quantization search architecture for better scalability. ✏️ Tip: You can customize this high-level summary in your review settings. --------- Signed-off-by: realAsma Co-authored-by: Asma Kuriparambil Thekkumpate Signed-off-by: inisis --- CHANGELOG.rst | 2 + examples/llm_eval/gen_model_answer.py | 35 +- examples/llm_eval/lm_eval_hf.py | 39 +- examples/llm_eval/mmlu.py | 6 + examples/llm_eval/quantization_utils.py | 65 +- examples/llm_ptq/hf_ptq.py | 73 +- .../llm_ptq/scripts/huggingface_example.sh | 22 + examples/llm_ptq/scripts/parser.sh | 8 +- modelopt/torch/opt/hparam.py | 10 +- modelopt/torch/quantization/algorithms.py | 1075 ++++++++++++----- modelopt/torch/quantization/model_quant.py | 78 +- .../torch/quantization/plugins/huggingface.py | 4 +- ...rt_weight.py => test_export_weight_gpu.py} | 0 .../quantization/plugins/test_huggingface.py | 28 +- .../unit/torch/quantization/test_autoquant.py | 66 +- 15 files changed, 1170 insertions(+), 341 deletions(-) rename tests/gpu/torch/export/{test_export_weight.py => test_export_weight_gpu.py} (100%) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 899b14009..beb01abf0 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,8 @@ Model Optimizer Changelog (Linux) - Add MoE (e.g. Qwen3-30B-A3B, gpt-oss-20b) pruning support for ``num_moe_experts``, ``moe_ffn_hidden_size`` and ``moe_shared_expert_intermediate_size`` parameters in Minitron pruning (``mcore_minitron``). - Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md `_ for more details. - Add FP8/NVFP4 KV cache quantization support for Megatron Core models. +- Add KL Divergence loss based auto_quantize method. See `auto_quantize API docs `_ for more details. +- Add support for saving and resuming auto_quantize search state. This speeds up the auto_quantize process by skipping the score estimation step if the search state is provided. - Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow. - Add support for PyTorch Geometric quantization. - Add per tensor and per channel MSE calibrator support. diff --git a/examples/llm_eval/gen_model_answer.py b/examples/llm_eval/gen_model_answer.py index 86504db62..42a7eaac9 100644 --- a/examples/llm_eval/gen_model_answer.py +++ b/examples/llm_eval/gen_model_answer.py @@ -201,8 +201,11 @@ def get_model_answers( tokenizer, args.calib_batch_size, args.calib_size, - args.auto_quantize_bits, test_generated=False, + auto_quantize_bits=args.auto_quantize_bits, + auto_quantize_method=args.auto_quantize_method, + auto_quantize_score_size=args.auto_quantize_score_size, + auto_quantize_checkpoint=args.auto_quantize_checkpoint, ) for question in tqdm(questions): @@ -450,6 +453,36 @@ def reorg_answer_file(answer_file): "regular quantization without auto_quantize search will be applied." ), ) + parser.add_argument( + "--auto_quantize_method", + type=str, + default="gradient", + choices=["gradient", "kl_div"], + help=( + "Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method " + "(requires labels in dataset). 'kl_div' uses KL divergence between original and " + "quantized model outputs (no labels required). Default: 'gradient'" + ), + ) + parser.add_argument( + "--auto_quantize_score_size", + type=int, + default=128, + help=( + "Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on " + "sensitivity score estimation, so reducing this speeds it up while only minimally affecting " + "final model accuracy compared to lowering --calib_size (the number of samples used for calibration)." + ), + ) + parser.add_argument( + "--auto_quantize_checkpoint", + type=str, + default=None, + help=( + "Path to checkpoint file for saving/restoring auto_quantize search state " + "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." + ), + ) parser.add_argument( "--trust_remote_code", help="Set trust_remote_code for Huggingface models and tokenizers", diff --git a/examples/llm_eval/lm_eval_hf.py b/examples/llm_eval/lm_eval_hf.py index e980a376e..31103ff86 100755 --- a/examples/llm_eval/lm_eval_hf.py +++ b/examples/llm_eval/lm_eval_hf.py @@ -53,6 +53,9 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | quant_cfg = arg_dict.pop("quant_cfg", None) auto_quantize_bits = arg_dict.pop("auto_quantize_bits", None) + auto_quantize_method = arg_dict.pop("auto_quantize_method", "gradient") + auto_quantize_score_size = arg_dict.pop("auto_quantize_score_size", 128) + auto_quantize_checkpoint = arg_dict.pop("auto_quantize_checkpoint", None) calib_batch_size = arg_dict.pop("calib_batch_size", None) calib_size = arg_dict.pop("calib_size", 512) compress = arg_dict.pop("compress", False) @@ -81,8 +84,11 @@ def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | batch_size=calib_batch_size, calib_size=calib_size, auto_quantize_bits=auto_quantize_bits, + auto_quantize_method=auto_quantize_method, + auto_quantize_score_size=auto_quantize_score_size, test_generated=False, compress=compress, + auto_quantize_checkpoint=auto_quantize_checkpoint, ) return model_obj @@ -101,6 +107,12 @@ def setup_parser_with_modelopt_args(): "comma-separated list of quantization quantization formats that will be searched by `auto_quantize`" ), ) + parser.add_argument( + "--calib_batch_size", type=int, help="Batch size for quantization calibration" + ) + parser.add_argument( + "--calib_size", type=int, help="Calibration size for quantization", default=512 + ) parser.add_argument( "--auto_quantize_bits", type=float, @@ -110,10 +122,30 @@ def setup_parser_with_modelopt_args(): ), ) parser.add_argument( - "--calib_batch_size", type=int, help="Batch size for quantization calibration" + "--auto_quantize_method", + type=str, + default="gradient", + choices=["gradient", "kl_div"], + help=( + "Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method " + "(requires labels in dataset). 'kl_div' uses KL divergence between original and " + "quantized model outputs (no labels required). Default: 'gradient'" + ), ) parser.add_argument( - "--calib_size", type=int, help="Calibration size for quantization", default=512 + "--auto_quantize_score_size", + type=int, + default=128, + help=( + "Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on " + "sensitivity score estimation, so reducing this speeds it up while only minimally affecting " + "final model accuracy compared to lowering --calib_size (the number of samples used for calibration)." + ), + ) + parser.add_argument( + "--auto_quantize_checkpoint", + type=str, + help=("Path to checkpoint file for saving/restoring auto_quantize search state. "), ) parser.add_argument( "--compress", @@ -139,6 +171,9 @@ def setup_parser_with_modelopt_args(): { "quant_cfg": args.quant_cfg, "auto_quantize_bits": args.auto_quantize_bits, + "auto_quantize_method": args.auto_quantize_method, + "auto_quantize_score_size": args.auto_quantize_score_size, + "auto_quantize_checkpoint": args.auto_quantize_checkpoint, "calib_batch_size": args.calib_batch_size, "calib_size": args.calib_size, "compress": args.compress, diff --git a/examples/llm_eval/mmlu.py b/examples/llm_eval/mmlu.py index 6a2f70ce4..ca244052b 100755 --- a/examples/llm_eval/mmlu.py +++ b/examples/llm_eval/mmlu.py @@ -227,6 +227,9 @@ def main( batch_size: int = 0, calib_size: int = 512, dtype: str = "bfloat16", + auto_quantize_method: str = "gradient", + auto_quantize_score_size: int = 128, + auto_quantize_checkpoint: str | None = None, **kwargs, ): random.seed(RAND_SEED) @@ -281,6 +284,9 @@ def main( batch_size=batch_size, calib_size=calib_size, auto_quantize_bits=auto_quantize_bits, + auto_quantize_method=auto_quantize_method, + auto_quantize_score_size=auto_quantize_score_size, + auto_quantize_checkpoint=auto_quantize_checkpoint, ) for subject in tqdm(subjects): diff --git a/examples/llm_eval/quantization_utils.py b/examples/llm_eval/quantization_utils.py index 2f43c93e0..3df44115a 100644 --- a/examples/llm_eval/quantization_utils.py +++ b/examples/llm_eval/quantization_utils.py @@ -66,8 +66,11 @@ def _quantize_model_with_dataset( quant_cfg: str | list[str], calib_dataset, auto_quantize_bits=None, + auto_quantize_method="gradient", + auto_quantize_score_size=128, batch_size=1, compress=False, + auto_quantize_checkpoint=None, ): if hasattr(lm, "gpt2"): net = lm.gpt2 @@ -81,23 +84,42 @@ def _quantize_model_with_dataset( getattr(mtq, quant_fmt) for quant_fmt in quant_cfg if quant_fmt != "NONE" ] - def loss_func(output, data): - # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` - # which contains the loss attribute. - return output.loss + # Configure forward_step and loss_func based on method + if auto_quantize_method == "gradient": + # For gradient-based method, return full output with loss + def forward_step(model, batch): + return model(**batch) + + def loss_func(output, data): + # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast` + # which contains the loss attribute. + return output.loss + elif auto_quantize_method == "kl_div": + # For KL divergence method, return only logits + def forward_step(model, batch): + return model(**batch).logits + + loss_func = None # KL divergence doesn't need a custom loss function + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. " + "Must be 'gradient' or 'kl_div'" + ) net, _ = mtq.auto_quantize( net, constraints={"effective_bits": auto_quantize_bits}, quantization_formats=quant_cfg_for_search, data_loader=calib_dataset, - forward_step=lambda model, batch: model(**batch), + forward_step=forward_step, loss_func=loss_func, num_calib_steps=len(calib_dataset), - num_score_steps=min( - len(calib_dataset), 128 // batch_size - ), # Limit the number of score steps to avoid long calibration time + # Most time is spent on score estimation; fewer samples speed it up with little accuracy impact. + num_score_steps=min(len(calib_dataset), max(auto_quantize_score_size // batch_size, 1)), verbose=True, + method=auto_quantize_method, + # disabled_layers=["*lm_head*", "*mlp.gate.*"], + checkpoint=auto_quantize_checkpoint, ) else: mtq_cfg = CUSTOM_CONFIG.get(quant_cfg) # type: ignore [arg-type] @@ -141,10 +163,13 @@ def quantize_model( tokenizer, batch_size, calib_size, - auto_quantize_bits=None, data="cnn_dailymail", test_generated=True, compress=False, + auto_quantize_bits=None, + auto_quantize_method="gradient", + auto_quantize_score_size=128, + auto_quantize_checkpoint=None, ): """Quantizes the model with the provided calibration dataset. @@ -155,10 +180,14 @@ def quantize_model( tokenizer: the tokenizer. batch_size: the calibration batch size for each calibration inference run. calib_size: the total calibration dataset size. - auto_quantize_bits: The effective bits constraint for auto_quantize. data: the name of the calibration dataset. test_generated: If ``True``, test the generated text before and after quantization. compress: If ``True``, compress the model after quantization. + auto_quantize_bits: The effective bits constraint for auto_quantize. + auto_quantize_method: The method for auto_quantize ('gradient' or 'kl_div'). + auto_quantize_score_size: Number of samples used for auto_quantize scoring. + auto_quantize_checkpoint: Path to checkpoint file for saving/restoring auto_quantize search state + (sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified. """ if "AWQ" in quant_cfg: print( @@ -170,8 +199,10 @@ def quantize_model( if hasattr(model, "model"): device = model.model.device + is_gradient_based = auto_quantize_bits is not None and auto_quantize_method == "gradient" + if batch_size == 0: - if auto_quantize_bits is not None or torch.distributed.is_initialized(): + if is_gradient_based or torch.distributed.is_initialized(): raise ValueError("We dont support automatic batch size inference for this case.") net = model.gpt2 if hasattr(model, "gpt2") else model.model @@ -186,7 +217,7 @@ def quantize_model( batch_size=batch_size, num_samples=calib_size, device=device, - include_labels=auto_quantize_bits is not None, + include_labels=is_gradient_based, ) if test_generated: @@ -194,7 +225,15 @@ def quantize_model( generated_str_before_ptq = model.run(input_str) _quantize_model_with_dataset( - model, quant_cfg, calib_dataloader, auto_quantize_bits, batch_size, compress + model, + quant_cfg, + calib_dataloader, + auto_quantize_bits, + auto_quantize_method, + auto_quantize_score_size, + batch_size, + compress, + auto_quantize_checkpoint, ) if test_generated: diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 7bb8d0f28..f0bb56caa 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -95,7 +95,15 @@ def auto_quantize( - model, qformat, auto_quantize_bits, calib_dataloader, calibrate_loop, batch_size=1 + model, + qformat, + calib_dataloader, + calibrate_loop, + auto_quantize_bits, + batch_size=1, + auto_quantize_method="gradient", + auto_quantize_score_size=128, + auto_quantize_checkpoint=None, ): qformat_list = qformat.split(",") assert qformat_list, "No quantization formats provided" @@ -122,18 +130,34 @@ def loss_func(output, data): # which contains the loss attribute. return output.loss + if auto_quantize_method == "gradient": + # For gradient-based method, return full output with loss + def forward_step(model, batch): + return model(**batch) + elif auto_quantize_method == "kl_div": + # For KL divergence method, return only logits + def forward_step(model, batch): + return model(**batch).logits + else: + raise ValueError( + f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" + ) + model, _ = mtq.auto_quantize( model, constraints={"effective_bits": auto_quantize_bits}, data_loader=calib_dataloader, - forward_step=lambda model, batch: model(**batch), - loss_func=loss_func, + forward_step=forward_step, + loss_func=loss_func, # Only used for gradient-based method # TRTLLM only support one quantization format or None (do not quantize, internally supported) quantization_formats=[QUANT_CFG_CHOICES[format] for format in qformat_list], num_calib_steps=len(calib_dataloader), - num_score_steps=len(calib_dataloader), + # AutoQuantize scoring is the costly phase; allow smaller sample counts than calibration. + num_score_steps=min(len(calib_dataloader), max(auto_quantize_score_size // batch_size, 1)), verbose=True, disabled_layers=["*lm_head*"], + method=auto_quantize_method, + checkpoint=auto_quantize_checkpoint, ) # We need to explicitly calibrate for kv cache quantization @@ -191,10 +215,13 @@ def quantize_model(model, quant_cfg, args, calib_dataloader=None, calibration_on model = auto_quantize( model, args.qformat, - args.auto_quantize_bits, calib_dataloader, calibrate_loop, + args.auto_quantize_bits, args.batch_size, + args.auto_quantize_method, + args.auto_quantize_score_size, + args.auto_quantize_checkpoint, ) elif calibration_only: model = mtq.calibrate(model, quant_cfg["algorithm"], forward_loop=calibrate_loop) @@ -444,13 +471,17 @@ def main(args): assert tokenizer is not None and isinstance( tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast) ), "The PreTrainedTokenizer must be set" + # Labels are only needed for gradient-based auto_quantize + include_labels = ( + args.auto_quantize_bits is not None and args.auto_quantize_method == "gradient" + ) calib_dataloader = get_dataset_dataloader( dataset_name=args.dataset, tokenizer=tokenizer, batch_size=args.batch_size, num_samples=args.calib_size, device=device, - include_labels=args.auto_quantize_bits is not None, + include_labels=include_labels, ) quant_cfg = build_quant_cfg( @@ -803,6 +834,36 @@ def output_decode(generated_ids, input_shape): default=None, type=str, ) + parser.add_argument( + "--auto_quantize_method", + type=str, + default="gradient", + choices=["gradient", "kl_div"], + help=( + "Method for auto_quantize sensitivity analysis. 'gradient' uses gradient-based method " + "(requires labels in dataset). 'kl_div' uses KL divergence between original and " + "quantized model outputs (no labels required). Default: 'gradient'" + ), + ) + parser.add_argument( + "--auto_quantize_score_size", + type=int, + default=128, + help=( + "Number of samples to use for auto_quantize scoring. Most of auto_quantize time is spent on " + "sensitivity score estimation, so reducing this speeds it up while only minimally affecting " + "final model accuracy compared to lowering --calib_size (the number of samples used for calibration)." + ), + ) + parser.add_argument( + "--auto_quantize_checkpoint", + type=str, + default=None, + help=( + "Path to checkpoint file for saving/restoring auto_quantize search state " + "(sensitivity scores, costs, etc.). Only used when auto_quantize_bits is specified." + ), + ) args = parser.parse_args() diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 7b7d6910e..043b690e5 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -93,6 +93,28 @@ fi if [ -n "$AUTO_QUANTIZE_BITS" ]; then PTQ_ARGS+=" --auto_quantize_bits=$AUTO_QUANTIZE_BITS " fi + +if [ -n "$AUTO_QUANTIZE_METHOD" ]; then + PTQ_ARGS+=" --auto_quantize_method=$AUTO_QUANTIZE_METHOD " +fi + +if [ -n "$AUTO_QUANTIZE_SCORE_SIZE" ]; then + PTQ_ARGS+=" --auto_quantize_score_size=$AUTO_QUANTIZE_SCORE_SIZE " +fi + +# Automatically generate auto_quantize checkpoint path if not provided +if [ -n "$AUTO_QUANTIZE_BITS" ] && [ -z "$AUTO_QUANTIZE_CHECKPOINT" ]; then + # Create a descriptive checkpoint name based on model and quantization settings + AQ_METHOD=${AUTO_QUANTIZE_METHOD:-gradient} + AUTO_QUANTIZE_CHECKPOINT="${ROOT_SAVE_PATH}/auto_quantize_checkpoints/${MODEL_NAME}_${AQ_METHOD}.pth" + mkdir -p $(dirname $AUTO_QUANTIZE_CHECKPOINT) + echo "Auto-generated auto_quantize checkpoint path: $AUTO_QUANTIZE_CHECKPOINT" +fi + +if [ -n "$AUTO_QUANTIZE_BITS" ]; then + PTQ_ARGS+=" --auto_quantize_checkpoint=$AUTO_QUANTIZE_CHECKPOINT " +fi + if [ -n "$CALIB_DATASET" ]; then PTQ_ARGS+=" --dataset=$CALIB_DATASET " fi diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index 7df601327..8db2fe131 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -36,7 +36,7 @@ parse_options() { USE_SEQ_DEVICE_MAP=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:" -n "$0" -- "$@") eval set -- "$ARGS" while true; do @@ -65,6 +65,9 @@ parse_options() { --low_memory_mode ) LOW_MEMORY_MODE=true; shift;; --calib_dataset ) CALIB_DATASET="$2"; shift 2;; --calib_seq ) CALIB_SEQ="$2"; shift 2;; + --auto_quantize_method ) AUTO_QUANTIZE_METHOD="$2"; shift 2;; + --auto_quantize_score_size ) AUTO_QUANTIZE_SCORE_SIZE="$2"; shift 2;; + --auto_quantize_checkpoint ) AUTO_QUANTIZE_CHECKPOINT="$2"; shift 2;; -- ) shift; break ;; * ) break ;; esac @@ -150,5 +153,8 @@ parse_options() { echo "low_memory_mode: $LOW_MEMORY_MODE" echo "calib_dataset: $CALIB_DATASET" echo "calib_seq: $CALIB_SEQ" + echo "auto_quantize_method: $AUTO_QUANTIZE_METHOD" + echo "auto_quantize_score_size: $AUTO_QUANTIZE_SCORE_SIZE" + echo "auto_quantize_checkpoint: $AUTO_QUANTIZE_CHECKPOINT" echo "=================" } diff --git a/modelopt/torch/opt/hparam.py b/modelopt/torch/opt/hparam.py index 13a9aef0d..60bde4c1e 100644 --- a/modelopt/torch/opt/hparam.py +++ b/modelopt/torch/opt/hparam.py @@ -48,7 +48,7 @@ def __eq__(self, other) -> bool: class Hparam: """A base hyperparameter of a DynamicModule. - An example of such a Hparam could be an hparam with identity dependencies. + Keeps track of hyperparameter values and their importance, which can be used for search algorithms. """ Importance = Union[torch.Tensor, None] # noqa: UP007 @@ -249,10 +249,14 @@ def _enforce_order(self, order: torch.Tensor | None = None) -> None: order = order.cpu() self._slice_order = order + @property + def attrs(self) -> list[str]: + """Return the attributes of the hparam for repr.""" + return ["choices", "active", "original"] + def __repr__(self) -> str: """Return string representation with relevant properties of the class.""" - attrs = ["choices", "active", "original"] - return f"{type(self).__name__}({', '.join(f'{x}={getattr(self, x)}' for x in attrs)})" + return f"{type(self).__name__}({', '.join(f'{x}={getattr(self, x)}' for x in self.attrs)})" def __iand__(self, hp: "Hparam"): """Merge another hparam into self.""" diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 3b99e40a9..16abc6ef4 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -19,6 +19,7 @@ import gc import types import warnings +from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Callable, Sequence from contextlib import nullcontext @@ -34,7 +35,7 @@ from modelopt.torch.opt.searcher import LPS, BaseSearcher, SearchConfig, SearchStateDict from modelopt.torch.opt.utils import get_hparam, named_hparams from modelopt.torch.utils import create_param_grad_clear_hook, print_rank_0, report_memory -from modelopt.torch.utils.distributed import DistributedProcessGroup, is_master +from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState, is_master from . import config as mtq_config from . import model_calib @@ -169,20 +170,29 @@ def fold_pqs_to_weights(model): class QuantRecipeHparam(Hparam): """An Hparam for quantization recipes. - In addition, this Hparam also: - 1. Keeps a link to its modules and sets the quantizers for the module based on the active recipe. - 2. Keeps track of the importance of each recipe in a dict instead of a tensor + See :class:`Hparam ` for more details. In addition, this Hparam also: + + * Keeps a link to its ``quant_modules`` and ``score_modules`` and sets the quantizers for the + ``quant_modules`` based on the active recipe. + * Provides ``get_score()`` and ``get_cost()`` methods to evaluate recipes. + * Registers itself with each ``score_module`` via the ``_hparams_for_scoring`` attribute. """ def __init__( self, choices: Sequence[QuantRecipe] | None = None, - nn_modules: list[nn.Module] | None = None, + quant_modules: list[nn.Module] | None = None, + score_modules: list[nn.Module] | None = None, + name: str | None = None, ) -> None: """Initializes Hparam with original value and choices.""" choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)}) super().__init__(choices, original=choices[0]) - self.nn_modules = nn_modules if nn_modules else [] + + self.name = name + + self.quant_modules = list(set(quant_modules or [])) + self.score_modules = list(set(score_modules or self.quant_modules)) # This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer # a dynamic attribute for backward compatibility with the model_calib.py @@ -191,17 +201,17 @@ def __init__( quant_recipe: QuantRecipe for quant_recipe in self.choices: - for nn_module in self.nn_modules: + for quant_module in self.quant_modules: for quantizer_attr_name in [ "input_quantizer", "weight_quantizer", "output_quantizer", ]: - setattr(nn_module, quantizer_attr_name, TensorQuantizer()) + setattr(quant_module, quantizer_attr_name, TensorQuantizer()) - set_quantizer_by_cfg(nn_module, quant_recipe.config.quant_cfg) - self._all_quantizer_choices[quant_recipe][nn_module] = { - quantizer_attr_name: getattr(nn_module, quantizer_attr_name) + set_quantizer_by_cfg(quant_module, quant_recipe.config.quant_cfg) + self._all_quantizer_choices[quant_recipe][quant_module] = { + quantizer_attr_name: getattr(quant_module, quantizer_attr_name) for quantizer_attr_name in [ "input_quantizer", "weight_quantizer", @@ -211,14 +221,17 @@ def __init__( self.active = self.original + # Importance dict is keyed by score_module (where the score is computed) self._importance_dict = { - quant_recipe: { - mod: torch.zeros((), device=mod.weight.device, dtype=torch.float32) - for mod in self.nn_modules - } - for quant_recipe in self.choices + quant_recipe: dict.fromkeys(self.score_modules) for quant_recipe in self.choices } + # Attach this hparam to each score_module's set of hparams it scores + for score_module in self.score_modules: + if not hasattr(score_module, "_hparams_for_scoring"): + score_module._hparams_for_scoring = set() + score_module._hparams_for_scoring.add(self) + @property def active(self) -> HPType: """Return the currently active value.""" @@ -240,53 +253,103 @@ def active(self, val: HPType | None): @property def importance(self) -> dict: - """Return the importance dict mapping recipe and importance.""" - return { - quant_recipe: sum(v.cpu().item() for v in importance_dict.values()) - for quant_recipe, importance_dict in self._importance_dict.items() - } + """Raises an error since this is not a useful abstraction for AutoQuantize.""" + raise NotImplementedError + + def get_score(self, recipe: QuantRecipe) -> float: + """Get the score for a given recipe.""" + total_score = 0 + for score_module in self.score_modules: + importance = self._importance_dict[recipe][score_module] + if importance is None: + continue + parallel_state = getattr(score_module, "parallel_state", None) -def _add_auto_quantize_score(grad_output, output_diff, score_tensor): - score_tensor += ((grad_output.float() ** 2) * (output_diff.float() ** 2)).sum() + if parallel_state is None: + total_score += importance.cpu().item() + continue + if parallel_state.expert_model_parallel_group.is_initialized(): + # TODO: Support expert model parallelism for score estimation + warnings.warn("AutoQuantize does not support expert model parallelism yet.") + importance = DistributedProcessGroup.get_dist_syncd_obj( + importance, + [parallel_state.tensor_parallel_group, parallel_state.data_parallel_group], + sum, + ) + total_score += importance.cpu().item() + return total_score -class AutoQuantizeSearcher(BaseSearcher): - """A searcher for AutoQuantize algorithm. + def get_cost(self, recipe: QuantRecipe) -> float: + """Get the cost for a given recipe. - In AutoQuantize, we search for the best per-layer quantization configuration that minimizes the sum of per-layer - scores while meeting the specified constraint. AutoQuantize uses Linear Programming Solver to find the - optimal quantization configuration. + The cost is the total weight size of the quantizable modules multiplied by + the compression ratio of the recipe. + """ + cost = 0 + for quant_module in self.quant_modules: + weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size([quant_module]) + parallel_state = getattr(quant_module, "parallel_state", None) - The auto_quantize score for a layer quantization configuration is an approximation of model loss change change due - to quantizing the particular layer with the particular configuration. - The approximation is based on taylor expansion of the loss function wrt to the quantized output of the layer and - substitution of Fisher information for Hessian. - This approximation is mathematically correct for models where the loss - is a log likelihood loss such as BERT, GPT, etc. However, the auto_quantize score can still be used as a proxy - for other models such as ResNet. - """ + if parallel_state is None: + cost += weight_size * recipe.compression + continue + + if parallel_state.expert_model_parallel_group.is_initialized(): + # TODO: Support expert model parallelism + warnings.warn("AutoQuantize does not support expert model parallelism yet.") + + weight_size = DistributedProcessGroup.get_dist_syncd_obj( + weight_size, + [parallel_state.tensor_parallel_group], + sum, + ) + + # Across data parallel groups, the weight size is the same for all the ranks. + weight_size = DistributedProcessGroup.get_dist_syncd_obj( + weight_size, + [parallel_state.data_parallel_group], + lambda a: a[0], + ) + cost += weight_size * recipe.compression + + return cost + + @property + def attrs(self) -> list[str]: + """Return the attributes of the hparam for repr.""" + return ["name", *super().attrs] + + +class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): + """Base searcher for AutoQuantize algorithm.""" + + # This searcher finds optimal per-layer quantization by searching across quantization formats + # for each quantizable module (quant module). Optionally, quant grouping rules can restrict + # certain modules to share the same format. Sensitivity scores are computed from perturbations + # at score modules. See AutoQuantizeGradientSearcher for detailed documentation. candidate_stats: dict[str, dict[str, list[float]]] best: dict[str, Any] - custom_support: list[tuple[Callable, Callable, Callable]] = [] - rules = [ + quant_grouping_rules = [ r"^(.*?)\.(q_proj|k_proj|v_proj)$", # q_proj, k_proj, v_proj for llama like models + # gate_proj, up_proj, down_proj for Qwen3 like MoE models + r"^(.*?\.mlp\.experts)\.\d+\.(gate_proj|up_proj|down_proj)$", r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts ] + score_module_rules = [] + @property def default_search_config(self): """Get the default config for the searcher.""" return { "quantization_formats": ["NVFP4_DEFAULT_CFG", "FP8_DEFAULT_CFG"], "data_loader": None, - "forward_step": None, - "loss_func": None, - "forward_backward_step": None, "num_calib_steps": 512, "num_score_steps": 128, "deployment": None, @@ -301,15 +364,11 @@ def default_state_dict(self) -> SearchStateDict: return { "candidate_stats": defaultdict(dict), "best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False}, - "constraints": {}, } def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: """Sanitize the search config dict.""" config = config or {} - if "score_func" in config: - warnings.warn("`score_func` is ignored for `auto_quantize`.") - config.pop("score_func") config = super().sanitize_search_config(config) assert config["data_loader"] is not None, ( "`data_loader` must be provided for `auto_quantize`." @@ -317,13 +376,6 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: assert config["forward_step"] is not None, ( "`forward_step` must be provided for `auto_quantize`." ) - - if config["forward_backward_step"] is None: - assert config["loss_func"] is not None, ( - "`loss_func` or `forward_backward_step` must be provided for `auto_quantize`." - ) - config["forward_backward_step"] = self._get_default_forward_backward_step() - return config @staticmethod @@ -343,148 +395,72 @@ def _get_search_recipes(quantization_formats): } ) - @classmethod - def register_custom_support( - cls, - is_supported_checker: Callable, - grad_ckpt_context: Callable, - is_param_grad_enabled: Callable, - ): - """Register custom support for `AutoQuantize` score estimation. + def _apply_quant_group_rule(self, name: str, rule) -> str | None: + """Apply a single quant_group_rule to a module name. - If the `is_supported_checker(model)` returns True, the `grad_ckpt_context(model)` will be - used to enable gradient checkpointing and `is_param_grad_enabled(pname, model)` - will be used to enable gradient for the parameter. - """ - cls.custom_support.append((is_supported_checker, grad_ckpt_context, is_param_grad_enabled)) + Args: + name: Module name + rule: Either a regex pattern string or a callable that returns a unique key; + If callable, it should take the model and the name as input and return the unique key - def _get_default_forward_backward_step(self): - def forward_backward_step(model, data): - output = self.config["forward_step"](model, data) - loss = self.config["loss_func"](output, data) - try: - loss.backward() - except RuntimeError as e: - raise RuntimeError( - "AutoQuantize: Error while calling `backward()` on the loss returned by `loss_func`. " - "Please fix this!" - ) from e - - return forward_backward_step - - @torch.enable_grad() - def _estimate_auto_quantize_scores(self, is_param_grad_enabled): - # TODO: remove the no-quant recipe - def auto_quantize_score_estimate_forward(module, input, *args, **kwargs): - module.quant_recipe = QuantRecipe(quant_cfg=None) - output = module._forward_original(input, *args, **kwargs) - - # If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass. - # With gradient checkpointing, gradients are computed in the local forward pass during backward pass - - # Lets compute the output_diff and save it in memory only if gradient is enabled to be memory efficient - if not torch.is_grad_enabled(): - return output - - module.output_diff_dict = {} - with torch.no_grad(): - for recipe in module.get_hparam("quant_recipe").choices: - if recipe == QuantRecipe(quant_cfg=None): - continue - module.quant_recipe = recipe - output_diff = module._forward_original(input, *args, **kwargs) - - if isinstance(output_diff, tuple): - output_diff = output_diff[0] - output[0] - else: - output_diff -= output - module.output_diff_dict[recipe] = output_diff - - return output - - def backward_hook(module, grad_input, grad_output): - for recipe, output_diff in module.output_diff_dict.items(): - score_tensor = module.get_hparam("quant_recipe")._importance_dict[recipe][module] - _add_auto_quantize_score(grad_output[0], output_diff, score_tensor) - - del module.output_diff_dict + Returns: + The group key if the rule matches, None otherwise + """ + if callable(rule): + return rule(self.model, name) + else: + # Regex pattern + pattern = re.compile(rule) + match = pattern.match(name) + if match: + return match.group(1) + return None - def setup_params_for_score_estimation(name, param, params_metadata, enable_grad=True): - # Let us delete the gradient as soon as they are computed to save memory - # In addition, this method enables gradient for all parameters - # This is needed to make sure the re-entrant activation checkpointing works - params_metadata[name] = {"requires_grad": param.requires_grad} - param.requires_grad = enable_grad - if not enable_grad: - return - if self.config.get("verbose", False): - print_rank_0(f"AutoQuantize: Enabling gradient for param {name}.") - accum_grad, handle = create_param_grad_clear_hook(param) - params_metadata[name]["accum_grad"] = accum_grad # We need to keep the accum_grad alive - params_metadata[name]["handle"] = handle + def _apply_score_group_rule(self, name: str, rule) -> str | None: + """Apply a single score_group_rule to a module name. - def setup_module_for_score_estimation(module): - module._forward_original = module.forward - module.forward = types.MethodType(auto_quantize_score_estimate_forward, module) - module._backward_hook_handle = module.register_full_backward_hook(backward_hook) + Args: + name: Module name + rule: Either a regex pattern string or a callable that returns the score module name. + If callable, it should take the model and the name as input and return the score module name - def cleanup_module_after_score_estimation(module): - module.forward = module._forward_original - del module._forward_original + Returns: + The score module name if the rule matches, None otherwise + """ + if callable(rule): + return rule(self.model, name) + else: + # Regex pattern - return the matched name or full match + pattern = re.compile(rule) + match = pattern.match(name) + if match: + # For score rules, return the full match or first group + return match.group(0) if match.lastindex is None else match.group(1) + return None - module._backward_hook_handle.remove() + def _get_score_module_from_name( + self, model: nn.Module, score_module_name: str, quant_module: nn.Module + ) -> nn.Module: + """Get the actual score module object from its name. - def cleanup_params_after_score_estimation(name, param, params_metadata): - param.requires_grad = params_metadata[name]["requires_grad"] - handle = params_metadata[name].get("handle", None) - if handle is not None: - handle.remove() + Args: + model: The model containing all modules + score_module_name: The name of the score module to retrieve + quant_module: The quantized module for which the score is estimated - for name, module in self.model.named_modules(): - if ( - self._is_auto_quantize_module(module) - and module.get_hparam("quant_recipe").is_configurable - ): - # Monkey patch the forward methods to cache Y(Q(W), Q(X)) - Y(W,X) - setup_module_for_score_estimation(module) - - params_metadata = {} - - for name, param in self.model.named_parameters(): - setup_params_for_score_estimation( - name, param, params_metadata, is_param_grad_enabled(name, self.model) + Returns: + The score module object, or the quantized module itself if the score module is not found + """ + try: + score_module = model.get_submodule(score_module_name) + return score_module + except AttributeError: + warnings.warn( + f"Score module '{score_module_name}' not found. Score will estimated from the quantized module itself." ) + return quant_module - gc.collect() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - report_memory("AutoQuantize: starting score estimation, ") - - self._run_func( - self.config["forward_backward_step"], - num_iters=self.config["num_score_steps"], - desc="Estimating auto_quantize scores", - ) - - if torch.cuda.is_available(): - report_memory("AutoQuantize: After score estimation") - - for name, module in self.model.named_modules(): - if ( - self._is_auto_quantize_module(module) - and module.get_hparam("quant_recipe").is_configurable - ): - cleanup_module_after_score_estimation(module) - - for name, param in self.model.named_parameters(): - cleanup_params_after_score_estimation(name, param, params_metadata) - - # Delete the params_metadata - del params_metadata - gc.collect() - - @classmethod - def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers=None): + def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers=None): """Restrict the search space using the merge rules and insert the hparams for the model.""" # TRTLLM fuses linear layers such as q_proj, k_proj, v_proj into same layer # Hence we need to restrict the search space so that all these layers share the same recipe @@ -495,9 +471,11 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers= elif isinstance(disabled_layers, str): disabled_layers = [disabled_layers] - search_map: dict[str, list[tuple[nn.Module, bool]]] = {} + # Map from group key to list of (quant_module, name, disabled, score_module) + search_map: dict[str, list[tuple[nn.Module, str, bool, nn.Module]]] = {} + for name, module in model.named_modules(): - if not cls._is_auto_quantize_module(module): + if not self._is_auto_quantize_module(module): continue # Skip layers that match disabled_layers patterns @@ -507,28 +485,46 @@ def insert_hparams_after_merge_rules(cls, model, quant_recipes, disabled_layers= disabled = True break - prefix = name - for rule in cls.rules: - pattern = re.compile(rule) - match = pattern.match(name) - if match: - prefix = match.group(1) + # Apply quant_grouping_rules to determine the group key + group_key = name # Default: each module in its own group + for rule in self.quant_grouping_rules: + result = self._apply_quant_group_rule(name, rule) + if result is not None: + group_key = result # We support only one rule for matching per module break - if prefix not in search_map: - search_map[prefix] = [(module, disabled)] + + # Apply score_module_rules to determine the score module name, then get the actual module + score_module_name = name # Default: score from same module + for rule in self.score_module_rules: + result = self._apply_score_group_rule(name, rule) + if result is not None: + score_module_name = result + # We support only one rule for matching per module + break + + # Get the actual score module object immediately + score_module = self._get_score_module_from_name(model, score_module_name, module) + + if group_key not in search_map: + search_map[group_key] = [(module, name, disabled, score_module)] else: - search_map[prefix].append((module, disabled)) - - for prefix, module_info_list in search_map.items(): - modules = [module for module, _ in module_info_list] - disabled = any(disabled for _, disabled in module_info_list) - hparam = ( - QuantRecipeHparam(None, nn_modules=modules) - if disabled - else QuantRecipeHparam(quant_recipes, nn_modules=modules) + search_map[group_key].append((module, name, disabled, score_module)) + + for group_key, module_info_list in search_map.items(): + quant_modules = [module for module, _, _, _ in module_info_list] + disabled = any(disabled for _, _, disabled, _ in module_info_list) + score_modules = [score_module for _, _, _, score_module in module_info_list] + + quant_recipes = None if disabled else quant_recipes + hparam = QuantRecipeHparam( + quant_recipes, + quant_modules=quant_modules, + score_modules=score_modules, + name=str(group_key), ) - for module in modules: + + for module in quant_modules: module._register_hparam("quant_recipe", hparam) def _get_formatted_weight_compression_constraint(self): @@ -547,6 +543,33 @@ def _verify_constraint(self, search_recipes): f"{search_recipes[0]} whose num_bits = {search_recipes[0].num_bits}." ) + @abstractmethod + def estimate_sensitivity_scores(self) -> None: + """Estimate sensitivity scores and track them with Hparam.""" + + def initialize_candidate_stats(self): + """Initialize the candidate stats for the model.""" + for name, hparam in named_hparams(self.model, unique=True): + if not isinstance(hparam, QuantRecipeHparam): + continue + + formats, scores, costs = [], [], [] + prev_score = float("inf") + for recipe in hparam.choices: + formats.append(recipe) + + score = hparam.get_score(recipe) # type: ignore [arg-type] + cost = hparam.get_cost(recipe) # type: ignore [arg-type] + + score = min(score, prev_score) # TODO: Should we get rid of this? + scores.append(score) + costs.append(cost) + prev_score = score + + self.candidate_stats[name]["formats"] = formats + self.candidate_stats[name]["scores"] = scores + self.candidate_stats[name]["costs"] = costs + def _run_func(self, func, num_iters=1, desc=""): for i, data in tqdm( zip(range(num_iters), self.config["data_loader"]), @@ -572,7 +595,7 @@ def before_search(self): # Iterate over the search recipes and calibrate the quantizers for each recipe for recipe in search_recipes: - if recipe.compression >= 1.0: + if recipe == QuantRecipe(quant_cfg=None): # No-quant format continue # Lets reduce the number of calibration steps for AWQ since it takes longer @@ -605,6 +628,330 @@ def forward_loop(model): # TODO: This is a hack. We need to create a mode for auto_quantize to handle this in a clean way. ModeloptStateManager(self.model).state_dict().pop() + if self.candidate_stats: + if self.config["verbose"]: + print_rank_0("AutoQuantize: Restored from checkpoint, skipping scoring") + return + + self.estimate_sensitivity_scores() + self.initialize_candidate_stats() + # Save checkpoint after successful score estimation + self.save_search_checkpoint(verbose=self.config["verbose"]) + + @staticmethod + def _get_total_weight_size(modules): + return sum( + ( + module.weight.numel() + if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module) + else 0 + ) + for module in modules + ) + + def _get_constraints_for_search(self, max_weight_size, lower_bound=None): + constraints = { + "weight_size_after_compression": ( + lower_bound * max_weight_size if lower_bound else lower_bound, + max_weight_size, + ) + } + return constraints, "weight_size_after_compression" + + @abstractmethod + def run_search_with_stats(self, max_weight_size, verbose=False): + """Run the search with stats to get the best recipe and whether the constraints are satisfied.""" + + def run_search(self): + """Search for the best per-layer quantization configuration and return the best model and configuration.""" + verbose = self.config["verbose"] + assert len(self.constraints) == 1 and "effective_bits" in self.constraints, ( + f"`constraints` must contain only 'effective_bits' constraint. " + f"Got {self.constraints.keys()}" + ) + + compression = self._get_formatted_weight_compression_constraint() + total_weight_size = self._get_total_weight_size(self.model.modules()) + max_weight_size = total_weight_size * compression + + # Run the search with stats to get the best recipe and whether the constraints are satisfied + best_recipe_info, is_satisfied = self.run_search_with_stats(max_weight_size, verbose) + self.best["is_satisfied"] = is_satisfied + + best_recipe = {} + best_constraints, best_scores = 0, 0 + for name, best_hparam_recipe_info in best_recipe_info.items(): + # Solvers could give different solutions for the same layer across DP/TP groups even though + # the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP + _ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state + best_format = DistributedProcessGroup.get_dist_syncd_obj( + best_hparam_recipe_info["format"], + [_ps.data_parallel_group, _ps.tensor_parallel_group], + lambda a: a[0], + ) + + best_recipe[name] = best_format + get_hparam(self.model, name).active = best_format + best_constraints += best_hparam_recipe_info["costs"] + best_scores += best_hparam_recipe_info["scores"] + if verbose: + print_rank_0( + f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}" + ) + + effective_bits_from_search = (best_constraints / total_weight_size) * 16 + if verbose: + print_rank_0( + f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}" + ) + + self.best["recipe"] = best_recipe + self.best["constraints"] = {"effective_bits": effective_bits_from_search} + self.best["score"] = best_scores + + QuantRecipe.fold_pqs_to_weights(self.model) + + +def _get_auto_quantize_score(grad_output, output_diff): + return ((grad_output.float() ** 2) * (output_diff.float() ** 2)).sum() + + +def _add_auto_quantize_score(grad_output, output_diff, score_tensor): + score_tensor += _get_auto_quantize_score(grad_output, output_diff) + + +class AutoQuantizeGradientSearcher(_AutoQuantizeBaseSearcher): + """A searcher for AutoQuantize algorithm that uses gradient based score estimation. + + In AutoQuantize, we search for the best per-layer quantization configuration that minimizes the sum of per-layer + scores while meeting the specified constraint. AutoQuantize uses Linear Programming Solver to find the + optimal quantization configuration. + + The auto_quantize score for a layer quantization configuration is an approximation of model loss change due + to quantizing the particular layer with the particular configuration. + The approximation is based on taylor expansion of the loss function wrt to the quantized output of the layer and + substitution of Fisher information for Hessian. + This approximation is mathematically correct for models where the loss + is a log likelihood loss such as BERT, GPT, etc. However, the auto_quantize score can still be used as a proxy + for other models such as ResNet. + + **Quant Modules:** + + This searcher operates on quantizable modules (quant modules), which are typically Linear or Conv layers + that support quantization. Optionally, grouping rules can be applied to ensure certain layers share the same + quantization format (e.g., Q, K, V projections in the same attention layer). For details on quant_grouping_rules + and customization, see the :meth:`auto_quantize ` + API documentation. + + **Score Modules:** + + By default, for each quant module, its sensitivity score is estimated using that module's output perturbation. + However, the sensitivity can also be estimated by looking at perturbation at a separate point in the neural + network (score module). This is helpful in some cases such as MoEs for speed and lower memory consumption. + Since all experts are already restricted to the same quant format by quant grouping rules, their sensitivity + can be estimated together at a single point (e.g., the MLP output level). + """ + + score_module_rules = [ + # Use MLP layer output for gate_proj, up_proj, down_proj for Qwen3 like MoE models (local and shared experts) + r"^(.*?\.mlp\.experts)\.\d+\.(gate_proj|up_proj|down_proj)$", + r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts + r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts + ] + + # See `register_custom_support` for details + _custom_support: list[tuple[Callable, Callable, Callable]] = [] + + @property + def default_search_config(self): + """Get the default config for the searcher.""" + config = super().default_search_config + config.update( + { + "forward_step": None, + "loss_func": None, + "forward_backward_step": None, + } + ) + return config + + def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: + """Sanitize the search config dict.""" + config = config or {} + if "score_func" in config: + warnings.warn("`score_func` is ignored for gradient based `auto_quantize`.") + config.pop("score_func") + config = super().sanitize_search_config(config) + if config["forward_backward_step"] is None: + assert config["loss_func"] is not None, ( + "`loss_func` or `forward_backward_step` must be provided for `auto_quantize`." + ) + config["forward_backward_step"] = self._get_default_forward_backward_step() + + return config + + @classmethod + def register_custom_support( + cls, + is_supported_checker: Callable, + grad_ckpt_context: Callable, + is_param_grad_enabled: Callable, + ) -> None: + """(Optional) Register custom support for `AutoQuantize` score estimation. + + This custom support is used to enable memory/compute efficient backward gradient propagation. This involves: + + - `grad_ckpt_context`: backward pass with gradient checkpointing enabled + - `is_param_grad_enabled`: AutoQuantize only needs activation gradients to be computed (not weight + gradients). `is_param_grad_enabled` is used to select which parameters should have gradients enabled, + limiting gradient computation to only what's needed for activation gradients. For LLMs, to trigger all + activation gradient computation, just enabling the embedding layer weight gradient is sufficient. This will + enable gradient computation for all the activation gradients downstream. + + If the `is_supported_checker(model)` returns True, the `grad_ckpt_context(model)` will be + used to enable gradient checkpointing and `is_param_grad_enabled(pname, model)` + will be used to select which parameters have gradients enabled to minimize gradient computation. + """ + cls._custom_support.append((is_supported_checker, grad_ckpt_context, is_param_grad_enabled)) + + def _get_default_forward_backward_step(self): + def forward_backward_step(model, data): + output = self.config["forward_step"](model, data) + loss = self.config["loss_func"](output, data) + try: + loss.backward() + except RuntimeError as e: + raise RuntimeError( + "AutoQuantize: Error while calling `backward()` on the loss returned by `loss_func`. " + "Please fix this!" + f"error: {e}" + ) from e + + return forward_backward_step + + @torch.enable_grad() + def _estimate_auto_quantize_scores(self, is_param_grad_enabled): + # TODO: remove the no-quant recipe + def auto_quantize_score_estimate_forward(module, input, *args, **kwargs): + for hparam in module._hparams_for_scoring: + if hparam.is_configurable: + hparam.active = QuantRecipe(quant_cfg=None) + + output = module._forward_original(input, *args, **kwargs) + + # If gradient checkpointing is enabled, gradient will not be enabled in the global forward pass. + # With gradient checkpointing, gradients are computed in the local forward pass during backward pass + + # Lets compute the output_diff and save it in memory only if gradient is enabled to be memory efficient + if not torch.is_grad_enabled(): + return output + + module.output_diff_dict = {hparam: {} for hparam in module._hparams_for_scoring} + with torch.no_grad(): + for hparam in module._hparams_for_scoring: + if not hparam.is_configurable: + continue + for recipe in hparam.choices: + if recipe == QuantRecipe(quant_cfg=None): + continue + hparam.active = recipe + output_diff = module._forward_original(input, *args, **kwargs) + + if isinstance(output_diff, tuple): + output_diff = output_diff[0] - output[0] + else: + output_diff -= output + module.output_diff_dict[hparam][recipe] = output_diff.detach() + + # Disable the configurable hparam now that we have computed the diff + hparam.active = QuantRecipe(quant_cfg=None) + + return output + + def backward_hook(module, grad_input, grad_output): + for hparam, output_diff_dict in module.output_diff_dict.items(): + for recipe, output_diff in output_diff_dict.items(): + if hparam._importance_dict[recipe][module] is None: + hparam._importance_dict[recipe][module] = _get_auto_quantize_score( + grad_output[0], output_diff + ) + else: + _add_auto_quantize_score( + grad_output[0], output_diff, hparam._importance_dict[recipe][module] + ) + + def setup_params_for_score_estimation(name, param, params_metadata, enable_grad=True): + # Let us delete the gradient as soon as they are computed to save memory + params_metadata[name] = {"requires_grad": param.requires_grad} + param.requires_grad = enable_grad + if not enable_grad: + return + if self.config.get("verbose", False): + print_rank_0(f"AutoQuantize: Enabling gradient for param {name}.") + accum_grad, handle = create_param_grad_clear_hook(param) + params_metadata[name]["accum_grad"] = accum_grad # We need to keep the accum_grad alive + params_metadata[name]["handle"] = handle + + def setup_module_for_score_estimation(module): + module._forward_original = module.forward + module.forward = types.MethodType(auto_quantize_score_estimate_forward, module) + module._backward_hook_handle = module.register_full_backward_hook(backward_hook) + + def cleanup_module_after_score_estimation(module): + module.forward = module._forward_original + del module._forward_original + + module._backward_hook_handle.remove() + + def cleanup_params_after_score_estimation(name, param, params_metadata): + param.requires_grad = params_metadata[name]["requires_grad"] + handle = params_metadata[name].get("handle", None) + if handle is not None: + handle.remove() + + score_modules = set() + for name, module in self.model.named_modules(): + if ( + hasattr(module, "_hparams_for_scoring") + and any(hparam.is_configurable for hparam in module._hparams_for_scoring) + and module not in score_modules + ): + # Monkey patch the forward methods to cache (Q(Y) - Y) + setup_module_for_score_estimation(module) + score_modules.add(module) + + params_metadata = {} + for name, param in self.model.named_parameters(): + setup_params_for_score_estimation( + name, param, params_metadata, is_param_grad_enabled(name, self.model) + ) + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + report_memory("AutoQuantize: starting score estimation, ") + + self._run_func( + self.config["forward_backward_step"], + num_iters=self.config["num_score_steps"], + desc="Estimating auto_quantize scores", + ) + + if torch.cuda.is_available(): + report_memory("AutoQuantize: After score estimation") + + for module in score_modules: + cleanup_module_after_score_estimation(module) + + for name, param in self.model.named_parameters(): + cleanup_params_after_score_estimation(name, param, params_metadata) + + # Delete the params_metadata + del params_metadata + gc.collect() + + def estimate_sensitivity_scores(self) -> None: + """Estimate sensitivity scores using hessian approximation.""" self.model.eval() def _default_is_param_grad_enabled(pname, model): @@ -612,7 +959,7 @@ def _default_is_param_grad_enabled(pname, model): grad_checkpointing_ctxt = None is_param_grad_enabled = _default_is_param_grad_enabled - for is_supported_checker, ctxt_candidate, grad_enabled_candidate in self.custom_support: + for is_supported_checker, ctxt_candidate, grad_enabled_candidate in self._custom_support: if is_supported_checker(self.model): grad_checkpointing_ctxt = ctxt_candidate is_param_grad_enabled = grad_enabled_candidate @@ -621,73 +968,21 @@ def _default_is_param_grad_enabled(pname, model): with grad_checkpointing_ctxt(self.model) if grad_checkpointing_ctxt else nullcontext(): self._estimate_auto_quantize_scores(is_param_grad_enabled) - def run_search(self): - """Search for the best per-layer quantization configuration and return the best model and configuration. + def run_search_with_stats(self, max_weight_size, verbose=False): + """Linear Programming Solve for gradient based auto_quantize. AutoQuantize uses Linear Programming Solver to find the optimal quantization configuration which minimizes the sum of per-layer auto_quantize scores while meeting the specified constraint. """ - - def get_total_weight_size(modules): - return sum( - (module.weight.numel() if self._is_auto_quantize_module(module) else 0) - for module in modules - ) - - def _get_constraints_for_search(max_weight_size, lower_bound=None): - constraints = { - "weight_size_after_compression": ( - lower_bound * max_weight_size if lower_bound else lower_bound, - max_weight_size, - ) - } - return constraints, "weight_size_after_compression" - - verbose = self.config["verbose"] - assert len(self.constraints) == 1 and "effective_bits" in self.constraints, ( - f"`constraints` must contain only 'effective_bits' constraint. " - f"Got {self.constraints.keys()}" - ) - - compression = self._get_formatted_weight_compression_constraint() - total_weight_size = get_total_weight_size(self.model.modules()) - weight_size_after_compression = total_weight_size * compression - - for name, hparam in named_hparams(self.model, unique=True): - if not isinstance(hparam, QuantRecipeHparam): - continue - - formats, scores, costs = [], [], [] - prev_score = float("inf") - for recipe in hparam.choices: - formats.append(recipe) - score = hparam.importance[recipe] - cost = get_total_weight_size(hparam.nn_modules) * recipe.compression # type: ignore [union-attr] - - # Lets get the score across Data Parallel (DP) and Tensor Parallel (TP) groups - # This way we constraint the same quantization format for the same layer across the DP/TP groups - # The cost we use here is weight size. They are the same across DP/TP groups. - _ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state - # The score is the sum of the scores across DP and TP groups - score = DistributedProcessGroup.get_dist_syncd_obj( - score, [_ps.data_parallel_group, _ps.tensor_parallel_group], sum - ) - - scores.append(min(score, prev_score)) - costs.append(cost) - prev_score = score - - self.candidate_stats[name]["formats"] = formats - self.candidate_stats[name]["scores"] = scores - self.candidate_stats[name]["costs"] = costs + # TODO: Do this only for rank 0 in the respective pipeline group for lower_bound in [None, 0.99, 0.90]: # The LP solver for auto_quantize sometimes fails to find a solution if a lower bound is not # specified. I dont know why this happens. # As a workaround, lets specify a lower bound for the weight compression if previous # search without lower bound fails. - constraints, constraint_name = _get_constraints_for_search( - weight_size_after_compression, lower_bound + constraints, constraint_name = self._get_constraints_for_search( + max_weight_size, lower_bound ) lps = LPS( @@ -708,47 +1003,251 @@ def _get_constraints_for_search(max_weight_size, lower_bound=None): if self.status == "Optimal": break - self.best = {} - if self.status != "Optimal": warnings.warn( "AutoQuantize FAILED to find a solution! The searched model might not meet all constraints. " ) - self.best["is_satisfied"] = False + is_satisfied = False else: - self.best["is_satisfied"] = True + is_satisfied = True - best_recipe = {} - best_constraints, best_scores = 0, 0 + best_recipes = {} for name, selected_idx in zip(self.candidate_stats.keys(), selections): - best_recipe_for_name = self.candidate_stats[name]["formats"][selected_idx] + best_recipes[name] = { + "format": self.candidate_stats[name]["formats"][selected_idx], + "costs": self.candidate_stats[name]["costs"][selected_idx], + "scores": self.candidate_stats[name]["scores"][selected_idx], + } - # LP solver could give different solutions for the same layer across DP/TP groups even though - # the scores and costs are the same. Lets make sure the same quantization format is selected across DP/TP - _ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state - best_recipe_for_name = DistributedProcessGroup.get_dist_syncd_obj( - best_recipe_for_name, - [_ps.data_parallel_group, _ps.tensor_parallel_group], - lambda a: a[0], + return best_recipes, is_satisfied + + +# TODO: Enable torch compile for this function +# Currently modelopt.onnx is breaking this +def _get_softmax_dist( + logits: torch.Tensor, tp_group, return_log_prob: bool = False +) -> torch.Tensor: + # TODO: test this + dtype = logits.dtype + max_logits = torch.amax(logits, dim=-1, keepdim=True) + torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=tp_group) + logits = (logits - max_logits).float() + sum_exp_logits = torch.exp(torch.logsumexp(logits, dim=-1, keepdim=True)) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group) + logits = logits - torch.log(sum_exp_logits) + if return_log_prob: + return logits.to(dtype) + else: + return torch.exp(logits).to(dtype) + + +def _get_softmax(logits: torch.Tensor, return_log_prob: bool = False) -> torch.Tensor: + # TODO: do we need to do log_softmax in float32? + # log_softmax is supposed to be numerically stable implementation + log_prob = torch.log_softmax(logits.float(), dim=-1) + if return_log_prob: + return log_prob + else: + return torch.exp(log_prob) + + +def _get_p_log_q(p: torch.Tensor, log_q: torch.Tensor) -> torch.Tensor: + return torch.sum(p * log_q).float() + + +def _get_prob_from_logits( + logits: torch.Tensor, return_log_prob: bool = False, lm_head: nn.Module = None +) -> torch.Tensor: + parallel_state: ParallelState | None = ( + getattr(lm_head, "parallel_state", None) if lm_head is not None else None + ) + if parallel_state is not None and parallel_state.tensor_parallel_group.is_initialized(): + return _get_softmax_dist( + logits, parallel_state.tensor_parallel_group.group, return_log_prob + ) + return _get_softmax(logits, return_log_prob) + + +def _get_kl_div_loss( + prob_unquant: torch.Tensor, logits_quant: torch.Tensor, lm_head: nn.Module = None +) -> torch.Tensor: + log_prob_quant = _get_prob_from_logits(logits_quant, return_log_prob=True, lm_head=lm_head) + # We dont need to calculate the full kl div loss here, just get - p*log_q + return -_get_p_log_q(prob_unquant, log_prob_quant) + + +def _get_lm_head(model: nn.Module) -> nn.Module: + # HF models do allgather of logits to at lm_head + # Hence lm_head outputs are not TP sharded - so we dont need to return the lm_head for TP KLDiv + # Loss + for name, module in model.named_modules(): + if name.endswith("output_layer"): # Megatron models + return module + return None + + +class AutoQuantizeKLDivSearcher(_AutoQuantizeBaseSearcher): + """A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation.""" + + @property + def default_search_config(self): + """Get the default config for the searcher.""" + config = super().default_search_config + config.update( + { + "forward_step": None, + } + ) + return config + + def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: + """Sanitize the search config dict.""" + config = config or {} + for ignored_key in ["score_func", "loss_func", "forward_backward_step"]: + if ignored_key in config: + if config[ignored_key] is not None: + warnings.warn( + f"`{ignored_key}` is ignored for KL-Divergence loss based `auto_quantize`." + ) + config.pop(ignored_key) + config = super().sanitize_search_config(config) + assert config["forward_step"] is not None, ( + "`forward_step` must be provided for KL-Divergence loss based `auto_quantize`. " + "`forward_step(model, data)` should return model logits." + ) + return config + + @torch.inference_mode() + def estimate_sensitivity_scores(self): + """Estimate the sensitivity scores for the model. + + Higher score means more sensitive to quantization. + """ + + def set_to_unquantized(): + for name, hparam in named_hparams(self.model, unique=True): + if not isinstance(hparam, QuantRecipeHparam): + continue + if hparam.is_configurable: + hparam.active = QuantRecipe(quant_cfg=None) + + self.model.eval() + num_iters = self.config["num_score_steps"] + for _, data in tqdm( + zip(range(num_iters), self.config["data_loader"]), + desc="Estimating KLDivergence loss", + total=num_iters, + ): + set_to_unquantized() + logits_unquant = self.config["forward_step"](self.model, data) + prob_unquant = _get_prob_from_logits( + logits_unquant, + return_log_prob=False, + lm_head=_get_lm_head(self.model), ) - best_recipe[name] = best_recipe_for_name - get_hparam(self.model, name).active = best_recipe_for_name - best_constraints += self.candidate_stats[name]["costs"][selected_idx] - best_scores += self.candidate_stats[name]["scores"][selected_idx] + for name, hparam in tqdm( + list(named_hparams(self.model, configurable=True)), desc="Evaluating hparams" + ): + if not isinstance(hparam, QuantRecipeHparam): + continue + for recipe in hparam.choices: + if recipe == QuantRecipe(quant_cfg=None): + continue + hparam.active = recipe + logits_quant = self.config["forward_step"](self.model, data) + score = _get_kl_div_loss(prob_unquant, logits_quant, _get_lm_head(self.model)) + if hparam._importance_dict[recipe][hparam.score_modules[0]] is None: + hparam._importance_dict[recipe][hparam.score_modules[0]] = score + else: + hparam._importance_dict[recipe][hparam.score_modules[0]] += score + hparam.active = QuantRecipe(quant_cfg=None) + + def run_search_with_stats(self, max_weight_size, verbose=False): + """Run threshold-based binary search for KLDivergence loss based auto_quantize. + + We use binary search to minimize the max(per-layer score) while meeting the constraint. + """ + # Collect all sensitivity scores to determine initial threshold bounds + all_scores = [ + score for name in self.candidate_stats for score in self.candidate_stats[name]["scores"] + ] + + if not all_scores: + warnings.warn("No scores available for threshold-based search!") + is_satisfied = False + return {}, is_satisfied + + # Initialize binary search bounds + min_score = min(all_scores) + max_score = max(all_scores) + threshold = (min_score + max_score) / 2.0 + lower_bound = min_score + upper_bound = max_score + + # Run for fixed number of iterations + max_iterations = 100 + + if verbose: + print_rank_0("AutoQuantize: Starting threshold-based binary search") + print_rank_0(f" Score range: [{min_score:.6e}, {max_score:.6e}]") + print_rank_0(f" Target weight size: {max_weight_size:.2f}") + + for iteration in range(max_iterations): + # Select recipes based on current threshold + best_recipes = {} + total_weight_size = 0.0 + + for name in self.candidate_stats: + formats = self.candidate_stats[name]["formats"] + scores = self.candidate_stats[name]["scores"] + costs = self.candidate_stats[name]["costs"] + + selected_idx = 0 + for idx in range(len(formats)): + if scores[idx] <= threshold: + selected_idx = idx + break + + best_recipes[name] = { + "format": formats[selected_idx], + "costs": costs[selected_idx], + "scores": scores[selected_idx], + } + total_weight_size += costs[selected_idx] + + # Check if we meet the constraint + meets_constraint = total_weight_size <= max_weight_size + if verbose: print_rank_0( - f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}" + f" Iteration {iteration + 1}: threshold={threshold:.6e}, " + f"weight_size={total_weight_size:.2f}, " + f"meets_constraint={meets_constraint}" ) - effective_bits_from_search = (best_constraints / total_weight_size) * 16 + # Update binary search bounds + if meets_constraint: + upper_bound = threshold # Threshold was too aggressive, relax it + else: + lower_bound = threshold # Threshold was too lax, tighten it + + # Update threshold for next iteration + threshold = (lower_bound + upper_bound) / 2.0 + + # Final check if constraint is satisfied + is_satisfied = total_weight_size <= max_weight_size + if verbose: print_rank_0( - f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}" + f"AutoQuantize: Search complete. " + f"Final weight size: {total_weight_size:.2f} " + f"(target: {max_weight_size:.2f}), " + f"constraint satisfied: {is_satisfied}" ) - self.best["recipe"] = best_recipe - self.best["constraints"] = {"effective_bits": effective_bits_from_search} - self.best["score"] = best_scores + return best_recipes, is_satisfied - QuantRecipe.fold_pqs_to_weights(self.model) + +# Backward compatibility alias (defaults to gradient-based searcher) +AutoQuantizeSearcher = AutoQuantizeGradientSearcher diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index deace8e0c..4a2b74a30 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -31,7 +31,7 @@ from modelopt.torch.quantization.config import QuantizeConfig from modelopt.torch.quantization.conversion import set_quantizer_by_cfg -from .algorithms import AutoQuantizeSearcher, QuantRecipe +from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe from .config import QuantizeAlgoCfgType from .conversion import set_quantizer_attribute from .mode import QuantizeModeRegistry, get_modelike_from_algo_cfg @@ -231,6 +231,12 @@ def forward_loop(model) -> None: return calibrate(model, config["algorithm"], forward_loop=forward_loop) +# TODO: create a config interface for auto_quantize and expose setting +# quant_grouping_rules and score_module_rules as part of the config. +# This will allow users to customize the grouping and scoring rules for their models. +# This way wecan limit the granularity of quantization search. For example, +# - limit the quantization format search to decoder block level (instead of each linear layer level) +# - Same format for all self attention layers of a model etc. def auto_quantize( model: nn.Module, constraints: dict[str, float | str] = {"effective_bits": 4.8}, @@ -246,11 +252,23 @@ def auto_quantize( num_calib_steps: int = 512, num_score_steps: int = 128, verbose: bool = False, + method: str = "gradient", + checkpoint: str | None = None, ): r"""Perform optimal per-layer quantization by searching for the best quantization formats per-layer. - ``auto_quantize`` uses a gradient based sensitivity score to rank the per-layer quantization formats and search - for the best quantization formats per-layer. + ``auto_quantize`` uses sensitivity scores to rank the per-layer quantization formats and search + for the best quantization formats per-layer. The sensitivity score can be computed using gradient-based + methods (default) or KL divergence loss, controlled by the ``method`` parameter. + + Internally this API runs two main phases: + + #. Calibrate the quantized model exactly like :func:`quantize` would. + #. Estimate per-layer sensitivity scores to decide which format to keep. + + The sensitivity scoring phase typically dominates the runtime of ``auto_quantize``, so decreasing the number of + samples used for scoring (see ``num_score_steps``) is the recommended way for improving overall auto_quantize time + with minimal accuracy impact. Args: model: A pytorch model with quantizer modules. @@ -369,10 +387,20 @@ def forward_backward_step(model, batch) -> None: disabled_layers = "*lm_head*" disabled_layers = ["*lm_head*", "*mlp*"] - num_calib_steps: Number of batches to use for calibrating the quantized model. Suggested value is 512. + num_calib_steps: Number of batches to use for calibrating each candidate quantization format. Suggested value + is 512. num_score_steps: Number of batches to use for estimating ``auto_quantize`` scores. Suggested value is 128. - A higher value could increase the time taken for performing ``auto_quantize``. + A higher value could increase the time taken for performing ``auto_quantize``; reducing it speeds up the + sensitivity score estimation phase and typically affects accuracy less than lowering ``num_calib_steps``. verbose: If True, prints the search progress/intermediate results. + method: Method to use for estimating sensitivity loss. Higher loss indicates greater sensitivity + to quantization. Options are ``"gradient"`` (default; uses gradient-based loss estimation, + linear programming search, and requires ``loss_func`` or ``forward_backward_step``) and + ``"kl_div"`` (uses KL divergence between unquantized and quantized outputs, relies on + threshold-based binary search, and only requires ``forward_step`` returning logits). + checkpoint: (Optional) Path to checkpoint file for saving/restoring auto_quantize search state. + If the checkpoint file exists, the search state will be restored from it, skipping the + expensive score estimation step. Returns: A tuple (model, state_dict) where ``model`` is the searched and quantized model and ``state_dict`` contains the history and detailed stats of the search procedure. @@ -384,23 +412,32 @@ def forward_backward_step(model, batch) -> None: This is to ensure compatibility with TensorRT-LLM which fuses these three linear layers into a single linear layer. - A list of regex pattern rules as defined in :attr:`rules <.algorithms.AutoQuantizeSearcher.rules>` - are used to specify the group of layers. The first captured group - in the regex pattern (i.e, ``pattern.match(name).group(1)``) is used to group the layers. All the layers - that share the same first captured group will have the same quantization format.. + Grouping rules are defined in :attr:`quant_grouping_rules + <.algorithms.AutoQuantizeSearcher.quant_grouping_rules>`. + Each rule can be either a regex pattern or a callable function. + + - **Regex patterns**: The first captured group (e.g., + ``pattern.match(name).group(1)``) determines the group key. + Layers with the same group key share the same quantization format. + - **Functions**: Should take a module name and return a group key + (or ``None`` if the rule doesn't apply). - For example, the rule ``r"^(.*?)\.(q_proj|k_proj|v_proj)$"`` - groups the `q_proj`, `k_proj`, `v_proj` linear layers belonging to the same transformer layer. + Example regex rule: ``r"^(.*?)\.(q_proj|k_proj|v_proj)$"`` groups the + `q_proj`, `k_proj`, `v_proj` layers belonging to the same transformer layer. - You may modify the rules to group the layers as per your requirement. + You can customize the rules as needed: .. code-block:: python from modelopt.torch.quantization.algorithms import AutoQuantizeSearcher - # To additionally group the layers belonging to same `mlp` layer, - # add the following rule - AutoQuantizeSearcher.rules.append(r"^(.*?)\.mlp") + # Add a regex rule to group layers in the same `mlp` module + AutoQuantizeSearcher.quant_grouping_rules.append(r"^(.*?)\.mlp") + + # Or add a function rule for custom logic + AutoQuantizeSearcher.quant_grouping_rules.append( + lambda name: name.rsplit(".", 1)[0] if "expert" in name else None + ) # Perform `auto_quantize` model, state_dict = auto_quantize(model, ...) @@ -426,12 +463,20 @@ def forward_backward_step(model, batch) -> None: processed_quantization_formats.append((quant_cfg, name)) assert len(processed_quantization_formats) > 0, "`quantization_formats` should not be empty" + + # Select the appropriate searcher based on method + if method == "gradient": + searcher = AutoQuantizeGradientSearcher() + elif method == "kl_div": + searcher = AutoQuantizeKLDivSearcher() + else: + raise ValueError(f"Invalid method: {method}. Valid options are 'gradient' or 'kl_div'.") + model = apply_mode( model, mode="auto_quantize", registry=QuantizeModeRegistry, ) - searcher = AutoQuantizeSearcher() search_config = { "quantization_formats": processed_quantization_formats, "data_loader": data_loader, @@ -442,6 +487,7 @@ def forward_backward_step(model, batch) -> None: "num_score_steps": num_score_steps, "disabled_layers": disabled_layers, "verbose": verbose, + "checkpoint": checkpoint, } # Disable all quantizers; AutoQuantize will enable the needed ones set_quantizer_by_cfg(model, {"*": {"enable": False}}) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index af46b6d26..31ac2bbbd 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -35,7 +35,7 @@ from modelopt.torch.opt.dynamic import DynamicModule from modelopt.torch.utils.distributed import ParallelState -from ..algorithms import AutoQuantizeSearcher +from ..algorithms import AutoQuantizeGradientSearcher from ..conversion import register from ..nn import QuantInputBase, QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import _QuantLinear @@ -745,7 +745,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): return "embed" in pname -AutoQuantizeSearcher.register_custom_support( +AutoQuantizeGradientSearcher.register_custom_support( _is_supported_hf_model, setup_model_for_gradient_checkpointing, _is_param_grad_enabled_for_auto_quantize, diff --git a/tests/gpu/torch/export/test_export_weight.py b/tests/gpu/torch/export/test_export_weight_gpu.py similarity index 100% rename from tests/gpu/torch/export/test_export_weight.py rename to tests/gpu/torch/export/test_export_weight_gpu.py diff --git a/tests/unit/torch/quantization/plugins/test_huggingface.py b/tests/unit/torch/quantization/plugins/test_huggingface.py index ab59e663e..a3e72ffa7 100644 --- a/tests/unit/torch/quantization/plugins/test_huggingface.py +++ b/tests/unit/torch/quantization/plugins/test_huggingface.py @@ -15,6 +15,7 @@ import os import warnings +from contextlib import nullcontext import pytest import torch @@ -136,28 +137,43 @@ def test_dbrx(): assert torch.allclose(out_1[0], out_2[0]) -def test_autoquantize_huggingface(): +@pytest.mark.parametrize( + "method", + ["gradient", "kl_div"], +) +def test_autoquantize_huggingface(method): model = get_tiny_llama() input_ids = model.dummy_inputs["input_ids"] + def forward_step(model, batch): + return model(**batch) if method == "gradient" else model(**batch).logits + warnings.filterwarnings( "error", message="AutoQuantize: Error enabling gradient checkpointing for huggingface model" ) - with pytest.warns( - UserWarning, - match="AutoQuantize: Huggingface model detected - Enabling gradient checkpointing. ", - ): + # Gradient checkpointing warning should only appear for gradient-based method + context = ( + pytest.warns( + UserWarning, + match="AutoQuantize: Huggingface model detected - Enabling gradient checkpointing. ", + ) + if method == "gradient" + else nullcontext() + ) + + with context: best_model, search_history = mtq.auto_quantize( model, constraints={"effective_bits": 11.0}, quantization_formats=[mtq.INT8_DEFAULT_CFG], data_loader=[{"input_ids": input_ids, "labels": input_ids} for _ in range(2)], - forward_step=lambda model, batch: model(**batch), + forward_step=forward_step, loss_func=lambda output, data: output.loss, num_calib_steps=2, num_score_steps=2, verbose=True, + method=method, ) diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index b08d7fa5e..1a5cfee32 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -92,7 +92,7 @@ def test_quant_recipe_hparam(): ] hparam = QuantRecipeHparam( search_recipes, - nn_modules=[model_test], + quant_modules=[model_test], ) model_test._register_hparam("quant_recipe", hparam) assert model_test.quant_recipe == QuantRecipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG) @@ -133,7 +133,11 @@ def test_quant_recipe_hparam(): ([None, mtq.INT8_SMOOTHQUANT_CFG], 8.0, 11.0), ], ) -def test_auto_quantize(model_cls, search_formats, min_bits, search_bits): +@pytest.mark.parametrize( + "method", + ["gradient", "kl_div"], +) +def test_auto_quantize(model_cls, search_formats, min_bits, search_bits, method): model = model_cls() def loss_func(output): @@ -149,6 +153,7 @@ def loss_func(output): num_calib_steps=2, num_score_steps=2, verbose=True, + method=method, ) assert isinstance(search_history, dict) assert search_history["best"]["is_satisfied"] @@ -178,7 +183,7 @@ def loss_func(output): assert torch.allclose(output_ref, output_test) -def test_auto_quantize_disable(): +def test_auto_quantize_disable_layers(): model = TransformerBlock() def loss_func(output): @@ -337,3 +342,58 @@ def test_estimate_quant_compression(): fp8_affine_kv_cfg = mtq.config.QuantizeConfig(**mtq.FP8_AFFINE_KV_CFG) assert estimate_quant_compression(fp8_affine_kv_cfg) == 0.5 + + +@pytest.mark.parametrize("method", ["gradient", "kl_div"]) +def test_auto_quantize_checkpoint_resume(method, tmp_path, capsys): + """Test that checkpoint can be used to resume an interrupted search.""" + model = SimpleLinear() + checkpoint_path = str(tmp_path / "autoquant_resume_checkpoint.pth") + + # First run: save checkpoint + model_1, state_dict_1 = mtq.auto_quantize( + model, + constraints={"effective_bits": 6.0}, + quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], + data_loader=[model.get_input() for _ in range(2)], + forward_step=lambda model, batch: model(batch), + loss_func=lambda output, data: output.sum(), + num_calib_steps=2, + num_score_steps=2, + verbose=True, + method=method, + checkpoint=checkpoint_path, + ) + + # Clear captured output from first run + capsys.readouterr() + + # Second run: resume with same constraint should produce same results + model_2 = SimpleLinear() + model_2, state_dict_2 = mtq.auto_quantize( + model_2, + constraints={"effective_bits": 6.0}, # Same constraint + quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], + data_loader=[model_2.get_input() for _ in range(2)], + forward_step=lambda model, batch: model(batch), + loss_func=lambda output, data: output.sum(), + num_calib_steps=2, + num_score_steps=2, + verbose=True, + method=method, + checkpoint=checkpoint_path, + ) + + # Verify the restore message was printed on second run + captured = capsys.readouterr() + assert "Restored from checkpoint, skipping scoring" in captured.out, ( + "Expected restore message when resuming from checkpoint" + ) + + # Results should be identical when using same constraint + assert state_dict_1["candidate_stats"] == state_dict_2["candidate_stats"] + assert state_dict_1["best"]["recipe"] == state_dict_2["best"]["recipe"] + assert ( + pytest.approx(state_dict_1["best"]["constraints"]["effective_bits"]) + == state_dict_2["best"]["constraints"]["effective_bits"] + ) From 671bbbb69be999e8e32d20eb2c68a9d5f04d94f1 Mon Sep 17 00:00:00 2001 From: inisis Date: Wed, 26 Nov 2025 14:24:37 +0800 Subject: [PATCH 2/8] feat: add onnxslim support Signed-off-by: inisis --- CHANGELOG.rst | 1 + modelopt/onnx/quantization/quantize.py | 13 +++---------- setup.py | 2 +- tests/gpu/onnx/test_simplify.py | 8 ++++---- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index beb01abf0..60c7ec5ca 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,7 @@ Model Optimizer Changelog (Linux) **Misc** - Bump minimum recommended transformers version to 4.53. +- Replace ONNX simplification package from 'onnxsim' to 'onnxslim'. 0.39 (2025-11-11) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 800124646..00b4f5d75 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -39,6 +39,7 @@ from typing import Any import onnx +import onnxslim import onnx.onnx_cpp2py_export.checker as C import onnx_graphsurgeon as gs @@ -133,16 +134,8 @@ def _preprocess_onnx( if simplify: logger.info("Attempting to simplify model") try: - import onnxsim - except ModuleNotFoundError as e: - logger.warning( - "onnxsim is not installed. Please install it with 'pip install onnxsim'." - ) - raise e - - try: - model_simp, check = onnxsim.simplify(onnx_model) - if check: + model_simp = onnxslim.slim(onnx_model, skip_fusion_patterns=["FusionGemm"]) + if model_simp: onnx_model = model_simp onnx_path = os.path.join(output_dir, f"{model_name}_simp.onnx") save_onnx(onnx_model, onnx_path, use_external_data_format) diff --git a/setup.py b/setup.py index 85b79e729..7befe9a47 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 "onnxruntime-directml==1.20.0; platform_system == 'Windows'", "onnxscript", # For autocast opset conversion and test_onnx_dynamo_export unit test - "onnxsim ; python_version < '3.12' and platform_machine != 'aarch64'", + "onnxslim>=0.1.75, "polygraphy>=0.49.22", ], "hf": [ diff --git a/tests/gpu/onnx/test_simplify.py b/tests/gpu/onnx/test_simplify.py index 3b6acccb6..5ca8449b3 100644 --- a/tests/gpu/onnx/test_simplify.py +++ b/tests/gpu/onnx/test_simplify.py @@ -57,14 +57,14 @@ def test_onnx_simplification(tmp_path): assert os.path.isfile(output_onnx_path), "Quantized ONNX was not found!" # Load the simplified model and check that the model doesn't contain Identity nodes, - # only 3 layers (Conv->BN->Relu). + # only 2 layers (Conv->Relu). graph = gs.import_onnx(onnx.load(simplified_onnx_path)) identity_nodes = [n for n in graph.nodes if n.op == "Identity"] assert not identity_nodes, "Simplified ONNX model contains Identity nodes but it shouldn't." - assert len(graph.nodes) == 3, ( - f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 3." + assert len(graph.nodes) == 2, ( + f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 2." ) - assert all(n.op in ["Conv", "BatchNormalization", "Relu"] for n in graph.nodes), ( + assert all(n.op in ["Conv", "Relu"] for n in graph.nodes), ( "Graph contains more ops than expected." ) From 0bf5b06759fdb3d65d568db16f446ca02e97198e Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:23:53 +0530 Subject: [PATCH 3/8] Update CHANGELOG.rst Signed-off-by: inisis --- CHANGELOG.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 60c7ec5ca..f6a323441 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,7 +28,7 @@ Model Optimizer Changelog (Linux) **Misc** - Bump minimum recommended transformers version to 4.53. -- Replace ONNX simplification package from 'onnxsim' to 'onnxslim'. +- Replace ONNX simplification package from ``onnxsim`` to ``onnxslim``. 0.39 (2025-11-11) ^^^^^^^^^^^^^^^^^ From ff3ffc83aca933a0d207d420e0f13bbebb7316b1 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:25:25 +0530 Subject: [PATCH 4/8] Update tox.ini Signed-off-by: inisis --- tox.ini | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tox.ini b/tox.ini index ff73b3b2b..1c1afbe3a 100644 --- a/tox.ini +++ b/tox.ini @@ -18,9 +18,6 @@ deps = torch28: torchvision~=0.23.0 torch29: torchvision~=0.24.0 - # Build onnxsim from sdists for Python 3.12 until http://github.com/daquexian/onnx-simplifier/pull/353 - py312: onnxsim - # Install megatron-core for special unit tests megatron-core @@ -42,9 +39,6 @@ deps = # Make sure torch 2.9 is used torchvision~=0.24.0 - # Build onnxsim from sdists for Python 3.12 until http://github.com/daquexian/onnx-simplifier/pull/353 - py312: onnxsim - # ONNX unit tests heavily rely on torch / torchvision onnx: .[onnx,dev-test] onnx: torchvision @@ -80,9 +74,6 @@ commands_pre = # Install Eagle-3 test dependencies pip install tiktoken blobfile sentencepiece - # Build onnxsim from sdists for Python 3.12 until http://github.com/daquexian/onnx-simplifier/pull/353 - py312: pip install onnxsim - # NOTE: User is expected to have correct torch-cuda version pre-installed if using --current-env # to avoid possible CUDA version mismatch pip install -e .[all,dev-test] From 0e07446fc3c29cf2e29e68d0bf413be6ad1f449b Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:26:10 +0530 Subject: [PATCH 5/8] Update tests.yml Signed-off-by: inisis --- .gitlab/tests.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitlab/tests.yml b/.gitlab/tests.yml index 1012bcb7f..1d699896e 100644 --- a/.gitlab/tests.yml +++ b/.gitlab/tests.yml @@ -19,8 +19,6 @@ unit: TRANSFORMERS: latest image: python:3.$PYTHON before_script: - # Install cmake to build onnxsim from sdists for Python 3.12 until http://github.com/daquexian/onnx-simplifier/pull/353 - - if [ "$PYTHON" = "12" ]; then apt-get update && apt-get install -y cmake; fi - pip install tox script: - tox -e py3$PYTHON-torch$TORCH-tf_$TRANSFORMERS-unit From ba5a125433cff5a2bbecf1042e6fbc54f7f17534 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:30:13 +0530 Subject: [PATCH 6/8] Update setup.py Signed-off-by: inisis --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7befe9a47..11bf5b261 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 "onnxruntime-directml==1.20.0; platform_system == 'Windows'", "onnxscript", # For autocast opset conversion and test_onnx_dynamo_export unit test - "onnxslim>=0.1.75, + "onnxslim>=0.1.75", "polygraphy>=0.49.22", ], "hf": [ From f2d3412c8522e3dc03fcbdcc78f63e6b731977d4 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:46:33 +0530 Subject: [PATCH 7/8] Update quantize.py Signed-off-by: inisis --- modelopt/onnx/quantization/quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 00b4f5d75..96ee406c7 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -39,9 +39,9 @@ from typing import Any import onnx -import onnxslim import onnx.onnx_cpp2py_export.checker as C import onnx_graphsurgeon as gs +import onnxslim from modelopt.onnx.logging_config import configure_logging, logger from modelopt.onnx.op_types import is_data_dependent_shape_op From 18c27dd56fa8869322c3a95a8adc5c00306a5e25 Mon Sep 17 00:00:00 2001 From: inisis Date: Wed, 26 Nov 2025 21:04:55 +0800 Subject: [PATCH 8/8] pin minimal onnxslim version to 0.1.76 Signed-off-by: inisis --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 11bf5b261..158c91e40 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 "onnxruntime-directml==1.20.0; platform_system == 'Windows'", "onnxscript", # For autocast opset conversion and test_onnx_dynamo_export unit test - "onnxslim>=0.1.75", + "onnxslim>=0.1.76", "polygraphy>=0.49.22", ], "hf": [