diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index 4cad10fc4216..8357d070886e 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1419,8 +1419,8 @@ def generate( model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs ) - - model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) + if "attention_mask" in model_inputs: + model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device) outputs = self.model.language_model( diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 87cc11d73cda..019fe3c93be8 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -1235,8 +1235,8 @@ def generate( model_inputs = self.prepare_inputs_for_generation( inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs ) - - model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) + if "attention_mask" in model_inputs: + model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device) outputs = self.model.language_model( diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 76e49179fdb5..f56642a38e54 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -514,6 +514,12 @@ def test_model_generate_images(self): 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165, 897, 4044, 1762, 4676 ], + ("xpu", None): [ + 4484, 4015, 15750, 506, 3758, 11651, 8597, 5739, 4861, 971, 14985, 14834, 15438, 7548, 1820, 1465, + 13529, 12761, 10503, 12761, 14303, 6155, 4015, 11766, 705, 15736, 14146, 10417, 1951, 7713, 14305, + 15617, 6169, 2706, 8006, 14893, 3855, 10188, 15652, 6297, 1097, 12108, 15038, 311, 14998, 15165, + 897, 4044, 1762, 4676 + ], } ) expected_tokens = torch.tensor(expected_tokens.get_expectation()).to(model.device)