Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_is_float_tensor,
Q_ANNOTATION_KEY,
)
from enum import Enum, unique
from executorch.backends.qualcomm.quantizer.quantizer import (
get_16a8w_qnn_ptq_config,
get_16a8w_qnn_qat_config,
Expand Down Expand Up @@ -125,6 +126,100 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
_annotated=True,
)

def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None:
"""
This function is for static LLM models.
This function will annotate the last conv(linear), which is the lm_head, as 16a8w.
"""

def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
input_qspec_map = {}
input_act = node.args[0]
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec

weight = node.args[1]
input_qspec_map[weight] = quantization_config.weight

if len(node.args) > 2 and isinstance(node.args[2], Node):
input_qspec_map[node.args[2]] = quantization_config.bias(node)

node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)

if is_qat:
quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config(
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
)
else:
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
)
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
if "nn_module_stack" in node.meta:
module_values_list = list(node.meta["nn_module_stack"].values())
full_qualified_name = module_values_list[-1][0]
if full_qualified_name == "output.conv":
annotate_conv2d(
node, quantization_config=quantization_config_16a8w_per_channel
)


@unique
class StaticLLMQuantConfig(Enum):
"""
Layer namespace configuration for Qualcomm's static LLaMA quantization.
"""

wq_sha = "wq_sha" # Query weight (single head)
wk_sha = "wk_sha" # Key weight (single head)
wv_sha = "wv_sha" # Value weight (single head)

def annotate_qkv_proj_sha(
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
qkv_tags: set[StaticLLMQuantConfig],
):
"""
Annotates QKV projection layers in a GraphModule for quantization,
specifically layers defined in StaticLLMQuantConfig.

Args:
qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers
(e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in
StaticLLMQuantConfig are allowed.

Raises:
ValueError: If any tag in `qkv_tags` is not among the allowed enum members.
"""

# Get all valid tags from the StaticLLMQuantConfig enum
allowed_tags = set(StaticLLMQuantConfig)
invalid_tags = qkv_tags - allowed_tags
if invalid_tags:
raise ValueError(
f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}"
)

for node in gm.graph.nodes:
if node.target == torch.ops.aten.conv2d.default and any(
tag.value in node.meta["stack_trace"] for tag in qkv_tags
):
input_qspec_map = {}
input_qspec_map[node.args[0]] = quantization_config.input_activation
input_qspec_map[node.args[1]] = quantization_config.weight
if len(node.args) > 2 and isinstance(node.args[2], Node):
input_qspec_map[node.args[2]] = quantization_config.bias(node)
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)


def annotate_kv_8bit( # noqa: C901
gm: torch.fx.GraphModule,
Expand Down
Loading