Skip to content

Commit b401b6b

Browse files
committed
fix qwen3vl ci
Signed-off-by: 李少鹏 <[email protected]>
1 parent 366d2d9 commit b401b6b

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

vllm_ascend/models/qwen3_vl.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from functools import partial
2020
from typing import Callable, Optional
2121

22+
import numpy as np
2223
import torch
2324
import torch.nn as nn
2425
import torch.nn.functional as F
@@ -143,14 +144,21 @@ def cal_cos_sin(self, rotary_pos_emb):
143144
def forward(
144145
self,
145146
x: torch.Tensor,
146-
grid_thw: list[list[int]],
147+
grid_thw: torch.Tensor | list[list[int]],
147148
) -> torch.Tensor:
148149
hidden_states = x.to(device=self.device, dtype=self.dtype)
149150
hidden_states = self.patch_embed(hidden_states)
150151

151-
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
152+
if isinstance(grid_thw, list):
153+
grid_thw_list = grid_thw
154+
grid_thw = np.array(grid_thw, dtype=np.int32)
155+
else:
156+
grid_thw_list = grid_thw.tolist()
157+
grid_thw = grid_thw.numpy()
158+
159+
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
152160
hidden_states = hidden_states + pos_embeds
153-
rotary_pos_emb = self.rot_pos_emb(grid_thw)
161+
rotary_pos_emb = self.rot_pos_emb(grid_thw_list)
154162
grid_thw_tensor = torch.tensor(grid_thw,
155163
device=self.device,
156164
dtype=torch.int32)

0 commit comments

Comments
 (0)