Skip to content

Commit 370fc65

Browse files
authored
add xpu support in test_modeling_janus.py::JanusIntegrationTest::test… (#41986)
* add xpu support in test_modeling_janus.py::JanusIntegrationTest::test_model_generate_images Signed-off-by: Wang, Yi A <[email protected]> * fix ci issue Signed-off-by: Wang, Yi A <[email protected]> --------- Signed-off-by: Wang, Yi A <[email protected]>
1 parent f065e40 commit 370fc65

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/transformers/models/janus/modeling_janus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1419,8 +1419,8 @@ def generate(
14191419
model_inputs = self.prepare_inputs_for_generation(
14201420
inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
14211421
)
1422-
1423-
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
1422+
if "attention_mask" in model_inputs:
1423+
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
14241424
model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)
14251425

14261426
outputs = self.model.language_model(

src/transformers/models/janus/modular_janus.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,8 +1235,8 @@ def generate(
12351235
model_inputs = self.prepare_inputs_for_generation(
12361236
inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
12371237
)
1238-
1239-
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
1238+
if "attention_mask" in model_inputs:
1239+
model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
12401240
model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)
12411241

12421242
outputs = self.model.language_model(

tests/models/janus/test_modeling_janus.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,12 @@ def test_model_generate_images(self):
514514
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
515515
897, 4044, 1762, 4676
516516
],
517+
("xpu", None): [
518+
4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465,
519+
13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305,
520+
15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165,
521+
897, 4044, 1762, 4676
522+
],
517523
}
518524
)
519525
expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)

0 commit comments

Comments
 (0)