-
Notifications
You must be signed in to change notification settings - Fork 88
Implement SDPA via MHA #2683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement SDPA via MHA #2683
Conversation
Signed-off-by: Ganesan Ramalingam <[email protected]>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2683 +/- ##
==========================================
+ Coverage 70.11% 70.12% +0.01%
==========================================
Files 224 225 +1
Lines 26982 27063 +81
Branches 2705 2719 +14
==========================================
+ Hits 18919 18979 +60
- Misses 7129 7146 +17
- Partials 934 938 +4 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Ganesan Ramalingam <[email protected]>
|
|
||
| class SDPAImplementation(pattern.RewriteRuleClassBase): | ||
| def pattern(self, op, query, key, value): | ||
| def pattern(self, op, query, key, value, key_format): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create a docstring for these params?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking ... [ok, I picked that from chatgpt :-)] ... I suspect what you are asking for is a "spec" for the SDPA (the internal op used by fusion, which does not correspond to any ONNX standard op or ORT contrib op, though it is a close approximation to ONNX's Attention). Is that right? Because that op is used across multiple fusions and files, the sdpa.py file might be the right place for that. Because the 1-liner docstring for these params would be that they correspond to the corresponding inputs and attribute of the SDPA op ... which doesn't say anything more than what the reader can see in the next few lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, added documentation as discussed above.
Signed-off-by: Ganesan Ramalingam <[email protected]>
Implement SDPA via MHA. This handles the case when earlier fusion rules do not map larger patterns containing SDPA into MHA or GQA or Attention (from ORT contrib ops). It implements SDPA via MHA.