Skip to content

Commit 64c5414

Browse files
cccclaifacebook-github-bot
authored andcommitted
forward fix unit test (pytorch#15926)
Summary: Forward fix test failure in pytorch#15807 Main reason is that this API is called internally. In this PR, I recovered some of the deleted functions in the previous PRs Differential Revision: D87566729
1 parent 92bf722 commit 64c5414

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_is_float_tensor,
1212
Q_ANNOTATION_KEY,
1313
)
14+
from enum import Enum, unique
1415
from executorch.backends.qualcomm.quantizer.quantizer import (
1516
get_16a8w_qnn_ptq_config,
1617
get_16a8w_qnn_qat_config,
@@ -125,6 +126,100 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
125126
_annotated=True,
126127
)
127128

129+
def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None:
130+
"""
131+
This function is for static LLM models.
132+
This function will annotate the last conv(linear), which is the lm_head, as 16a8w.
133+
"""
134+
135+
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
136+
input_qspec_map = {}
137+
input_act = node.args[0]
138+
input_spec = quantization_config.input_activation
139+
input_qspec_map[input_act] = input_spec
140+
141+
weight = node.args[1]
142+
input_qspec_map[weight] = quantization_config.weight
143+
144+
if len(node.args) > 2 and isinstance(node.args[2], Node):
145+
input_qspec_map[node.args[2]] = quantization_config.bias(node)
146+
147+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
148+
input_qspec_map=input_qspec_map,
149+
output_qspec=quantization_config.output_activation,
150+
_annotated=True,
151+
)
152+
153+
if is_qat:
154+
quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config(
155+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
156+
)
157+
else:
158+
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
159+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
160+
)
161+
for node in gm.graph.nodes:
162+
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
163+
if "nn_module_stack" in node.meta:
164+
module_values_list = list(node.meta["nn_module_stack"].values())
165+
full_qualified_name = module_values_list[-1][0]
166+
if full_qualified_name == "output.conv":
167+
annotate_conv2d(
168+
node, quantization_config=quantization_config_16a8w_per_channel
169+
)
170+
171+
172+
@unique
173+
class StaticLLMQuantConfig(Enum):
174+
"""
175+
Layer namespace configuration for Qualcomm's static LLaMA quantization.
176+
"""
177+
178+
wq_sha = "wq_sha" # Query weight (single head)
179+
wk_sha = "wk_sha" # Key weight (single head)
180+
wv_sha = "wv_sha" # Value weight (single head)
181+
182+
def annotate_qkv_proj_sha(
183+
gm: torch.fx.GraphModule,
184+
quantization_config: QuantizationConfig,
185+
qkv_tags: set[StaticLLMQuantConfig],
186+
):
187+
"""
188+
Annotates QKV projection layers in a GraphModule for quantization,
189+
specifically layers defined in StaticLLMQuantConfig.
190+
191+
Args:
192+
qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers
193+
(e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in
194+
StaticLLMQuantConfig are allowed.
195+
196+
Raises:
197+
ValueError: If any tag in `qkv_tags` is not among the allowed enum members.
198+
"""
199+
200+
# Get all valid tags from the StaticLLMQuantConfig enum
201+
allowed_tags = set(StaticLLMQuantConfig)
202+
invalid_tags = qkv_tags - allowed_tags
203+
if invalid_tags:
204+
raise ValueError(
205+
f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}"
206+
)
207+
208+
for node in gm.graph.nodes:
209+
if node.target == torch.ops.aten.conv2d.default and any(
210+
tag.value in node.meta["stack_trace"] for tag in qkv_tags
211+
):
212+
input_qspec_map = {}
213+
input_qspec_map[node.args[0]] = quantization_config.input_activation
214+
input_qspec_map[node.args[1]] = quantization_config.weight
215+
if len(node.args) > 2 and isinstance(node.args[2], Node):
216+
input_qspec_map[node.args[2]] = quantization_config.bias(node)
217+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
218+
input_qspec_map=input_qspec_map,
219+
output_qspec=quantization_config.output_activation,
220+
_annotated=True,
221+
)
222+
128223

129224
def annotate_kv_8bit( # noqa: C901
130225
gm: torch.fx.GraphModule,

0 commit comments

Comments
 (0)