Skip to content

Conversation

@yao-matrix
Copy link
Contributor

we expect all model cases except CUDAGraph specific, CUDA compute capability specific and FA3 specific can run XPU. For FA3, we are developing.

@ydshieh, pls help review, thx very much.

…UDAGraph

specific, CUDA compute capability specific and FA3 specific can run XPU.
For FA3, we are develioping

Signed-off-by: Yao, Matrix <[email protected]>
Signed-off-by: Yao, Matrix <[email protected]>

MIMI_ATTENTION_CLASSES = {
"eager": MimiAttention,
"kernels-community/flash-attn2": MimiFlashAttention2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain this part?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh , sure. in latest design, when users set attn_implementation == "flash_attention_2", there will be 2 branches:

  1. if flash_attn package is available, it will go directly to use it
  2. else, do not fail as before, but use kernels instead, in this case, the attn_implementation will be updated to "kernels-community/flash-attn2", as in code here

For XPU, we go with the kernels path in transformers for FA support, so we need this key.

Thx very much.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not do this, even tho you are correct here. We should rather refactor mimi here with the attention interface and not have these manual registrations. We could infinitely extend these edge cases in the future to FA3 etc which makes this not scalable (without using/refactoring to the interface).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yao-matrix

Let's revert this line 🙏 . We can skip the relevant FA tests if necessary.

("xpu", None): {
"req_1": " 3.5 bolts.\n\nLet's break it down step by step:\n\n- Blue fiber: 2 bolts\n- White fiber: half of 2 bolts = 1 bolt\n\nTotal = ",
},
}).get_expectation() # fmt: skip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i need to check why this was {} before, but thank you.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM, but has one question

@ydshieh
Copy link
Collaborator

ydshieh commented Dec 2, 2025

@yao-matrix don't forget

#42536 (comment)

🙏

@yao-matrix
Copy link
Contributor Author

@yao-matrix don't forget

#42536 (comment)

🙏

Yes, done, thx very much for your always support, :).

@github-actions
Copy link
Contributor

github-actions bot commented Dec 2, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma2, gemma3, glm4v, glm4v_moe, granitemoehybrid, idefics2, kosmos2_5, longcat_flash, mimi, modernbert, musicgen, musicgen_melody, pixtral, qwen2_5_omni, qwen2_5_vl, qwen2_moe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants