Skip to content

Commit 0d61efc

Browse files
authored
Qualcomm AI Engine Direct - Add MHA2SHA pass (pytorch#15438)
### Background We observed that quantizing and compiling the original sha model requires a significant amount of time. Switching to the mha model speeds up this process. Therefore, we investigated whether converting the mha model after quantization is feasible. However, we cannot perform this conversion during the to_edge transformation, as splitting the convolution weights to sha would require modifying the state_dict, which is not permitted at that stage. Therefore, we decided to apply this pass during qnn_preprocess. ### Summary: - Integrated mha into sha pass and implemented it in qnn_preprocess - Refactored mha in static llama - Included spin quant r3 support and masked softmax for MHA model in static llama - Combined the n_heads key-value cache into a single cache for each layer to decrease the number of inputs and outputs, which enhances performance. - Deprecated ShiftPointer kv updater mode - Since each layer now has its own kv cache, the v cache no longer benefits from ShiftPointer, which previously avoided copying the new v cache to the input v cache. To prevent user confusion, ShiftPointer mode has been deprecated - Applied the correct input template for smollm2 135m - Correct the quantization annotation for reshape - Remove outdated code from CanonicalizeConv ### Results Follow [README setting](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama/README.md), test on SM8750 with QNN 2.37. Compared the new pass `convert_mha_to_sha` with original sha structure <img width="1731" height="734" alt="image" src="https://github.com/user-attachments/assets/2b1c2b66-77c0-4662-a035-900ad9091d67" />
1 parent 32916d3 commit 0d61efc

39 files changed

+1222
-1004
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .canonicalize_conv import CanonicalizeConv
1212
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1313
from .convert_linear_to_conv2d import ConvertLinearToConv2d
14+
from .convert_mha_to_sha import ConvertMhaToSha
1415
from .convert_square_to_pow import ConvertSquareToPow
1516
from .decompose_any import DecomposeAny
1617
from .decompose_binary_alpha import DecomposeBinaryAlpha
@@ -56,6 +57,7 @@
5657
CanonicalizeConv,
5758
ConvertBmmToMatmul,
5859
ConvertLinearToConv2d,
60+
ConvertMhaToSha,
5961
ConvertSquareToPow,
6062
DecomposeAny,
6163
DecomposeBinaryAlpha,

backends/qualcomm/_passes/canonicalize_conv.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010

1111
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
12-
from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE
1312
from executorch.exir.pass_base import ExportPass, PassResult
1413
from torch._guards import detect_fake_mode
1514

@@ -197,14 +196,6 @@ def call(self, graph_module: torch.fx.GraphModule):
197196
)
198197
squeeze_node.meta = copy_meta(node.meta)
199198

200-
if QCOM_REQUANTIZE in input_node.meta:
201-
input_node.meta.pop(QCOM_REQUANTIZE)
202-
if QCOM_REQUANTIZE in node.meta:
203-
squeeze_node.meta[QCOM_REQUANTIZE] = node.meta[
204-
QCOM_REQUANTIZE
205-
]
206-
conv2d_node.meta.pop(QCOM_REQUANTIZE, None)
207-
208199
for user in node.users.copy():
209200
user.replace_input_with(node, squeeze_node)
210201

backends/qualcomm/_passes/convert_bmm_to_matmul.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ def call(self, graph_module: torch.fx.GraphModule):
4747
graph = graph_module.graph
4848
partitions = get_source_partitions(
4949
graph,
50-
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
50+
[
51+
"matmul",
52+
operator.matmul,
53+
torch.matmul,
54+
torch.bmm,
55+
torch.ops.aten.matmul.default,
56+
],
5157
)
5258
for _, src_partitions in partitions.items():
5359
for src_partition in src_partitions:

0 commit comments

Comments
 (0)