Commit 87e1f76
authored
[ONNX] Add Onnx->Torch lowering for GroupQueryAttention op (#4006)
This commit adds the lowering for Onnx's GroupQueryAttention op.
The lowering is adopted from here:
https://github.com/microsoft/onnxruntime/blob/65008cbb7393b121400a40dd8af4cc93d506918f/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L45
https://github.com/microsoft/onnxruntime/blob/65008cbb7393b121400a40dd8af4cc93d506918f/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h#L50
The reference implementation or pseudo-code can be viewed here:
https://gist.github.com/vivekkhandelwal1/f307b687fb133f36276f3d1a3c60ed7e.
The lowering supports the GQA with rotary_embedding.
---------
Signed-off-by: Vivek Khandelwal <[email protected]>1 parent 40b3469 commit 87e1f76
File tree
2 files changed
+513
-0
lines changed- lib/Conversion/TorchOnnxToTorch
- test/Conversion/TorchOnnxToTorch
2 files changed
+513
-0
lines changed
0 commit comments