-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Fixed paged|FA2 kernel loading logic and UT. #42547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@vasqu , please help review, thanks! |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to extend this a bit to allow general CB support for kernels without 2x loading the kernel.
The idea is to modify the fallback to include the paged| prefix directly + load the kernel properly via lazy_import_paged_flash_attention. This should allow us to use the prefix version directly as well, e.g. attn_implementation="paged|kernels-community/flash-attn2"
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's simplify a bit, introducing another variable will be even more confusing on second thought. We can avoid that
|
Strange, the CI failing example |
|
Yup, not sure what's going on. It's an annoying flaky test. Added a new commit to simplify something, let's see if CI passes this time otherwise I'll try to get a core maintainer to merge if necessary. |
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to approve but lgtm now
What does this PR do?
1. [General] When running the test case
tests/generation/test_continuous_batching.py::ContinuousBatchingTest::test_continuous_batching_parity_qwen_flash, it reports error as follows:The root cause is a kernel loading failure of
paged|flash_attention_2. This PR fixes the issue.2. [XPU] The above test case reports error
E AssertionError: Test request_id = 'req_1' failed, no expected output was provided.....after successfully running on both XPU and CUDA. However, on the XPU side, the test case can stably pass by usingrequire_deterministic_for_xpu.For enabling the related tests on XPU, please refer to PR #42536.