Skip to content

Commit 5c88a1d

Browse files
_supports_flash_attn = False on vision encoder
1 parent 98dfce2 commit 5c88a1d

File tree

3 files changed

+2
-30
lines changed

3 files changed

+2
-30
lines changed

src/transformers/models/lightonocr/modeling_lightonocr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class LightOnOCRVisionPreTrainedModel(PreTrainedModel):
156156
main_input_name = "pixel_values"
157157
supports_gradient_checkpointing = True
158158
_supports_attention_backend = True
159-
_supports_flash_attn = True
159+
_supports_flash_attn = False
160160
_supports_sdpa = True
161161
_supports_flex_attn = True
162162
_no_split_modules = ["LightOnOCRVisionAttentionLayer"]

src/transformers/models/lightonocr/modular_lightonocr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ class LightOnOCRVisionPreTrainedModel(PreTrainedModel):
417417
main_input_name = "pixel_values"
418418
supports_gradient_checkpointing = True
419419
_supports_attention_backend = True
420-
_supports_flash_attn = True
420+
_supports_flash_attn = False
421421
_supports_sdpa = True
422422
_supports_flex_attn = True
423423
_no_split_modules = ["LightOnOCRVisionAttentionLayer"]

tests/models/lightonocr/test_modeling_lightonocr.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -412,34 +412,6 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
412412
def test_flash_attn_2_fp32_ln(self):
413413
pass
414414

415-
@unittest.skip("Pixtral does not support attention interfaces.")
416-
def test_eager_matches_fa2_generate(self):
417-
pass
418-
419-
@unittest.skip("Pixtral does not support attention interfaces.")
420-
def test_eager_matches_sdpa_generate(self):
421-
pass
422-
423-
@unittest.skip("Pixtral does not support attention interfaces.")
424-
def test_flash_attn_2_from_config(self):
425-
pass
426-
427-
@unittest.skip("Pixtral does not support attention interfaces.")
428-
def test_flash_attn_2_inference_equivalence(self):
429-
pass
430-
431-
@unittest.skip("Pixtral does not support attention interfaces.")
432-
def test_flash_attn_2_inference_equivalence_right_padding(self):
433-
pass
434-
435-
@unittest.skip("Pixtral does not support attention interfaces.")
436-
def test_sdpa_can_dispatch_on_flash(self):
437-
pass
438-
439-
@unittest.skip("Pixtral does not support attention interfaces.")
440-
def test_flex_attention_with_grads(self):
441-
pass
442-
443415
def test_initialization(self):
444416
"""
445417
Test that model initializes correctly with proper weight initialization.

0 commit comments

Comments
 (0)