Skip to content

Commit 68c8800

Browse files
authored
Merge branch 'main' into main
2 parents 4efd413 + 57eeb9c commit 68c8800

File tree

10 files changed

+142
-175
lines changed

10 files changed

+142
-175
lines changed

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,11 +1958,8 @@ def forward(
19581958
audio_feature_lengths = None
19591959

19601960
if attention_mask is not None and position_ids is None:
1961-
if (
1962-
cache_position is None
1963-
or (cache_position is not None and cache_position[0] == 0)
1964-
or self.rope_deltas is None
1965-
):
1961+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1962+
if past_key_values_length == 0 or self.rope_deltas is None:
19661963
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
19671964
position_ids, rope_deltas = self.get_rope_index(
19681965
input_ids,
@@ -1977,7 +1974,7 @@ def forward(
19771974
self.rope_deltas = rope_deltas
19781975
else:
19791976
batch_size, seq_length = input_ids.shape
1980-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
1977+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
19811978
position_ids = torch.arange(seq_length, device=input_ids.device)
19821979
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
19831980
position_ids = position_ids.add(delta)
@@ -2366,11 +2363,8 @@ def forward(
23662363
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
23672364

23682365
if attention_mask is not None and position_ids is None:
2369-
if (
2370-
cache_position is None
2371-
or (cache_position is not None and cache_position[0] == 0)
2372-
or self.rope_deltas is None
2373-
):
2366+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
2367+
if past_key_values_length == 0 or self.rope_deltas is None:
23742368
position_ids, rope_deltas = self.get_rope_index(
23752369
input_text_ids,
23762370
image_grid_thw,
@@ -2390,8 +2384,8 @@ def forward(
23902384
self.rope_deltas = rope_deltas
23912385

23922386
else:
2393-
batch_size, seq_length = input_ids.shape
2394-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
2387+
batch_size, seq_length, _ = inputs_embeds.shape
2388+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
23952389
position_ids = torch.arange(seq_length, device=input_ids.device)
23962390
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
23972391
position_ids = position_ids.add(delta)

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,11 +2306,8 @@ def forward(
23062306
audio_feature_lengths = None
23072307

23082308
if attention_mask is not None and position_ids is None:
2309-
if (
2310-
cache_position is None
2311-
or (cache_position is not None and cache_position[0] == 0)
2312-
or self.rope_deltas is None
2313-
):
2309+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
2310+
if past_key_values_length == 0 or self.rope_deltas is None:
23142311
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
23152312
position_ids, rope_deltas = self.get_rope_index(
23162313
input_ids,
@@ -2325,7 +2322,7 @@ def forward(
23252322
self.rope_deltas = rope_deltas
23262323
else:
23272324
batch_size, seq_length = input_ids.shape
2328-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
2325+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
23292326
position_ids = torch.arange(seq_length, device=input_ids.device)
23302327
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
23312328
position_ids = position_ids.add(delta)
@@ -2567,11 +2564,8 @@ def forward(
25672564
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
25682565

25692566
if attention_mask is not None and position_ids is None:
2570-
if (
2571-
cache_position is None
2572-
or (cache_position is not None and cache_position[0] == 0)
2573-
or self.rope_deltas is None
2574-
):
2567+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
2568+
if past_key_values_length == 0 or self.rope_deltas is None:
25752569
position_ids, rope_deltas = self.get_rope_index(
25762570
input_text_ids,
25772571
image_grid_thw,
@@ -2591,8 +2585,8 @@ def forward(
25912585
self.rope_deltas = rope_deltas
25922586

25932587
else:
2594-
batch_size, seq_length = input_ids.shape
2595-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
2588+
batch_size, seq_length, _ = inputs_embeds.shape
2589+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
25962590
position_ids = torch.arange(seq_length, device=input_ids.device)
25972591
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
25982592
position_ids = position_ids.add(delta)

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,7 +1290,8 @@ def forward(
12901290
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
12911291

12921292
if position_ids is None:
1293-
if self.rope_deltas is None or cache_position is None or cache_position[0] == 0:
1293+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1294+
if self.rope_deltas is None or past_key_values_length == 0:
12941295
position_ids, rope_deltas = self.get_rope_index(
12951296
input_ids,
12961297
image_grid_thw,
@@ -1303,10 +1304,7 @@ def forward(
13031304
batch_size, seq_length, _ = inputs_embeds.shape
13041305
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
13051306
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1306-
if cache_position is not None:
1307-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1308-
else:
1309-
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
1307+
delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
13101308
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
13111309
position_ids = position_ids + delta.to(position_ids.device)
13121310

src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,8 @@ def forward(
595595
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
596596

597597
if position_ids is None:
598-
if self.rope_deltas is None or cache_position is None or cache_position[0] == 0:
598+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
599+
if self.rope_deltas is None or past_key_values_length == 0:
599600
position_ids, rope_deltas = self.get_rope_index(
600601
input_ids,
601602
image_grid_thw,
@@ -608,10 +609,7 @@ def forward(
608609
batch_size, seq_length, _ = inputs_embeds.shape
609610
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
610611
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
611-
if cache_position is not None:
612-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
613-
else:
614-
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
612+
delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
615613
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
616614
position_ids = position_ids + delta.to(position_ids.device)
617615

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
TransformersKwargs,
4343
auto_docstring,
4444
can_return_tuple,
45-
is_torchdynamo_compiling,
4645
logging,
4746
)
4847
from ..qwen2.modeling_qwen2 import (
@@ -1222,7 +1221,8 @@ def forward(
12221221
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
12231222

12241223
if position_ids is None:
1225-
if self.rope_deltas is None or cache_position is None or cache_position[0] == 0:
1224+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1225+
if self.rope_deltas is None or past_key_values_length == 0:
12261226
position_ids, rope_deltas = self.get_rope_index(
12271227
input_ids, image_grid_thw, video_grid_thw, attention_mask
12281228
)
@@ -1232,10 +1232,7 @@ def forward(
12321232
batch_size, seq_length, _ = inputs_embeds.shape
12331233
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
12341234
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
1235-
if cache_position is not None:
1236-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1237-
else:
1238-
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
1235+
delta = (past_key_values_length + self.rope_deltas).to(inputs_embeds.device)
12391236
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
12401237
position_ids = position_ids + delta.to(position_ids.device)
12411238

@@ -1443,15 +1440,7 @@ def prepare_inputs_for_generation(
14431440
# When compiling, we can't check tensor values thus we check only input length
14441441
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
14451442
# models currently cannot do asssisted decoding
1446-
prefill_compiled_stage = is_torchdynamo_compiling() and (
1447-
(input_ids is not None and input_ids.shape[1] != 1)
1448-
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1449-
)
1450-
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1451-
(cache_position is not None and cache_position[0] == 0)
1452-
or (past_key_values is None or past_key_values.get_seq_length() == 0)
1453-
)
1454-
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.model.rope_deltas is None:
1443+
if model_inputs["cache_position"][0] == 0 or self.model.rope_deltas is None:
14551444
vision_positions, rope_deltas = self.model.get_rope_index(
14561445
model_inputs.get("input_ids", None),
14571446
image_grid_thw=image_grid_thw,

src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,11 +2165,8 @@ def forward(
21652165
audio_feature_lengths = None
21662166

21672167
if attention_mask is not None and position_ids is None:
2168-
if (
2169-
cache_position is None
2170-
or (cache_position is not None and cache_position[0] == 0)
2171-
or self.rope_deltas is None
2172-
):
2168+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
2169+
if past_key_values_length == 0 or self.rope_deltas is None:
21732170
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
21742171
position_ids, rope_deltas = self.get_rope_index(
21752172
input_ids,
@@ -2184,7 +2181,7 @@ def forward(
21842181
self.rope_deltas = rope_deltas
21852182
else:
21862183
batch_size, seq_length = input_ids.shape
2187-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
2184+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
21882185
position_ids = torch.arange(seq_length, device=input_ids.device)
21892186
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
21902187
position_ids = position_ids.add(delta)
@@ -3103,12 +3100,9 @@ def forward(
31033100
if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
31043101
generation_step = -1
31053102
residual_codes = None
3106-
if attention_mask is not None:
3107-
if (
3108-
cache_position is None
3109-
or (cache_position is not None and cache_position[0] == 0)
3110-
or self.rope_deltas is None
3111-
):
3103+
if position_ids is None:
3104+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
3105+
if past_key_values_length == 0 or self.rope_deltas is None:
31123106
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
31133107
position_ids, rope_deltas = self.get_rope_index(
31143108
talker_input_ids,
@@ -3123,7 +3117,7 @@ def forward(
31233117
self.rope_deltas = rope_deltas
31243118
else:
31253119
batch_size, seq_length = input_ids.shape
3126-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
3120+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
31273121
position_ids = torch.arange(seq_length, device=input_ids.device)
31283122
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
31293123
position_ids = position_ids.add(delta)
@@ -3224,7 +3218,10 @@ def prepare_inputs_for_generation(
32243218
inputs = super().prepare_inputs_for_generation(
32253219
input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
32263220
)
3227-
# Decode stage
3221+
3222+
# Qwen3-Omni will prepare position ids in forward with deltas
3223+
inputs["position_ids"] = None
3224+
32283225
# TODO(raushan, gante): Refactor this part to a utility function
32293226
if cache_position[0] != 0:
32303227
input_ids = input_ids[:, -1:]

src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,11 +1521,8 @@ def forward(
15211521
audio_feature_lengths = None
15221522

15231523
if attention_mask is not None and position_ids is None:
1524-
if (
1525-
cache_position is None
1526-
or (cache_position is not None and cache_position[0] == 0)
1527-
or self.rope_deltas is None
1528-
):
1524+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1525+
if past_key_values_length == 0 or self.rope_deltas is None:
15291526
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
15301527
position_ids, rope_deltas = self.get_rope_index(
15311528
input_ids,
@@ -1540,7 +1537,7 @@ def forward(
15401537
self.rope_deltas = rope_deltas
15411538
else:
15421539
batch_size, seq_length = input_ids.shape
1543-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
1540+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
15441541
position_ids = torch.arange(seq_length, device=input_ids.device)
15451542
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
15461543
position_ids = position_ids.add(delta)
@@ -1961,12 +1958,9 @@ def forward(
19611958
if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
19621959
generation_step = -1
19631960
residual_codes = None
1964-
if attention_mask is not None:
1965-
if (
1966-
cache_position is None
1967-
or (cache_position is not None and cache_position[0] == 0)
1968-
or self.rope_deltas is None
1969-
):
1961+
if position_ids is None:
1962+
past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
1963+
if past_key_values_length == 0 or self.rope_deltas is None:
19701964
delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
19711965
position_ids, rope_deltas = self.get_rope_index(
19721966
talker_input_ids,
@@ -1981,7 +1975,7 @@ def forward(
19811975
self.rope_deltas = rope_deltas
19821976
else:
19831977
batch_size, seq_length = input_ids.shape
1984-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
1978+
delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
19851979
position_ids = torch.arange(seq_length, device=input_ids.device)
19861980
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
19871981
position_ids = position_ids.add(delta)
@@ -2044,7 +2038,10 @@ def prepare_inputs_for_generation(
20442038
inputs = super().prepare_inputs_for_generation(
20452039
input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
20462040
)
2047-
# Decode stage
2041+
2042+
# Qwen3-Omni will prepare position ids in forward with deltas
2043+
inputs["position_ids"] = None
2044+
20482045
# TODO(raushan, gante): Refactor this part to a utility function
20492046
if cache_position[0] != 0:
20502047
input_ids = input_ids[:, -1:]

0 commit comments

Comments
 (0)