@@ -42,6 +42,7 @@ def __init__(self, base_layer: FusedMoE) -> None:
4242 self .tp_size = get_tensor_model_parallel_world_size ()
4343 self .tp_rank = get_tensor_model_parallel_rank ()
4444 self .device = base_layer .w2_weight .device
45+ self .w13_slices = 2
4546 self ._inject_lora_into_fused_moe ()
4647
4748 def _normalize_keys (self , config : dict [str , int | None ]) -> dict [str , int | None ]:
@@ -60,32 +61,34 @@ def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None
6061 def _get_lora_moe_configs (
6162 self ,
6263 op_prefix : str ,
63- lora_a_stacked : torch . Tensor ,
64- lora_b_stacked : torch . Tensor ,
64+ num_loras : int ,
65+ rank : int ,
6566 num_slices : int ,
6667 M : int ,
6768 layer : FusedMoE ,
6869 top_k : int ,
6970 config_dtype : str ,
7071 ):
7172 if envs .VLLM_TUNED_CONFIG_FOLDER :
73+ hidden_size = layer .hidden_size
74+ intermediate_size = layer .intermediate_size_per_partition
7275 shrink_config = get_lora_op_configs (
7376 op_type = f"fused_moe_lora_{ op_prefix } _shrink" ,
74- max_loras = lora_a_stacked . shape [ 0 ] ,
77+ max_loras = num_loras ,
7578 batch = M ,
76- hidden_size = lora_a_stacked . shape [ - 1 ] ,
77- rank = lora_a_stacked . shape [ - 2 ] ,
79+ hidden_size = hidden_size ,
80+ rank = rank ,
7881 num_slices = num_slices ,
79- moe_intermediate_size = lora_b_stacked . shape [ - 2 ] ,
82+ moe_intermediate_size = intermediate_size ,
8083 )
8184 expand_config = get_lora_op_configs (
8285 op_type = f"fused_moe_lora_{ op_prefix } _expand" ,
83- max_loras = lora_a_stacked . shape [ 0 ] ,
86+ max_loras = num_loras ,
8487 batch = M ,
85- hidden_size = lora_a_stacked .shape [- 1 ],
86- rank = lora_a_stacked . shape [ - 2 ] ,
88+ hidden_size = hidden_size , # lora_a_stacked.shape[-1],
89+ rank = rank ,
8790 num_slices = num_slices ,
88- moe_intermediate_size = lora_b_stacked .shape [- 2 ],
91+ moe_intermediate_size = intermediate_size , # lora_b_stacked.shape[-2],
8992 )
9093 else : # fall back to the default config
9194 get_config_func = functools .partial (
@@ -155,17 +158,16 @@ def wrapper(*args, **kwargs):
155158
156159 shrink_config , expand_config = self ._get_lora_moe_configs (
157160 op_prefix = "w13" ,
158- lora_a_stacked = self .w1_lora_a_stacked ,
159- lora_b_stacked = self .w1_lora_b_stacked ,
160- num_slices = 2 ,
161+ num_loras = self .max_loras ,
162+ rank = self .max_lora_rank ,
163+ num_slices = self . w13_slices ,
161164 M = M ,
162165 layer = layer ,
163166 top_k = top_k ,
164167 config_dtype = config_dtype ,
165168 )
166169
167170 # get the block size of m from customized config or default config
168- max_loras = self .w1_lora_a_stacked .shape [0 ]
169171 (
170172 sorted_token_ids_lora ,
171173 expert_ids_lora ,
@@ -175,7 +177,7 @@ def wrapper(*args, **kwargs):
175177 num_tokens ,
176178 shrink_config ["BLOCK_SIZE_M" ],
177179 self .base_layer .local_num_experts ,
178- max_loras ,
180+ self . max_loras ,
179181 self .adapter_enabled ,
180182 expert_map ,
181183 )
@@ -186,22 +188,19 @@ def wrapper(*args, **kwargs):
186188 num_tokens_post_padded_lora
187189 )
188190
189- w13_lora_a_stacked = [self .w1_lora_a_stacked , self .w3_lora_a_stacked ]
190- w13_lora_b_stacked = [self .w1_lora_b_stacked , self .w3_lora_b_stacked ]
191- max_lora_rank = self .w1_lora_a_stacked .shape [- 2 ]
192- expert_ids_lora = expert_ids_lora .view (max_loras , - 1 )
193- sorted_token_ids_lora = sorted_token_ids_lora .view (max_loras , - 1 )
191+ expert_ids_lora = expert_ids_lora .view (self .max_loras , - 1 )
192+ sorted_token_ids_lora = sorted_token_ids_lora .view (self .max_loras , - 1 )
194193
195194 self .punica_wrapper .add_lora_fused_moe (
196195 input .view (- 1 , top_k , input .shape [- 1 ]),
197196 hidden_states ,
198- w13_lora_a_stacked ,
199- w13_lora_b_stacked ,
197+ self . w13_lora_a_stacked ,
198+ self . w13_lora_b_stacked ,
200199 topk_weights ,
201200 sorted_token_ids_lora ,
202201 expert_ids_lora ,
203202 num_tokens_post_padded_lora ,
204- max_lora_rank ,
203+ self . max_lora_rank ,
205204 top_k ,
206205 shrink_config , ## pass the shrink config
207206 expand_config , ## pass the expand config
@@ -233,8 +232,8 @@ def wrapper(*args, **kwargs):
233232
234233 shrink_config , expand_config = self ._get_lora_moe_configs (
235234 op_prefix = "w2" ,
236- lora_a_stacked = self .w2_lora_a_stacked ,
237- lora_b_stacked = self .w2_lora_b_stacked ,
235+ num_loras = self .max_loras ,
236+ rank = self .max_lora_rank ,
238237 num_slices = 1 ,
239238 M = M ,
240239 layer = layer ,
@@ -247,25 +246,24 @@ def wrapper(*args, **kwargs):
247246 num_tokens_post_padded_lora = moe_state_dict [
248247 "num_tokens_post_padded_lora"
249248 ]
250- max_loras = self . w1_lora_a_stacked . shape [ 0 ]
251- expert_ids_lora = expert_ids_lora .view (max_loras , - 1 )
252- sorted_token_ids_lora = sorted_token_ids_lora .view (max_loras , - 1 )
249+
250+ expert_ids_lora = expert_ids_lora .view (self . max_loras , - 1 )
251+ sorted_token_ids_lora = sorted_token_ids_lora .view (self . max_loras , - 1 )
253252 intermediate_cache2 = moe_state_dict ["intermediate_cache2" ]
254253 intermediate_cache3 = args [0 ]
255- max_lora_rank = self .w2_lora_a_stacked .shape [- 2 ]
256254
257255 shard_size_w2 = divide (self .base_layer .hidden_size , self .tp_size )
258256
259257 self .punica_wrapper .add_lora_fused_moe (
260258 intermediate_cache3 ,
261259 intermediate_cache2 ,
262- [ self .w2_lora_a_stacked ] ,
263- [ self .w2_lora_b_stacked ] ,
260+ ( self .w2_lora_a_stacked ,) ,
261+ ( self .w2_lora_b_stacked ,) ,
264262 topk_weights ,
265263 sorted_token_ids_lora ,
266264 expert_ids_lora ,
267265 num_tokens_post_padded_lora ,
268- max_lora_rank ,
266+ self . max_lora_rank ,
269267 top_k ,
270268 shrink_config , ## pass the shrink config
271269 expand_config , ## pass the expand config
@@ -289,7 +287,6 @@ def wrapper(*args, **kwargs):
289287 fused_experts .moe_sum = moe_sum_decorator (
290288 self .base_layer , fused_experts .moe_sum
291289 )
292-
293290 self .base_layer .quant_method = FusedMoEModularMethod (
294291 self .base_layer .quant_method , m_fused_moe_fn
295292 )
@@ -301,40 +298,50 @@ def create_lora_weights(
301298 model_config : PretrainedConfig | None = None ,
302299 ) -> None :
303300 """Initializes lora matrices."""
301+ assert self .w13_slices == 2
302+ self .max_loras = lora_config .max_loras
303+ self .max_lora_rank = lora_config .max_lora_rank
304304 self .fully_sharded = lora_config .fully_sharded_loras
305305
306306 self .adapter_enabled = torch .tensor (
307307 [0 ] * (max_loras + 1 ), dtype = torch .int , device = self .device
308308 )
309309
310- self .w1_lora_a_stacked = torch .zeros (
311- (
312- max_loras ,
313- self .base_layer .local_num_experts ,
314- lora_config .max_lora_rank
315- if not self .fully_sharded
316- else divide (lora_config .max_lora_rank , self .tp_size ),
317- self .base_layer .hidden_size ,
318- ),
319- dtype = lora_config .lora_dtype ,
320- device = self .device ,
310+ self .w13_lora_a_stacked = tuple (
311+ torch .zeros (
312+ (
313+ max_loras ,
314+ self .base_layer .local_num_experts ,
315+ self .max_lora_rank
316+ if not self .fully_sharded
317+ else divide (self .max_lora_rank , self .tp_size ),
318+ self .base_layer .hidden_size ,
319+ ),
320+ dtype = lora_config .lora_dtype ,
321+ device = self .device ,
322+ )
323+ for _ in range (self .w13_slices )
321324 )
322- self .w1_lora_b_stacked = torch .zeros (
323- (
324- max_loras ,
325- self .base_layer .local_num_experts ,
326- self .base_layer .intermediate_size_per_partition ,
327- lora_config .max_lora_rank ,
328- ),
329- dtype = lora_config .lora_dtype ,
330- device = self .device ,
325+
326+ self .w13_lora_b_stacked = tuple (
327+ torch .zeros (
328+ (
329+ max_loras ,
330+ self .base_layer .local_num_experts ,
331+ self .base_layer .intermediate_size_per_partition ,
332+ self .max_lora_rank ,
333+ ),
334+ dtype = lora_config .lora_dtype ,
335+ device = self .device ,
336+ )
337+ for _ in range (self .w13_slices )
331338 )
332339
333340 self .w2_lora_a_stacked = torch .zeros (
334341 (
335342 max_loras ,
336343 self .base_layer .local_num_experts ,
337- lora_config .max_lora_rank ,
344+ self .max_lora_rank ,
338345 self .base_layer .intermediate_size_per_partition ,
339346 ),
340347 dtype = lora_config .lora_dtype ,
@@ -347,30 +354,7 @@ def create_lora_weights(
347354 self .base_layer .hidden_size
348355 if not self .fully_sharded
349356 else divide (self .base_layer .hidden_size , self .tp_size ),
350- lora_config .max_lora_rank ,
351- ),
352- dtype = lora_config .lora_dtype ,
353- device = self .device ,
354- )
355-
356- self .w3_lora_a_stacked = torch .zeros (
357- (
358- max_loras ,
359- self .base_layer .local_num_experts ,
360- lora_config .max_lora_rank
361- if not self .fully_sharded
362- else divide (lora_config .max_lora_rank , self .tp_size ),
363- self .base_layer .hidden_size ,
364- ),
365- dtype = lora_config .lora_dtype ,
366- device = self .device ,
367- )
368- self .w3_lora_b_stacked = torch .zeros (
369- (
370- max_loras ,
371- self .base_layer .local_num_experts ,
372- self .base_layer .intermediate_size_per_partition ,
373- lora_config .max_lora_rank ,
357+ self .max_lora_rank ,
374358 ),
375359 dtype = lora_config .lora_dtype ,
376360 device = self .device ,
@@ -383,20 +367,28 @@ def create_lora_weights(
383367 for lora_id in range (max_loras ):
384368 for experts_id in range (self .base_layer .local_num_experts ):
385369 # gate_proj,down_proj,up_proj
386- self .lora_a_stacked .append (self .w1_lora_a_stacked [lora_id ][experts_id ])
370+ self .lora_a_stacked .append (
371+ self .w13_lora_a_stacked [0 ][lora_id ][experts_id ]
372+ )
387373 self .lora_a_stacked .append (self .w2_lora_a_stacked [lora_id ][experts_id ])
388- self .lora_a_stacked .append (self .w3_lora_a_stacked [lora_id ][experts_id ])
374+ self .lora_a_stacked .append (
375+ self .w13_lora_a_stacked [1 ][lora_id ][experts_id ]
376+ )
389377
390- self .lora_b_stacked .append (self .w1_lora_b_stacked [lora_id ][experts_id ])
378+ self .lora_b_stacked .append (
379+ self .w13_lora_b_stacked [0 ][lora_id ][experts_id ]
380+ )
391381 self .lora_b_stacked .append (self .w2_lora_b_stacked [lora_id ][experts_id ])
392- self .lora_b_stacked .append (self .w3_lora_b_stacked [lora_id ][experts_id ])
382+ self .lora_b_stacked .append (
383+ self .w13_lora_b_stacked [1 ][lora_id ][experts_id ]
384+ )
393385
394386 def reset_lora (self , index : int ):
395387 """Resets the lora weights at index back to 0."""
396- self . w1_lora_a_stacked [ index ] = 0
397- self .w1_lora_b_stacked [index ] = 0
398- self .w3_lora_a_stacked [index ] = 0
399- self . w3_lora_b_stacked [ index ] = 0
388+ for pos in range ( self . w13_slices ):
389+ self .w13_lora_a_stacked [ pos ] [index ] = 0
390+ self .w13_lora_b_stacked [ pos ] [index ] = 0
391+
400392 self .w2_lora_a_stacked [index ] = 0
401393 self .w2_lora_b_stacked [index ] = 0
402394 self .adapter_enabled [index ] = 0
@@ -444,29 +436,32 @@ def set_lora(
444436 w2_start_idx = self .tp_rank * w2_shard_size
445437 w2_end_idx = (self .tp_rank + 1 ) * w2_shard_size
446438 w2_lora_b = w2_lora_b [w2_start_idx :w2_end_idx , :]
447-
448- self .w1_lora_a_stacked [
439+ # w1 lora_a
440+ self .w13_lora_a_stacked [ 0 ] [
449441 index , eid , : w1_lora_a .shape [0 ], : w1_lora_a .shape [1 ]
450442 ].copy_ (w1_lora_a , non_blocking = True )
451-
452- self .w3_lora_a_stacked [
443+ # w3 lora_a
444+ self .w13_lora_a_stacked [ 1 ] [
453445 index , eid , : w3_lora_a .shape [0 ], : w3_lora_a .shape [1 ]
454446 ].copy_ (w3_lora_a , non_blocking = True )
455447
456- self .w2_lora_b_stacked [
457- index , eid , : w2_lora_b .shape [0 ], : w2_lora_b .shape [1 ]
458- ].copy_ (w2_lora_b , non_blocking = True )
459-
460- self .w1_lora_b_stacked [
448+ # w1 lora_b
449+ self .w13_lora_b_stacked [0 ][
461450 index , eid , : w1_lora_b .shape [0 ], : w1_lora_b .shape [1 ]
462451 ].copy_ (w1_lora_b , non_blocking = True )
463- self .w3_lora_b_stacked [
452+ # w3 lora_b
453+ self .w13_lora_b_stacked [1 ][
464454 index , eid , : w3_lora_b .shape [0 ], : w3_lora_b .shape [1 ]
465455 ].copy_ (w3_lora_b , non_blocking = True )
456+
466457 self .w2_lora_a_stacked [
467458 index , eid , : w2_lora_a .shape [0 ], : w2_lora_a .shape [1 ]
468459 ].copy_ (w2_lora_a , non_blocking = True )
469460
461+ self .w2_lora_b_stacked [
462+ index , eid , : w2_lora_b .shape [0 ], : w2_lora_b .shape [1 ]
463+ ].copy_ (w2_lora_b , non_blocking = True )
464+
470465 @classmethod
471466 def can_replace_layer (
472467 cls ,
0 commit comments