-
Notifications
You must be signed in to change notification settings - Fork 616
[TORCH] Added flex_attention hop function #4366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[TORCH] Added flex_attention hop function #4366
Conversation
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Change 1: Converts builtin tensors → Torch tensors when entering the loop body Change 2: Ensures Torch tensors → builtin tensors when yielding back to the loop condition Without these fixes, the conversion would fail when while loops carry tensor values Also modified basic_test.py FILECHECK statements. Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Better documentation for AtenFlexAttentionOp 2. Function referece added as attributes to aten.flex_attention 3. Updates to _import_hop_flex_attention reflecting latest changes of module import. 4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr Signed-off-by: Keshav Vinayak Jha <[email protected]>
Remove note about method usage for HOPs.
Removed TODO note for grouped query attention support in the docstring and comments.
095cb61 to
5e024f6
Compare
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
zjgarvey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does enable importing to mlir.
However, the changes don't actually provide "support" for this op, since the torch op can neither be decomposed nor lowered to any other dialects.
Although we could review/merge this and subsequently add a lowering path for the op in MLIR, I would personally prefer the e2e support is added in the same PR as the import support.
This is a rather unique operator, so having passing e2e tests would give me a lot more confidence in the choices made here. Otherwise I'm basically just hoping that what you did generally makes sense (or doing a significant amount of work myself to check it out), because there really isn't much precedent for these kinds of choices in our codebase.
The only thing needed to have this passing e2e tests is implementing TilingInterface for this operation:
With that said, it's an unreasonable bar to set that every operation must compile e2e through torch-mlir. Torch-MLIR is not a compiler, even though it has tests for e2e paths. The project docs explicitly call out this: Torch-MLIR is primarily a project that is integrated into compilers to bridge them to PyTorch and ONNX. If contemplating a new integration, it may be helpful to refer to existing downstreams: It should be okay to land support for ops through the importer without it running e2e tests in torch-mlir. I've looked at the implementation of e2e tests for more complex ops like attention, and they are not good implementations, they don't add much value. We should as a project allow landing PRs that add support to the importer seperately from e2e tests (Atleast for HOPs). I don't think having a dummy implementation for an op should be the bar to land an operation. |
|
@Groverkss So this torch op lowers to a tm tensor op? Because I don't see where that is happening. My blocking is primarily predicated on the fact that this op is imported to something completely unhandled. Even then, I'm happy to unblock and review as is, but it warranted discussion at least. If you would like to add a review yourself, your context on attention ops would be very helpful. It's simply my preference that we have an e2e test, and I'm not going to block based on that alone. |
I think there is a PR running around in IREE that lowers this op to IREE's attention op (found it: iree-org/iree#22441). I don't think TMTensor is really a requirement anymore, since you can directly lower a torch op in your own project. I think TMTensor is more of a thing of the past, when we really wanted torch-mlir to lower everything for us and we didn't hook patterns into it. For historical context on how TMTensor was used and how it was replaced in IREE (and generally how it should be used now): iree-org/iree#14917
I refrained from adding a review on this because I was guiding @keshavvinayak01 through the implementation and didn't want to land this without getting an extra pair of eyes 😅 I think your review on this is invaluable and I'll still let you decide if we should land this as is or not.
My main worry is that we are tieing the fx_importer to the e2e tests. I personally believe that the e2e test lowering test suite and the fx_importer are seperate pieces of utlity and one should be able to use one without another. I do think the e2e tests are useful though, so I'll recommend @keshavvinayak01 to send a patch implementing TilingInterface for this operation just like we have for the TMTensor op. But that should be seperate from this patch. |
Ah, these are both useful context. Thanks. Yeah, if we don't care about having some implementation here, I'm totally fine with that.
That makes sense. I'll review now.
Yeah, that sounds good. I just wasn't aware that it was common practice to in-house certain torch lowerings in downstream projects like IREE. |
zjgarvey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think two things would be nice before merging, but since none of the changes here really affect anything else in torch-mlir, I'm not going to block anything.
-
An importer test would be incredibly valuable in my opinion. I'm half-inclined to write one myself just to debug-print the fx graph and mlir so I can review this PR a bit better.
-
Some explanation of what the
enable_gqaarg is doing/not doing. As you can see from my comments, I'm a bit confused by this arg, since it doesn't seem to do anything in pytorch or in the torch-mlir op (where it is hardcoded toFalse).
…ment Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
zjgarvey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think removing the unused arg makes sense, thanks for doing that.
Based on the comments, this PR definitely needs to have at least one importer test, but I would highly recommend adding tests for both default and non-default mod functions.
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
…working lit test Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
|
This seems fine to me, but I would consider
|
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Description
Torch_AtenFlexAttentionOpwith 6 operands (query, key, value, scale, return_lse, return_max_score) and 2 optional attributes (score_mod_fn, mask_mod_fn) for function references._import_hop_flex_attention) correctly extracts score/mask modification functions fromget_attrnodes using module IDs, following the while_loop HOP pattern.kernel_optionsperformance tuning parameters.