Skip to content

Conversation

@YangKai0616
Copy link
Contributor

What does this PR do?

1. [General] When running the test case tests/generation/test_continuous_batching.py::ContinuousBatchingTest::test_continuous_batching_parity_qwen_flash, it reports error as follows:

                flash_attn_func = getattr(kernel, "flash_attn_func", None)
                flash_attn_varlen_func = getattr(kernel, "flash_attn_varlen_func", None)
                if flash_attn_varlen_func is None:
>                   raise ValueError(
                        f"Could not find the currently requested flash attention implementation at `{implementation}`."
                        "Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn2`."
                    )
E                   ValueError: Could not find the currently requested flash attention implementation at `flash_attention_2`.Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn2`.

src/transformers/modeling_flash_attention_utils.py:115: ValueError

The root cause is a kernel loading failure of paged|flash_attention_2. This PR fixes the issue.

2. [XPU] The above test case reports error E AssertionError: Test request_id = 'req_1' failed, no expected output was provided..... after successfully running on both XPU and CUDA. However, on the XPU side, the test case can stably pass by using require_deterministic_for_xpu.

For enabling the related tests on XPU, please refer to PR #42536.

@YangKai0616
Copy link
Contributor Author

@vasqu , please help review, thanks!

@YangKai0616 YangKai0616 changed the title Fixed UT and kernel loading logic. Fixed paged|FA2 kernel loading logic and UT. Dec 2, 2025
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to extend this a bit to allow general CB support for kernels without 2x loading the kernel.

The idea is to modify the fallback to include the paged| prefix directly + load the kernel properly via lazy_import_paged_flash_attention. This should allow us to use the prefix version directly as well, e.g. attn_implementation="paged|kernels-community/flash-attn2"

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify a bit, introducing another variable will be even more confusing on second thought. We can avoid that

@YangKai0616
Copy link
Contributor Author

Strange, the CI failing example tests/models/mvp/test_modeling_mvp.py::MvpHeadTests::test_generate_beam_search can pass on both CUDA (A100) and XPU for me (with torch 2.9).

@vasqu
Copy link
Contributor

vasqu commented Dec 3, 2025

Yup, not sure what's going on. It's an annoying flaky test. Added a new commit to simplify something, let's see if CI passes this time otherwise I'll try to get a core maintainer to merge if necessary.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to approve but lgtm now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants