|
11 | 11 | _is_float_tensor, |
12 | 12 | Q_ANNOTATION_KEY, |
13 | 13 | ) |
| 14 | +from enum import Enum, unique |
14 | 15 | from executorch.backends.qualcomm.quantizer.quantizer import ( |
15 | 16 | get_16a8w_qnn_ptq_config, |
16 | 17 | get_16a8w_qnn_qat_config, |
@@ -125,6 +126,100 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): |
125 | 126 | _annotated=True, |
126 | 127 | ) |
127 | 128 |
|
| 129 | +def annotate_output_16a8w(gm: torch.fx.GraphModule, is_qat: bool = False) -> None: |
| 130 | + """ |
| 131 | + This function is for static LLM models. |
| 132 | + This function will annotate the last conv(linear), which is the lm_head, as 16a8w. |
| 133 | + """ |
| 134 | + |
| 135 | + def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: |
| 136 | + input_qspec_map = {} |
| 137 | + input_act = node.args[0] |
| 138 | + input_spec = quantization_config.input_activation |
| 139 | + input_qspec_map[input_act] = input_spec |
| 140 | + |
| 141 | + weight = node.args[1] |
| 142 | + input_qspec_map[weight] = quantization_config.weight |
| 143 | + |
| 144 | + if len(node.args) > 2 and isinstance(node.args[2], Node): |
| 145 | + input_qspec_map[node.args[2]] = quantization_config.bias(node) |
| 146 | + |
| 147 | + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( |
| 148 | + input_qspec_map=input_qspec_map, |
| 149 | + output_qspec=quantization_config.output_activation, |
| 150 | + _annotated=True, |
| 151 | + ) |
| 152 | + |
| 153 | + if is_qat: |
| 154 | + quantization_config_16a8w_per_channel = get_qat_per_channel_quant_config( |
| 155 | + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver |
| 156 | + ) |
| 157 | + else: |
| 158 | + quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config( |
| 159 | + torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver |
| 160 | + ) |
| 161 | + for node in gm.graph.nodes: |
| 162 | + if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: |
| 163 | + if "nn_module_stack" in node.meta: |
| 164 | + module_values_list = list(node.meta["nn_module_stack"].values()) |
| 165 | + full_qualified_name = module_values_list[-1][0] |
| 166 | + if full_qualified_name == "output.conv": |
| 167 | + annotate_conv2d( |
| 168 | + node, quantization_config=quantization_config_16a8w_per_channel |
| 169 | + ) |
| 170 | + |
| 171 | + |
| 172 | +@unique |
| 173 | +class StaticLLMQuantConfig(Enum): |
| 174 | + """ |
| 175 | + Layer namespace configuration for Qualcomm's static LLaMA quantization. |
| 176 | + """ |
| 177 | + |
| 178 | + wq_sha = "wq_sha" # Query weight (single head) |
| 179 | + wk_sha = "wk_sha" # Key weight (single head) |
| 180 | + wv_sha = "wv_sha" # Value weight (single head) |
| 181 | + |
| 182 | +def annotate_qkv_proj_sha( |
| 183 | + gm: torch.fx.GraphModule, |
| 184 | + quantization_config: QuantizationConfig, |
| 185 | + qkv_tags: set[StaticLLMQuantConfig], |
| 186 | +): |
| 187 | + """ |
| 188 | + Annotates QKV projection layers in a GraphModule for quantization, |
| 189 | + specifically layers defined in StaticLLMQuantConfig. |
| 190 | +
|
| 191 | + Args: |
| 192 | + qkv_tags (set[StaticLLMQuantConfig]): A set of enum tags indicating which QKV layers |
| 193 | + (e.g., wq, wk, wv) should be annotated for quantization. Only tags defined in |
| 194 | + StaticLLMQuantConfig are allowed. |
| 195 | +
|
| 196 | + Raises: |
| 197 | + ValueError: If any tag in `qkv_tags` is not among the allowed enum members. |
| 198 | + """ |
| 199 | + |
| 200 | + # Get all valid tags from the StaticLLMQuantConfig enum |
| 201 | + allowed_tags = set(StaticLLMQuantConfig) |
| 202 | + invalid_tags = qkv_tags - allowed_tags |
| 203 | + if invalid_tags: |
| 204 | + raise ValueError( |
| 205 | + f"Invalid qkv tags: {invalid_tags}. Allowed tags are: {allowed_tags}" |
| 206 | + ) |
| 207 | + |
| 208 | + for node in gm.graph.nodes: |
| 209 | + if node.target == torch.ops.aten.conv2d.default and any( |
| 210 | + tag.value in node.meta["stack_trace"] for tag in qkv_tags |
| 211 | + ): |
| 212 | + input_qspec_map = {} |
| 213 | + input_qspec_map[node.args[0]] = quantization_config.input_activation |
| 214 | + input_qspec_map[node.args[1]] = quantization_config.weight |
| 215 | + if len(node.args) > 2 and isinstance(node.args[2], Node): |
| 216 | + input_qspec_map[node.args[2]] = quantization_config.bias(node) |
| 217 | + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( |
| 218 | + input_qspec_map=input_qspec_map, |
| 219 | + output_qspec=quantization_config.output_activation, |
| 220 | + _annotated=True, |
| 221 | + ) |
| 222 | + |
128 | 223 |
|
129 | 224 | def annotate_kv_8bit( # noqa: C901 |
130 | 225 | gm: torch.fx.GraphModule, |
|
0 commit comments