@@ -319,13 +319,11 @@ def forward(
319319 ) -> torch .Tensor :
320320 # patchify
321321 seq_len , _ = x .size ()
322- rotary_pos_emb_cos : list [torch .Tensor ] = []
323- rotary_pos_emb_sin : list [torch .Tensor ] = []
324- window_index : list [torch .Tensor ] = []
325- cu_window_seqlens : list [torch .Tensor ] = [
326- torch .tensor ([0 ], dtype = torch .int32 )
327- ]
328- cu_seqlens : list [torch .Tensor ] = []
322+ rotary_pos_emb_cos : list = []
323+ rotary_pos_emb_sin : list = []
324+ window_index : list = []
325+ cu_window_seqlens : list = [torch .tensor ([0 ], dtype = torch .int32 )]
326+ cu_seqlens : list = []
329327
330328 hidden_states = x .to (device = self .device , dtype = self .dtype )
331329 hidden_states = self .patch_embed (hidden_states )
@@ -375,15 +373,21 @@ def forward(
375373 max_seqlen_window , seqlens_window = self .compute_attn_mask_seqlen (
376374 cu_window_seqlens )
377375
378- cu_seqlens = cu_seqlens .to (device = self .device , non_blocking = True )
379- cu_window_seqlens = cu_window_seqlens .to (device = self .device ,
380- non_blocking = True )
381- rotary_pos_emb_cos = rotary_pos_emb_cos .to (device = self .device ,
382- non_blocking = True )
383- rotary_pos_emb_sin = rotary_pos_emb_sin .to (device = self .device ,
384- non_blocking = True )
385- window_index = window_index .to (device = hidden_states .device ,
386- non_blocking = True )
376+ cu_seqlens = cu_seqlens .to ( # type: ignore[attr-defined]
377+ device = self .device ,
378+ non_blocking = True )
379+ cu_window_seqlens = cu_window_seqlens .to ( # type: ignore[attr-defined]
380+ device = self .device ,
381+ non_blocking = True )
382+ rotary_pos_emb_cos = rotary_pos_emb_cos .to ( # type: ignore[attr-defined]
383+ device = self .device ,
384+ non_blocking = True )
385+ rotary_pos_emb_sin = rotary_pos_emb_sin .to ( # type: ignore[attr-defined]
386+ device = self .device ,
387+ non_blocking = True )
388+ window_index = window_index .to ( # type: ignore[attr-defined]
389+ device = hidden_states .device ,
390+ non_blocking = True )
387391 reverse_indices = reverse_indices .to (device = hidden_states .device ,
388392 non_blocking = True )
389393
0 commit comments