3232from vllm .model_executor .layers .layernorm import RMSNorm
3333from vllm .model_executor .layers .quantization import QuantizationConfig
3434from vllm .model_executor .layers .rotary_embedding import get_rope
35- from vllm .model_executor .layers .rotary_embedding .common import (
36- apply_rotary_emb_torch , dispatch_rotary_emb_function )
3735from vllm .model_executor .models .qwen2_5_vl import (
3836 Qwen2_5_VisionAttention , Qwen2_5_VisionBlock , Qwen2_5_VisionPatchEmbed ,
3937 Qwen2_5_VisionPatchMerger , Qwen2_5_VisionTransformer ,
@@ -69,36 +67,50 @@ def forward(
6967 x , _ = self .qkv (x )
7068 seq_len , batch_size , _ = x .shape
7169
72- # Split q k v.
7370 qkv = einops .rearrange (
7471 x ,
7572 "s b (three head head_dim) -> b s three head head_dim" ,
7673 three = 3 ,
7774 head = self .num_attention_heads_per_partition ,
7875 )
79- q , k , v = qkv [:, :, 0 ], qkv [:, :, 1 ], qkv [:, :, 2 ]
80- origin_shape = q .shape [- 1 ]
8176
8277 # Convert cumulative tensor to intervals and move it to cpu.
8378 cu_seqlens = torch .diff (cu_seqlens ).to ("cpu" )
8479
85- cos = torch .cat ((rotary_pos_emb_cos , rotary_pos_emb_cos ), dim = - 1 )
86- sin = torch .cat ((rotary_pos_emb_sin , rotary_pos_emb_sin ), dim = - 1 )
87- cos = cos .reshape (1 , - 1 , 1 , self .hidden_size_per_attention_head )
88- sin = sin .reshape (1 , - 1 , 1 , self .hidden_size_per_attention_head )
89- q = torch_npu .npu_rotary_mul (q , cos , sin )
90- k = torch_npu .npu_rotary_mul (k , cos , sin )
80+ if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None :
81+ qk , v = qkv [:, :, :2 ], qkv [:, :, 2 ]
9182
92- q , k , v = [
93- einops .rearrange (x , "b s h d -> (b s) h d" ).contiguous ()
94- for x in (q , k , v )
95- ]
83+ qk_reshaped = einops .rearrange (
84+ qk , "b s two head head_dim -> (two b) s head head_dim" , two = 2 )
85+ qk_rotated = self .apply_rotary_emb (
86+ qk_reshaped ,
87+ rotary_pos_emb_cos ,
88+ rotary_pos_emb_sin ,
89+ )
90+ qk_rotated = qk_rotated .view (
91+ 2 ,
92+ batch_size ,
93+ seq_len ,
94+ self .num_attention_heads_per_partition ,
95+ self .hidden_size_per_attention_head ,
96+ )
97+ q , k = qk_rotated .unbind (dim = 0 )
98+ else :
99+ q , k , v = qkv .unbind (dim = 2 )
96100
101+ # TODO(shen-shanshan): Move codes below to MMEncoderAttention CustomOp
102+ # ----------------------------------------------------------------------
97103 enable_pad = (envs_ascend .USE_OPTIMIZED_MODEL
98104 and self .hidden_size_per_attention_head > MIN_PAD_SIZE
99105 and self .hidden_size_per_attention_head < MAX_PAD_SIZE )
100106
107+ q , k , v = [
108+ einops .rearrange (x , "b s h d -> (b s) h d" ).contiguous ()
109+ for x in (q , k , v )
110+ ]
111+
101112 if enable_pad :
113+ origin_shape = q .shape [- 1 ]
102114 pad_len = MAX_PAD_SIZE - origin_shape
103115 # q/k/v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE]
104116 q = F .pad (q , (0 , pad_len ), mode = "constant" , value = 0 )
@@ -125,6 +137,7 @@ def forward(
125137 context_layer = einops .rearrange (context_layer ,
126138 "(b s) h d -> s b (h d)" ,
127139 b = batch_size ).contiguous ()
140+ # ----------------------------------------------------------------------
128141
129142 output , _ = self .proj (context_layer )
130143 return output
@@ -650,14 +663,6 @@ def _process_video_input(
650663 return video_embeds .split (sizes )
651664
652665
653- def _apply_rotary_pos_emb_vision (t : torch .Tensor , cos : torch .Tensor ,
654- sin : torch .Tensor ) -> torch .Tensor :
655- rotary_emb_function = dispatch_rotary_emb_function (
656- default = partial (apply_rotary_emb_torch , is_neox_style = True ))
657- output = rotary_emb_function (t , cos , sin ).type_as (t )
658- return output
659-
660-
661666# NOTE: This will be removed after MMEncoderAttention has been extract as a CustomOp in vllm.
662667Qwen2VisionAttention .forward = AscendQwen2_5_VisionAttention .forward
663668Qwen2_5_VisionAttention .forward = AscendQwen2_5_VisionAttention .forward
@@ -676,4 +681,3 @@ def _apply_rotary_pos_emb_vision(t: torch.Tensor, cos: torch.Tensor,
676681Qwen2_5_VisionTransformer .rotary_pos_emb_thw = AscendQwen2_5_VisionTransformer .rotary_pos_emb_thw
677682Qwen2_5_VisionTransformer .get_rope_by_thw = AscendQwen2_5_VisionTransformer .get_rope_by_thw
678683Qwen2_5_VisionTransformer .forward = AscendQwen2_5_VisionTransformer .forward
679- apply_rotary_pos_emb_vision = _apply_rotary_pos_emb_vision
0 commit comments