@@ -155,11 +155,11 @@ def wrapper(*args, **kwargs):
155155 CHUNK_SIZE = envs .VLLM_FUSED_MOE_CHUNK_SIZE
156156 num_tokens = hidden_states .size (0 )
157157 M = min (num_tokens , CHUNK_SIZE )
158-
158+ max_lora_rank = self . w13_lora_a_stacked [ 0 ]. shape [ - 2 ]
159159 shrink_config , expand_config = self ._get_lora_moe_configs (
160160 op_prefix = "w13" ,
161161 num_loras = self .max_loras ,
162- rank = self . max_lora_rank ,
162+ rank = max_lora_rank ,
163163 num_slices = self .w13_slices ,
164164 M = M ,
165165 layer = layer ,
@@ -190,6 +190,7 @@ def wrapper(*args, **kwargs):
190190
191191 expert_ids_lora = expert_ids_lora .view (self .max_loras , - 1 )
192192 sorted_token_ids_lora = sorted_token_ids_lora .view (self .max_loras , - 1 )
193+ #
193194
194195 self .punica_wrapper .add_lora_fused_moe (
195196 input .view (- 1 , top_k , input .shape [- 1 ]),
@@ -200,7 +201,7 @@ def wrapper(*args, **kwargs):
200201 sorted_token_ids_lora ,
201202 expert_ids_lora ,
202203 num_tokens_post_padded_lora ,
203- self . max_lora_rank ,
204+ max_lora_rank ,
204205 top_k ,
205206 shrink_config , ## pass the shrink config
206207 expand_config , ## pass the expand config
@@ -229,11 +230,11 @@ def wrapper(*args, **kwargs):
229230 CHUNK_SIZE = envs .VLLM_FUSED_MOE_CHUNK_SIZE
230231 num_tokens = hidden_states .size (0 )
231232 M = min (num_tokens , CHUNK_SIZE )
232-
233+ max_lora_rank = self . w2_lora_a_stacked . shape [ - 2 ]
233234 shrink_config , expand_config = self ._get_lora_moe_configs (
234235 op_prefix = "w2" ,
235236 num_loras = self .max_loras ,
236- rank = self . max_lora_rank ,
237+ rank = max_lora_rank ,
237238 num_slices = 1 ,
238239 M = M ,
239240 layer = layer ,
@@ -263,7 +264,7 @@ def wrapper(*args, **kwargs):
263264 sorted_token_ids_lora ,
264265 expert_ids_lora ,
265266 num_tokens_post_padded_lora ,
266- self . max_lora_rank ,
267+ max_lora_rank ,
267268 top_k ,
268269 shrink_config , ## pass the shrink config
269270 expand_config , ## pass the expand config
@@ -300,7 +301,6 @@ def create_lora_weights(
300301 """Initializes lora matrices."""
301302 assert self .w13_slices == 2
302303 self .max_loras = lora_config .max_loras
303- self .max_lora_rank = lora_config .max_lora_rank
304304 self .fully_sharded = lora_config .fully_sharded_loras
305305
306306 self .adapter_enabled = torch .tensor (
@@ -312,9 +312,9 @@ def create_lora_weights(
312312 (
313313 max_loras ,
314314 self .base_layer .local_num_experts ,
315- self .max_lora_rank
315+ lora_config .max_lora_rank
316316 if not self .fully_sharded
317- else divide (self .max_lora_rank , self .tp_size ),
317+ else divide (lora_config .max_lora_rank , self .tp_size ),
318318 self .base_layer .hidden_size ,
319319 ),
320320 dtype = lora_config .lora_dtype ,
@@ -329,7 +329,7 @@ def create_lora_weights(
329329 max_loras ,
330330 self .base_layer .local_num_experts ,
331331 self .base_layer .intermediate_size_per_partition ,
332- self .max_lora_rank ,
332+ lora_config .max_lora_rank ,
333333 ),
334334 dtype = lora_config .lora_dtype ,
335335 device = self .device ,
@@ -341,7 +341,7 @@ def create_lora_weights(
341341 (
342342 max_loras ,
343343 self .base_layer .local_num_experts ,
344- self .max_lora_rank ,
344+ lora_config .max_lora_rank ,
345345 self .base_layer .intermediate_size_per_partition ,
346346 ),
347347 dtype = lora_config .lora_dtype ,
@@ -354,7 +354,7 @@ def create_lora_weights(
354354 self .base_layer .hidden_size
355355 if not self .fully_sharded
356356 else divide (self .base_layer .hidden_size , self .tp_size ),
357- self .max_lora_rank ,
357+ lora_config .max_lora_rank ,
358358 ),
359359 dtype = lora_config .lora_dtype ,
360360 device = self .device ,
@@ -426,7 +426,7 @@ def set_lora(
426426 if self .fully_sharded :
427427 # Based on S-LoRA, we slice W1 and W3 A along the rank dim,
428428 # and W2 B along the hidden_size dim.
429- w13_shard_size = self .w1_lora_a_stacked [index , eid ].shape [0 ]
429+ w13_shard_size = self .w13_lora_a_stacked [ 0 ] [index , eid ].shape [0 ]
430430 w13_start_idx = self .tp_rank * w13_shard_size
431431 w13_end_idx = (self .tp_rank + 1 ) * w13_shard_size
432432 w1_lora_a = w1_lora_a [w13_start_idx :w13_end_idx , :]
0 commit comments