Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/models/mimi/modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,7 @@ def forward(

MIMI_ATTENTION_CLASSES = {
"eager": MimiAttention,
"kernels-community/flash-attn2": MimiFlashAttention2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain this part?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydshieh , sure. in latest design, when users set attn_implementation == "flash_attention_2", there will be 2 branches:

  1. if flash_attn package is available, it will go directly to use it
  2. else, do not fail as before, but use kernels instead, in this case, the attn_implementation will be updated to "kernels-community/flash-attn2", as in code here

For XPU, we go with the kernels path in transformers for FA support, so we need this key.

Thx very much.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not do this, even tho you are correct here. We should rather refactor mimi here with the attention interface and not have these manual registrations. We could infinitely extend these edge cases in the future to FA3 etc which makes this not scalable (without using/refactoring to the interface).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yao-matrix

Let's revert this line 🙏 . We can skip the relevant FA tests if necessary.

"flash_attention_2": MimiFlashAttention2,
"sdpa": MimiSdpaAttention,
}
Expand Down
26 changes: 18 additions & 8 deletions tests/generation/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
require_kernels,
require_read_token,
require_torch_accelerator,
require_torch_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -315,36 +314,47 @@ def test_continuous_batching_parity_gemma_sdpa(self) -> None:
# GPT-OSS is not compatible with SDPA because it has an attention sink. TODO: is this fixable?

# Flash attention test
@require_torch_gpu
@require_torch_accelerator
@require_kernels
@slow
def test_continuous_batching_parity_llama_flash(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5 bolts. The total number of bolts is 4.5 bolts.",
}
},
("xpu", None): {
"req_1": " 3 bolts of blue fiber and 1.5 bolts of white fiber. The total number of bolts is 4.5 bolts. The total number of bolts is 4.5 bolts.",
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("meta-llama/Llama-3.1-8B", "paged|flash_attention_2", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@require_kernels
@slow
def test_continuous_batching_parity_gemma_flash(self) -> None:
expected_outputs = Expectations({
("cuda", (9, 0)): {
"req_1": " \n \n 2 + 1 = 3 bolts \n \n \n \n \n \n \n \n \n \n \n \n \n ",
}
},
("xpu", None): {
"req_0": "\n\n**$128**\n\n**Here's how to solve it:**\n\n* **Eggs eaten:** 3\n* **Eggs left:** 16 - 3 = 1",
"req_1": "\n\n**Answer:** 3 bolts\n\n**Solution:**\n\n* **White fiber:** The robe needs half as much white fiber as blue fiber, so it needs 2 bolts / 2 =",
},
}).get_expectation() # fmt: skip
self._continuous_batching_parity("google/gemma-2-2b-it", "paged|flash_attention_2", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@require_kernels
@slow
def test_continuous_batching_parity_qwen_flash(self) -> None:
expected_outputs = {}
expected_outputs = Expectations({
("xpu", None): {
"req_1": " 3.5 bolts.\n\nLet's break it down step by step:\n\n- Blue fiber: 2 bolts\n- White fiber: half of 2 bolts = 1 bolt\n\nTotal = ",
},
}).get_expectation() # fmt: skip
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i need to check why this was {} before, but thank you.

self._continuous_batching_parity("Qwen/Qwen3-4B-Instruct-2507", "paged|flash_attention_2", expected_outputs)

@require_torch_gpu
@require_torch_accelerator
@require_kernels
@slow
def test_continuous_batching_parity_gpt_oss_flash(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/generation/test_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from parameterized import parameterized

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_flash_attn, require_torch_gpu, slow
from transformers.testing_utils import require_flash_attn, require_torch_accelerator, slow


_TEST_PROMPTS = [
Expand All @@ -26,7 +26,7 @@

@slow
@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
class TestBatchGeneration(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down
22 changes: 15 additions & 7 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
require_torch,
require_torch_accelerator,
require_torch_large_accelerator,
require_torch_large_gpu,
run_test_using_subprocess,
slow,
torch_device,
Expand Down Expand Up @@ -172,16 +171,25 @@ def test_model_2b_pipeline_bf16_flex_attention(self):

@require_read_token
@require_flash_attn
@require_torch_large_gpu
@require_torch_large_accelerator
@mark.flash_attn_test
@slow
def test_model_9b_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [
'<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic composed of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the",
] # fmt: skip
# fmt: off
EXPECTED_TEXTS = Expectations(
{
(None, None): ['<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic composed of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the",
],
("xpu", None): ['<bos>Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
"<pad><pad><bos>Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the",
],
}
)
# fmt: on
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()

model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="flash_attention_2", dtype="float16"
Expand All @@ -192,7 +200,7 @@ def test_model_9b_flash_attn(self):
output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)

self.assertEqual(output_text, EXPECTED_TEXTS)
self.assertEqual(output_text, EXPECTED_TEXT)

@pytest.mark.torch_export_test
@slow
Expand Down
2 changes: 1 addition & 1 deletion tests/models/gemma3/test_modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def test_automodelforcausallm(self):
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
def test_flash_attn_2_from_config(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/models/glm4v/test_modeling_glm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
require_deterministic_for_xpu,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -512,7 +512,7 @@ def test_small_model_integration_test_batch_different_resolutions(self):

@slow
@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
def test_small_model_integration_test_batch_flashatt2(self):
model = Glm4vForConditionalGeneration.from_pretrained(
"THUDM/GLM-4.1V-9B-Thinking",
Expand Down Expand Up @@ -547,7 +547,7 @@ def test_small_model_integration_test_batch_flashatt2(self):

@slow
@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
def test_small_model_integration_test_batch_wo_image_flashatt2(self):
model = Glm4vForConditionalGeneration.from_pretrained(
"THUDM/GLM-4.1V-9B-Thinking",
Expand Down
4 changes: 2 additions & 2 deletions tests/models/glm4v_moe/test_modeling_glm4v_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
cleanup,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_accelerator,
run_first,
slow,
torch_device,
Expand Down Expand Up @@ -434,7 +434,7 @@ def test_small_model_integration_test_with_video(self):

@run_first
@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
def test_small_model_integration_test_batch_flashatt2(self):
model = Glm4vMoeForConditionalGeneration.from_pretrained(
"zai-org/GLM-4.5V",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa
pass

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
@unittest.skip(
Expand Down Expand Up @@ -356,7 +356,7 @@ def test_config_requires_mamba_or_attention_layers(self):

# TODO (@alex-jw-brooks) - update this once the model(s) are out
@unittest.skip(reason="GraniteMoeHybrid models are not yet released")
@require_torch_gpu
@require_torch_accelerator
class GraniteMoeHybridIntegrationTest(unittest.TestCase):
@slow
def test_model_logits(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/idefics2/test_modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
require_bitsandbytes,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_torch_multi_accelerator,
slow,
torch_device,
Expand Down Expand Up @@ -645,7 +645,7 @@ def test_integration_test_4bit_batch2(self):

@pytest.mark.flash_attn_test
@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@require_bitsandbytes
def test_flash_attn_2_eager_equivalence(self):
# Create inputs
Expand Down
3 changes: 1 addition & 2 deletions tests/models/kosmos2_5/test_modeling_kosmos2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
require_flash_attn,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_vision,
slow,
torch_device,
Expand Down Expand Up @@ -467,7 +466,7 @@ def test_model_parallelism(self):
pass

# TODO: ydshieh
@require_torch_gpu
@require_torch_accelerator
@slow
@unittest.skip(reason="_update_causal_mask is not implemented yet which fails this test")
def test_sdpa_can_dispatch_on_flash(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/longcat_flash/test_modeling_longcat_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
require_flash_attn,
require_large_cpu_ram,
require_torch,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -285,7 +285,7 @@ def _prepare_config_headdim(config, requested_dim):
return config

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@require_bitsandbytes
@mark.flash_attn_test
@slow
Expand Down
4 changes: 2 additions & 2 deletions tests/models/mimi/test_modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
is_torch_available,
require_flash_attn,
require_torch,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_identity_shortcut(self):
self.model_tester.create_and_check_model_forward(config, inputs_dict)

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
@is_flaky()
Expand Down
5 changes: 2 additions & 3 deletions tests/models/modernbert/test_modeling_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
require_flash_attn,
require_torch,
require_torch_accelerator,
require_torch_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -344,14 +343,14 @@ def test_model_from_pretrained(self):
self.assertIsNotNone(model)

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="ModernBert flash attention does not support right padding")

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@pytest.mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
Expand Down
7 changes: 3 additions & 4 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_torch_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -282,7 +281,7 @@ def test_greedy_generate_stereo_outputs(self):
self.model_tester.audio_channels = original_audio_channels

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
Expand Down Expand Up @@ -362,7 +361,7 @@ def test_flash_attn_2_inference_equivalence(self):
_ = model_fa(dummy_input, **other_inputs)

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
# Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding
Expand Down Expand Up @@ -899,7 +898,7 @@ def test_greedy_generate_stereo_outputs(self):
self.model_tester.audio_channels = original_audio_channels

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
Expand Down
7 changes: 3 additions & 4 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
require_torch,
require_torch_accelerator,
require_torch_fp16,
require_torch_gpu,
require_torchaudio,
slow,
torch_device,
Expand Down Expand Up @@ -291,7 +290,7 @@ def test_greedy_generate_stereo_outputs(self):
self.model_tester.audio_channels = original_audio_channels

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence
Expand Down Expand Up @@ -373,7 +372,7 @@ def test_flash_attn_2_inference_equivalence(self):
_ = model_fa(dummy_input, **other_inputs)

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
# Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_inference_equivalence_right_padding
Expand Down Expand Up @@ -902,7 +901,7 @@ def test_greedy_generate_stereo_outputs(self):
self.model_tester.audio_channels = original_audio_channels

@require_flash_attn
@require_torch_gpu
@require_torch_accelerator
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
Expand Down
Loading