@@ -330,6 +330,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
330330 lora_config = vllm_config .lora_config
331331
332332 self .config = config
333+ self .quant_config = quant_config
333334 lora_vocab = (
334335 (lora_config .lora_extra_vocab_size * (lora_config .max_loras or 1 ))
335336 if lora_config
@@ -405,6 +406,33 @@ def forward(
405406 hidden_states = self .norm (hidden_states )
406407 return hidden_states
407408
409+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
410+ # Params for weights, fp8 weight scales, fp8 activation scales
411+ # (param_name, weight_name, expert_id, shard_id)
412+ # layers.0.block_sparse_moe.expert_0.input_linear.input_scale
413+ ckpt_gate_proj_name = "gate_proj"
414+ ckpt_down_proj_name = "down_proj"
415+ ckpt_up_proj_name = "up_proj"
416+ num_experts = self .config .num_local_experts
417+
418+ return [
419+ # (param_name, weight_name, expert_id, shard_id)
420+ (
421+ "block_sparse_moe.experts.w13_"
422+ if weight_name in [ckpt_gate_proj_name , ckpt_up_proj_name ]
423+ else "block_sparse_moe.experts.w2_" ,
424+ f"block_sparse_moe.experts.{ expert_id } .{ weight_name } ." ,
425+ expert_id ,
426+ shard_id ,
427+ )
428+ for expert_id in range (num_experts )
429+ for shard_id , weight_name in [
430+ ("w1" , ckpt_gate_proj_name ),
431+ ("w2" , ckpt_down_proj_name ),
432+ ("w3" , ckpt_up_proj_name ),
433+ ]
434+ ]
435+
408436 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
409437 stacked_params_mapping = [
410438 # (param_name, shard_name, shard_id)
@@ -414,6 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
414442 ]
415443 params_dict = dict (self .named_parameters ())
416444 loaded_params : set [str ] = set ()
445+ expert_params_mapping = self .get_expert_mapping ()
417446
418447 def _load (n , p ):
419448 param = params_dict [n ]
@@ -435,10 +464,56 @@ def _load_expert(n, p, name, shard_id, expert_id):
435464 weight_loader (param , p , name , shard_id = shard_id , expert_id = expert_id )
436465 loaded_params .add (n )
437466
467+ def _load_quant_expert (name , loaded_weight ):
468+ for mapping in expert_params_mapping :
469+ param_name , weight_name , expert_id , shard_id = mapping
470+
471+ if weight_name not in name :
472+ continue
473+
474+ name_mapped = name .replace (weight_name , param_name )
475+
476+ # Skip layers on other devices.
477+ if is_pp_missing_parameter (name_mapped , self ):
478+ continue
479+
480+ param = params_dict [name_mapped ]
481+ weight_loader = param .weight_loader
482+ success = False
483+
484+ if weight_loader is not None :
485+ success = weight_loader (
486+ param ,
487+ loaded_weight ,
488+ name_mapped ,
489+ shard_id = shard_id ,
490+ expert_id = expert_id ,
491+ return_success = True ,
492+ )
493+
494+ if success :
495+ return name_mapped
496+ return None
497+
438498 for n , p in weights :
439499 if "A_log" in n :
440500 n = n .replace ("A_log" , "A" )
441501
502+ if self .quant_config is not None and (
503+ scale_name := self .quant_config .get_cache_scale (n )
504+ ):
505+ # Loading kv cache quantization scales
506+ loaded_weight = p
507+ loaded_weight = (
508+ loaded_weight if loaded_weight .dim () == 0 else loaded_weight [0 ]
509+ )
510+ _load (scale_name , loaded_weight )
511+ loaded_params .add (scale_name )
512+ continue
513+
514+ if _load_quant_expert (n , p ):
515+ continue
516+
442517 # Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215
443518 # Mapping different experts' layout:
444519 # from HF (input_linear, output_linear, router)
0 commit comments