Skip to content

Conversation

@keshavvinayak01
Copy link
Contributor

@keshavvinayak01 keshavvinayak01 commented Nov 4, 2025

Description

  • Added support for PyTorch's flex_attention Higher-Order Operator in torch-mlir.
  • Implemented Torch_AtenFlexAttentionOp with 6 operands (query, key, value, scale, return_lse, return_max_score) and 2 optional attributes (score_mod_fn, mask_mod_fn) for function references.
  • The FX importer (_import_hop_flex_attention) correctly extracts score/mask modification functions from get_attr nodes using module IDs, following the while_loop HOP pattern.
  • Includes TODO markers for kernel_options performance tuning parameters.

The call to flex_attention_hop internally in torch.nn.attention.flex_attention uses the kernel_options dict to pass the return_lse and return_max_score options (OUTPUT_LOGSUMEXP and OUTPUT_MAX key names respectively). While I've implemented that support, other fine grained controls (including blocking heuristics) are not supported yet.

  • Imports flex_attention from PyTorch FX graphs into valid MLIR.

keshavvinayak01 and others added 17 commits October 22, 2025 09:41
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Change 1: Converts builtin tensors → Torch tensors when entering the loop body
Change 2: Ensures Torch tensors → builtin tensors when yielding back to the loop condition
Without these fixes, the conversion would fail when while loops carry tensor values

Also modified basic_test.py FILECHECK statements.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1. Better documentation for AtenFlexAttentionOp
2. Function referece added as attributes to aten.flex_attention
3. Updates to _import_hop_flex_attention reflecting latest changes of module import.
4. Removed discardable attributes; scored_mod_fn and mask_mod_fn added as optionalAttr

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Remove note about method usage for HOPs.
@keshavvinayak01 keshavvinayak01 changed the title Keshavvinayak01/torch aten flex attention [TORCH] Added flex_attention hop function Nov 4, 2025
Removed TODO note for grouped query attention support in the docstring and comments.
@keshavvinayak01 keshavvinayak01 force-pushed the keshavvinayak01/torch-aten-flex_attention branch from 095cb61 to 5e024f6 Compare November 6, 2025 09:36
@keshavvinayak01 keshavvinayak01 marked this pull request as ready for review November 6, 2025 09:37
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does enable importing to mlir.

However, the changes don't actually provide "support" for this op, since the torch op can neither be decomposed nor lowered to any other dialects.

Although we could review/merge this and subsequently add a lowering path for the op in MLIR, I would personally prefer the e2e support is added in the same PR as the import support.

This is a rather unique operator, so having passing e2e tests would give me a lot more confidence in the choices made here. Otherwise I'm basically just hoping that what you did generally makes sense (or doing a significant amount of work myself to check it out), because there really isn't much precedent for these kinds of choices in our codebase.

@Groverkss
Copy link
Member

Groverkss commented Nov 11, 2025

This does enable importing to mlir.

However, the changes don't actually provide "support" for this op, since the torch op can neither be decomposed nor lowered to any other dialects.

Although we could review/merge this and subsequently add a lowering path for the op in MLIR, I would personally prefer the e2e support is added in the same PR as the import support.

This is a rather unique operator, so having passing e2e tests would give me a lot more confidence in the choices made here. Otherwise I'm basically just hoping that what you did generally makes sense (or doing a significant amount of work myself to check it out), because there really isn't much precedent for these kinds of choices in our codebase.

The only thing needed to have this passing e2e tests is implementing TilingInterface for this operation:

LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,

With that said, it's an unreasonable bar to set that every operation must compile e2e through torch-mlir. Torch-MLIR is not a compiler, even though it has tests for e2e paths. The project docs explicitly call out this:

Torch-MLIR is primarily a project that is integrated into compilers to bridge them to PyTorch and ONNX. If contemplating a new integration, it may be helpful to refer to existing downstreams:

IREE
Blade
While most of the project is exercised via testing paths, there are some ways that an end user can directly use the APIs without further integration:

It should be okay to land support for ops through the importer without it running e2e tests in torch-mlir. I've looked at the implementation of e2e tests for more complex ops like attention, and they are not good implementations, they don't add much value.

We should as a project allow landing PRs that add support to the importer seperately from e2e tests (Atleast for HOPs). I don't think having a dummy implementation for an op should be the bar to land an operation.

@zjgarvey
Copy link
Collaborator

@Groverkss So this torch op lowers to a tm tensor op? Because I don't see where that is happening.

My blocking is primarily predicated on the fact that this op is imported to something completely unhandled. Even then, I'm happy to unblock and review as is, but it warranted discussion at least. If you would like to add a review yourself, your context on attention ops would be very helpful.

It's simply my preference that we have an e2e test, and I'm not going to block based on that alone.

@Groverkss
Copy link
Member

@Groverkss So this torch op lowers to a tm tensor op? Because I don't see where that is happening.

I think there is a PR running around in IREE that lowers this op to IREE's attention op (found it: iree-org/iree#22441).

I don't think TMTensor is really a requirement anymore, since you can directly lower a torch op in your own project. I think TMTensor is more of a thing of the past, when we really wanted torch-mlir to lower everything for us and we didn't hook patterns into it. For historical context on how TMTensor was used and how it was replaced in IREE (and generally how it should be used now): iree-org/iree#14917

My blocking is primarily predicated on the fact that this op is imported to something completely unhandled. Even then, I'm happy to unblock and review as is, but it warranted discussion at least. If you would like to add a review yourself, your context on attention ops would be very helpful.

I refrained from adding a review on this because I was guiding @keshavvinayak01 through the implementation and didn't want to land this without getting an extra pair of eyes 😅 I think your review on this is invaluable and I'll still let you decide if we should land this as is or not.

It's simply my preference that we have an e2e test, and I'm not going to block based on that alone.

My main worry is that we are tieing the fx_importer to the e2e tests. I personally believe that the e2e test lowering test suite and the fx_importer are seperate pieces of utlity and one should be able to use one without another. I do think the e2e tests are useful though, so I'll recommend @keshavvinayak01 to send a patch implementing TilingInterface for this operation just like we have for the TMTensor op. But that should be seperate from this patch.

@zjgarvey
Copy link
Collaborator

I think there is a PR running around in IREE that lowers this op to IREE's attention op (found it: iree-org/iree#22441).

I don't think TMTensor is really a requirement anymore, since you can directly lower a torch op in your own project. I think TMTensor is more of a thing of the past, when we really wanted torch-mlir to lower everything for us and we didn't hook patterns into it. For historical context on how TMTensor was used and how it was replaced in IREE (and generally how it should be used now): iree-org/iree#14917

Ah, these are both useful context. Thanks. Yeah, if we don't care about having some implementation here, I'm totally fine with that.

I refrained from adding a review on this because I was guiding @keshavvinayak01 through the implementation and didn't want to land this without getting an extra pair of eyes 😅 I think your review on this is invaluable and I'll still let you decide if we should land this as is or not.

That makes sense. I'll review now.

My main worry is that we are tieing the fx_importer to the e2e tests. I personally believe that the e2e test lowering test suite and the fx_importer are seperate pieces of utlity and one should be able to use one without another. I do think the e2e tests are useful though, so I'll recommend @keshavvinayak01 to send a patch implementing TilingInterface for this operation just like we have for the TMTensor op. But that should be seperate from this patch.

Yeah, that sounds good. I just wasn't aware that it was common practice to in-house certain torch lowerings in downstream projects like IREE.

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think two things would be nice before merging, but since none of the changes here really affect anything else in torch-mlir, I'm not going to block anything.

  1. An importer test would be incredibly valuable in my opinion. I'm half-inclined to write one myself just to debug-print the fx graph and mlir so I can review this PR a bit better.

  2. Some explanation of what the enable_gqa arg is doing/not doing. As you can see from my comments, I'm a bit confused by this arg, since it doesn't seem to do anything in pytorch or in the torch-mlir op (where it is hardcoded to False).

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think removing the unused arg makes sense, thanks for doing that.

Based on the comments, this PR definitely needs to have at least one importer test, but I would highly recommend adding tests for both default and non-default mod functions.

Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
Signed-off-by: Keshav Vinayak Jha <[email protected]>
@zjgarvey
Copy link
Collaborator

This seems fine to me, but I would consider

  1. renaming the op if it isn't in the "aten" operator namespace for pytorch (this can be misleading otherwise).

  2. Adding at least one more import test to cover some other cases. This op has many options so it would be great to have some more robust testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants