1+ #
2+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+ # Adapted from vllm/model_executor/models/qwen2_5_vl.py
4+ # Copyright 2023 The vLLM team.
5+ #
6+ # This file is a part of the vllm-ascend project.
7+ #
8+ # Licensed under the Apache License, Version 2.0 (the "License");
9+ # you may not use this file except in compliance with the License.
10+ # You may obtain a copy of the License at
11+ #
12+ # http://www.apache.org/licenses/LICENSE-2.0
13+ #
14+ # Unless required by applicable law or agreed to in writing, software
15+ # distributed under the License is distributed on an "AS IS" BASIS,
16+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+ # See the License for the specific language governing permissions and
18+ # limitations under the License.
19+
20+ from functools import partial
21+ from typing import Callable , Iterable , Optional , Set , Tuple
22+
23+ import torch
24+ import torch .nn as nn
25+ import torch .nn .functional as F
26+ import torch_npu
27+ from einops import rearrange
28+ from transformers .models .qwen2_5_vl .configuration_qwen2_5_vl import (
29+ Qwen2_5_VLConfig , Qwen2_5_VLVisionConfig )
30+ from vllm .config import VllmConfig
31+ from vllm .distributed import parallel_state
32+ from vllm .distributed import utils as dist_utils
33+ from vllm .model_executor .layers .activation import _ACTIVATION_REGISTRY
34+ from vllm .model_executor .layers .layernorm import RMSNorm
35+ from vllm .model_executor .layers .quantization import QuantizationConfig
36+ from vllm .model_executor .model_loader .weight_utils import default_weight_loader
37+ from vllm .model_executor .models .qwen2_5_vl import (
38+ Qwen2_5_VisionAttention , Qwen2_5_VisionBlock , Qwen2_5_VisionPatchEmbed ,
39+ Qwen2_5_VisionTransformer , Qwen2_5_VLDummyInputsBuilder ,
40+ Qwen2_5_VLForConditionalGeneration , Qwen2_5_VLMultiModalProcessor ,
41+ Qwen2_5_VLProcessingInfo )
42+ from vllm .model_executor .models .utils import maybe_prefix
43+ from vllm .multimodal import MULTIMODAL_REGISTRY
44+
45+ MIN_PAD_SIZE = 64 # min_size to pad weight
46+ MAX_PAD_SIZE = 128 # max_size to pad weight
47+
48+
49+ class AscendQwen2_5_VisionAttention (Qwen2_5_VisionAttention ):
50+
51+ def __init__ (
52+ self ,
53+ embed_dim : int ,
54+ num_heads : int ,
55+ projection_size : int ,
56+ quant_config : Optional [QuantizationConfig ] = None ,
57+ prefix : str = "" ,
58+ ) -> None :
59+ super ().__init__ (
60+ embed_dim ,
61+ num_heads ,
62+ projection_size ,
63+ quant_config ,
64+ prefix ,
65+ )
66+ self .embed_dim = embed_dim
67+ self .hidden_size_per_attention_head = dist_utils .divide (
68+ projection_size , num_heads )
69+ self .origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head
70+ if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
71+ self .hidden_size_per_attention_head = MAX_PAD_SIZE
72+
73+ def forward (
74+ self ,
75+ x : torch .Tensor ,
76+ cu_seqlens : torch .Tensor ,
77+ cos : torch .Tensor ,
78+ sin : torch .Tensor ,
79+ ) -> torch .Tensor :
80+ # [s, b, c] --> [s, b, head * 3 * head_dim]
81+ x , _ = self .qkv (x )
82+
83+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
84+ q , k , v = self .split_qkv (x )
85+ batch_size = q .shape [1 ]
86+
87+ q , k , v = (rearrange (x , "s b ... -> b s ..." ).contiguous ()
88+ for x in (q , k , v ))
89+ q = torch_npu .npu_rotary_mul (q , cos , sin )
90+ k = torch_npu .npu_rotary_mul (k , cos , sin )
91+
92+ q , k , v = [
93+ rearrange (x , "b s h d -> (b s) h d" ).contiguous ()
94+ for x in (q , k , v )
95+ ]
96+
97+ context_layer = torch .torch .empty_like (q )
98+
99+ # operator requires pta version >= 2.5.1
100+ torch_npu ._npu_flash_attention_unpad (
101+ query = q ,
102+ key = k ,
103+ value = v ,
104+ seq_len = cu_seqlens ,
105+ scale_value = self .origin_hidden_size_per_attention_head ** - 0.5 ,
106+ num_heads = self .num_attention_heads_per_partition ,
107+ num_kv_heads = self .num_attention_heads_per_partition ,
108+ out = context_layer )
109+
110+ context_layer = rearrange (context_layer ,
111+ "(b s) h d -> s b (h d)" ,
112+ b = batch_size ).contiguous ()
113+
114+ output , _ = self .proj (context_layer )
115+ return output
116+
117+
118+ class AscendQwen2_5_VisionBlock (Qwen2_5_VisionBlock ):
119+
120+ def __init__ (
121+ self ,
122+ dim : int ,
123+ num_heads : int ,
124+ mlp_hidden_dim : int ,
125+ act_fn : Callable [[torch .Tensor ], torch .Tensor ] = F .silu ,
126+ norm_layer : Optional [Callable [[int ], nn .Module ]] = None ,
127+ quant_config : Optional [QuantizationConfig ] = None ,
128+ prefix : str = "" ,
129+ ) -> None :
130+ super ().__init__ (dim , num_heads , mlp_hidden_dim , act_fn , norm_layer ,
131+ quant_config , prefix )
132+ self .attn = AscendQwen2_5_VisionAttention (embed_dim = dim ,
133+ num_heads = num_heads ,
134+ projection_size = dim ,
135+ quant_config = quant_config ,
136+ prefix = f"{ prefix } .attn" )
137+
138+ def forward (self , x : torch .Tensor , cu_seqlens : torch .Tensor ,
139+ cos : torch .Tensor , sin : torch .Tensor ) -> torch .Tensor :
140+ x = x + self .attn (
141+ self .norm1 (x ), cu_seqlens = cu_seqlens , cos = cos , sin = sin )
142+
143+ x = x + self .mlp (self .norm2 (x ))
144+ return x
145+
146+
147+ class AscendQwen2_5_VisionPatchEmbed (Qwen2_5_VisionPatchEmbed ):
148+
149+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
150+ x = x .matmul (
151+ self .proj .weight .data .view (self .hidden_size , - 1 ).transpose (0 , 1 ))
152+ return x
153+
154+
155+ class AscendQwen2_5_VisionTransformer (Qwen2_5_VisionTransformer ):
156+
157+ def __init__ (
158+ self ,
159+ vision_config : Qwen2_5_VLVisionConfig ,
160+ norm_eps : float = 1e-6 ,
161+ quant_config : Optional [QuantizationConfig ] = None ,
162+ prefix : str = "" ,
163+ interleaved = False ,
164+ ) -> None :
165+ super ().__init__ (vision_config , norm_eps , quant_config , prefix )
166+ norm_layer = partial (RMSNorm , eps = norm_eps )
167+ self .interleaved = interleaved
168+ self .enable_pad = False
169+ self .patch_embed = AscendQwen2_5_VisionPatchEmbed (
170+ patch_size = vision_config .patch_size ,
171+ temporal_patch_size = vision_config .temporal_patch_size ,
172+ in_channels = vision_config .in_channels ,
173+ hidden_size = self .hidden_size ,
174+ )
175+ self .blocks = nn .ModuleList ([
176+ AscendQwen2_5_VisionBlock (
177+ dim = self .hidden_size ,
178+ num_heads = self .num_heads ,
179+ mlp_hidden_dim = vision_config .intermediate_size ,
180+ act_fn = _ACTIVATION_REGISTRY [vision_config .hidden_act ],
181+ norm_layer = norm_layer ,
182+ quant_config = quant_config ,
183+ prefix = f"{ prefix } .blocks.{ layer_idx } " )
184+ for layer_idx in range (vision_config .depth )
185+ ])
186+ self .tp_size = parallel_state .get_tensor_model_parallel_world_size ()
187+ self .tp_rank = parallel_state .get_tensor_model_parallel_rank ()
188+ self .hidden_size_per_attention_head = dist_utils .divide (
189+ self .hidden_size , self .num_heads )
190+
191+ if self .hidden_size_per_attention_head > MIN_PAD_SIZE and self .hidden_size_per_attention_head < MAX_PAD_SIZE :
192+ self .enable_pad = True
193+ self .origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head
194+ self .half_origin_hidden_size_per_attention_head = self .hidden_size_per_attention_head // 2
195+ self .half_pad_hidden_size_per_attention_head = (
196+ MAX_PAD_SIZE - self .hidden_size_per_attention_head ) // 2
197+ self .hidden_size_per_attention_head = MAX_PAD_SIZE
198+
199+ def cal_cos_sin (self , rotary_pos_emb ):
200+ cos = rotary_pos_emb .cos () # [seqlen, rotary_dim / 2]
201+ sin = rotary_pos_emb .sin ()
202+ if self .enable_pad :
203+ cos = torch .nn .functional .pad (
204+ cos , (0 , self .half_pad_hidden_size_per_attention_head ))
205+ sin = torch .nn .functional .pad (
206+ sin , (0 , self .half_pad_hidden_size_per_attention_head ))
207+
208+ if not self .interleaved :
209+ cos_new = torch .cat ((cos , cos ), dim = - 1 )
210+ sin_new = torch .cat ((sin , sin ), dim = - 1 )
211+ else :
212+ cos_new = rearrange (torch .stack ((cos , cos ), dim = - 1 ),
213+ "... d two -> ...(d two)" ,
214+ two = 2 )
215+ sin_new = rearrange (torch .stack ((sin , sin ), dim = - 1 ),
216+ "... d two -> ...(d two)" ,
217+ two = 2 )
218+ cos_new = cos_new .reshape (1 , - 1 , 1 ,
219+ self .hidden_size_per_attention_head )
220+ sin_new = sin_new .reshape (1 , - 1 , 1 ,
221+ self .hidden_size_per_attention_head )
222+ return cos_new , sin_new
223+
224+ def pad_qkv_bias (self , bias ):
225+ first_half = bias .reshape (
226+ - 1 , 3 , self .origin_hidden_size_per_attention_head
227+ )[:, :, :self .half_origin_hidden_size_per_attention_head ]
228+ second_half = bias .reshape (
229+ - 1 , 3 , self .origin_hidden_size_per_attention_head
230+ )[:, :, self .half_origin_hidden_size_per_attention_head :]
231+ first_half_padded = torch .nn .functional .pad (
232+ first_half , (0 , self .half_pad_hidden_size_per_attention_head ))
233+ second_half_padded = torch .nn .functional .pad (
234+ second_half , (0 , self .half_pad_hidden_size_per_attention_head ))
235+ bias_padded = torch .cat ([first_half_padded , second_half_padded ], dim = 2 )
236+ bias_final = bias_padded .reshape (- 1 )
237+ return bias_final
238+
239+ def pad_qkv_weight (self , data ):
240+ qkv_weight_first_half = data .reshape (
241+ - 1 , 3 , self .origin_hidden_size_per_attention_head , self .hidden_size
242+ )[:, :, :self .half_origin_hidden_size_per_attention_head , :]
243+ qkv_weight_second_half = data .reshape (
244+ - 1 , 3 , self .origin_hidden_size_per_attention_head , self .hidden_size
245+ )[:, :, self .half_origin_hidden_size_per_attention_head :, :]
246+
247+ qkv_weight_first_half_padded = torch .nn .functional .pad (
248+ qkv_weight_first_half ,
249+ (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head ))
250+ qkv_weight_second_half_padded = torch .nn .functional .pad (
251+ qkv_weight_second_half ,
252+ (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head ))
253+ qkv_weight_padded = torch .cat (
254+ [qkv_weight_first_half_padded , qkv_weight_second_half_padded ],
255+ dim = 2 )
256+ qkv_weight_final = qkv_weight_padded .reshape (- 1 , self .hidden_size )
257+ return qkv_weight_final
258+
259+ def pad_proj_weight (self , data ):
260+ out_weight = torch .nn .functional .pad (
261+ data .reshape (self .hidden_size , - 1 ,
262+ self .half_origin_hidden_size_per_attention_head ),
263+ (0 , self .half_pad_hidden_size_per_attention_head , 0 , 0 )).reshape (
264+ self .hidden_size , - 1 )
265+ return out_weight
266+
267+ def load_weights (self , weights : Iterable [Tuple [str ,
268+ torch .Tensor ]]) -> Set [str ]:
269+ stacked_params_mapping = [
270+ # (param_name, shard_name, shard_id)
271+ ("qkv_proj" , "q_proj" , "q" ),
272+ ("qkv_proj" , "k_proj" , "k" ),
273+ ("qkv_proj" , "v_proj" , "v" ),
274+ ]
275+ params_dict = dict (self .named_parameters (remove_duplicate = False ))
276+ loaded_params : Set [str ] = set ()
277+ for name , loaded_weight in weights :
278+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
279+ if weight_name not in name :
280+ continue
281+ name = name .replace (weight_name , param_name )
282+
283+ param = params_dict [name ]
284+ weight_loader = param .weight_loader
285+ weight_loader (param , loaded_weight , shard_id )
286+ break
287+ else :
288+ param = params_dict [name ]
289+ weight_loader = getattr (param , "weight_loader" ,
290+ default_weight_loader )
291+ weight_loader (param , loaded_weight )
292+ if ("attn.proj.weight" in name ) and self .enable_pad :
293+ param .data = self .pad_proj_weight (param .data )
294+ if ("attn.qkv.weight" in name ) and self .enable_pad :
295+ param .data = self .pad_qkv_weight (param .data )
296+ if ("attn.qkv.bias" in name ) and self .enable_pad :
297+ param .data = self .pad_qkv_bias (param .data )
298+ loaded_params .add (name )
299+ return loaded_params
300+
301+ def forward (
302+ self ,
303+ x : torch .Tensor ,
304+ grid_thw : torch .Tensor ,
305+ ) -> torch .Tensor :
306+ # compute cu_seqlens
307+ cu_seqlens = torch .repeat_interleave (grid_thw [:, 1 ] * grid_thw [:, 2 ],
308+ grid_thw [:,
309+ 0 ]).cpu ().to (torch .int32 )
310+
311+ # patchify
312+ x = self .patch_embed (x )
313+
314+ # compute position embedding
315+ rotary_pos_emb = self .rot_pos_emb (grid_thw )
316+
317+ # windows attention
318+ window_index , cu_window_seqlens = self .get_window_index (grid_thw )
319+ cu_window_seqlens = torch .tensor (
320+ cu_window_seqlens ,
321+ device = x .device ,
322+ dtype = grid_thw .dtype if torch .jit .is_tracing () else torch .int32 )
323+ cu_window_seqlens = torch .unique_consecutive (cu_window_seqlens )
324+ cu_window_seqlens = torch .diff (cu_window_seqlens ).cpu ().to (torch .int32 )
325+ seq_len , _ = x .size ()
326+ x = x .reshape (seq_len // self .spatial_merge_unit ,
327+ self .spatial_merge_unit , - 1 )
328+ x = x [window_index , :, :]
329+ x = x .reshape (seq_len , - 1 )
330+ rotary_pos_emb = rotary_pos_emb .reshape (
331+ seq_len // self .spatial_merge_unit , self .spatial_merge_unit , - 1 )
332+ rotary_pos_emb = rotary_pos_emb [window_index , :, :]
333+ rotary_pos_emb = rotary_pos_emb .reshape (seq_len , - 1 )
334+
335+ cos , sin = self .cal_cos_sin (rotary_pos_emb )
336+
337+ # transformers
338+ x = x .unsqueeze (1 )
339+ for layer_num , blk in enumerate (self .blocks ):
340+ if layer_num in self .fullatt_block_indexes :
341+ cu_seqlens_now = cu_seqlens
342+ else :
343+ cu_seqlens_now = cu_window_seqlens
344+ x = blk (x , cu_seqlens = cu_seqlens_now , cos = cos , sin = sin )
345+
346+ # adapter
347+ x = self .merger (x )
348+ reverse_indices = torch .argsort (window_index )
349+ x = x [reverse_indices , :]
350+ return x
351+
352+
353+ @MULTIMODAL_REGISTRY .register_processor (
354+ Qwen2_5_VLMultiModalProcessor ,
355+ info = Qwen2_5_VLProcessingInfo ,
356+ dummy_inputs = Qwen2_5_VLDummyInputsBuilder )
357+ class AscendQwen2_5_VLForConditionalGeneration (
358+ Qwen2_5_VLForConditionalGeneration ):
359+
360+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
361+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
362+ config : Qwen2_5_VLConfig = vllm_config .model_config .hf_config
363+ quant_config = vllm_config .quant_config
364+ self .visual = AscendQwen2_5_VisionTransformer (
365+ vision_config = config .vision_config ,
366+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
367+ quant_config = self ._maybe_ignore_quant_config (quant_config ),
368+ prefix = maybe_prefix (prefix , "visual" ),
369+ )
0 commit comments