-
Notifications
You must be signed in to change notification settings - Fork 31.3k
extend FA2 and other cases to XPU, #42536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
8d14174
b7a366f
d54f907
40d694a
ac8c6b8
ec05093
e5692c7
b37f8b3
ab3e92c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,6 @@ | |
| require_kernels, | ||
| require_read_token, | ||
| require_torch_accelerator, | ||
| require_torch_gpu, | ||
| slow, | ||
| torch_device, | ||
| ) | ||
|
|
@@ -315,36 +314,47 @@ def test_continuous_batching_parity_gemma_sdpa(self) -> None: | |
| # GPT-OSS is not compatible with SDPA because it has an attention sink. TODO: is this fixable? | ||
|
|
||
| # Flash attention test | ||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| @require_kernels | ||
| @slow | ||
| def test_continuous_batching_parity_llama_flash(self) -> None: | ||
| expected_outputs = Expectations({ | ||
| ("cuda", (9, 0)): { | ||
| "req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5 bolts. The total number of bolts is 4.5 bolts.", | ||
| } | ||
| }, | ||
| ("xpu", None): { | ||
| "req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5 bolts. The total number of bolts is 4.5 bolts.", | ||
| }, | ||
| }).get_expectation() # fmt: skip | ||
| self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|flash_attention_2", expected_outputs) | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| @require_kernels | ||
| @slow | ||
| def test_continuous_batching_parity_gemma_flash(self) -> None: | ||
| expected_outputs = Expectations({ | ||
| ("cuda", (9, 0)): { | ||
| "req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ", | ||
| } | ||
| }, | ||
| ("xpu", None): { | ||
| "req_0": "\n\n**$128**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 1", | ||
| "req_1": "\n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =", | ||
| }, | ||
| }).get_expectation() # fmt: skip | ||
| self._continuous_batching_parity("google/gemma-2-2b-it", "paged|flash_attention_2", expected_outputs) | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| @require_kernels | ||
| @slow | ||
| def test_continuous_batching_parity_qwen_flash(self) -> None: | ||
| expected_outputs = {} | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @remi-or Could you add code comment and/or docstring for
is for?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! Was actually thinking of doing a test-oriented PR. |
||
| expected_outputs = Expectations({ | ||
| ("xpu", None): { | ||
| "req_1": " 3.5 bolts.\n\nLet's break it down step by step:\n\n- Blue fiber: 2 bolts\n- White fiber: half of 2 bolts = 1 bolt\n\nTotal = ", | ||
| }, | ||
| }).get_expectation() # fmt: skip | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i need to check why this was |
||
| self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|flash_attention_2", expected_outputs) | ||
|
|
||
| @require_torch_gpu | ||
| @require_torch_accelerator | ||
| @require_kernels | ||
| @slow | ||
| def test_continuous_batching_parity_gpt_oss_flash(self) -> None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.