File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed
Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change 1919from functools import partial
2020from typing import Callable , Optional
2121
22+ import numpy as np
2223import torch
2324import torch .nn as nn
2425import torch .nn .functional as F
@@ -143,14 +144,21 @@ def cal_cos_sin(self, rotary_pos_emb):
143144 def forward (
144145 self ,
145146 x : torch .Tensor ,
146- grid_thw : list [list [int ]],
147+ grid_thw : torch . Tensor | list [list [int ]],
147148 ) -> torch .Tensor :
148149 hidden_states = x .to (device = self .device , dtype = self .dtype )
149150 hidden_states = self .patch_embed (hidden_states )
150151
151- pos_embeds = self .fast_pos_embed_interpolate (grid_thw )
152+ if isinstance (grid_thw , list ):
153+ grid_thw_list = grid_thw
154+ grid_thw = np .array (grid_thw , dtype = np .int32 )
155+ else :
156+ grid_thw_list = grid_thw .tolist ()
157+ grid_thw = grid_thw .numpy ()
158+
159+ pos_embeds = self .fast_pos_embed_interpolate (grid_thw_list )
152160 hidden_states = hidden_states + pos_embeds
153- rotary_pos_emb = self .rot_pos_emb (grid_thw )
161+ rotary_pos_emb = self .rot_pos_emb (grid_thw_list )
154162 grid_thw_tensor = torch .tensor (grid_thw ,
155163 device = self .device ,
156164 dtype = torch .int32 )
You can’t perform that action at this time.
0 commit comments