Skip to content

Commit 70022ff

Browse files
xiao-llmXiao YU
andauthored
Granite 4.0 quark quantization support (#26944)
Signed-off-by: Xiao YU <[email protected]> Signed-off-by: Xiao Yu <[email protected]> Co-authored-by: Xiao YU <[email protected]>
1 parent f417746 commit 70022ff

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
330330
lora_config = vllm_config.lora_config
331331

332332
self.config = config
333+
self.quant_config = quant_config
333334
lora_vocab = (
334335
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
335336
if lora_config
@@ -405,6 +406,33 @@ def forward(
405406
hidden_states = self.norm(hidden_states)
406407
return hidden_states
407408

409+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
410+
# Params for weights, fp8 weight scales, fp8 activation scales
411+
# (param_name, weight_name, expert_id, shard_id)
412+
# layers.0.block_sparse_moe.expert_0.input_linear.input_scale
413+
ckpt_gate_proj_name = "gate_proj"
414+
ckpt_down_proj_name = "down_proj"
415+
ckpt_up_proj_name = "up_proj"
416+
num_experts = self.config.num_local_experts
417+
418+
return [
419+
# (param_name, weight_name, expert_id, shard_id)
420+
(
421+
"block_sparse_moe.experts.w13_"
422+
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
423+
else "block_sparse_moe.experts.w2_",
424+
f"block_sparse_moe.experts.{expert_id}.{weight_name}.",
425+
expert_id,
426+
shard_id,
427+
)
428+
for expert_id in range(num_experts)
429+
for shard_id, weight_name in [
430+
("w1", ckpt_gate_proj_name),
431+
("w2", ckpt_down_proj_name),
432+
("w3", ckpt_up_proj_name),
433+
]
434+
]
435+
408436
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
409437
stacked_params_mapping = [
410438
# (param_name, shard_name, shard_id)
@@ -414,6 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
414442
]
415443
params_dict = dict(self.named_parameters())
416444
loaded_params: set[str] = set()
445+
expert_params_mapping = self.get_expert_mapping()
417446

418447
def _load(n, p):
419448
param = params_dict[n]
@@ -435,10 +464,56 @@ def _load_expert(n, p, name, shard_id, expert_id):
435464
weight_loader(param, p, name, shard_id=shard_id, expert_id=expert_id)
436465
loaded_params.add(n)
437466

467+
def _load_quant_expert(name, loaded_weight):
468+
for mapping in expert_params_mapping:
469+
param_name, weight_name, expert_id, shard_id = mapping
470+
471+
if weight_name not in name:
472+
continue
473+
474+
name_mapped = name.replace(weight_name, param_name)
475+
476+
# Skip layers on other devices.
477+
if is_pp_missing_parameter(name_mapped, self):
478+
continue
479+
480+
param = params_dict[name_mapped]
481+
weight_loader = param.weight_loader
482+
success = False
483+
484+
if weight_loader is not None:
485+
success = weight_loader(
486+
param,
487+
loaded_weight,
488+
name_mapped,
489+
shard_id=shard_id,
490+
expert_id=expert_id,
491+
return_success=True,
492+
)
493+
494+
if success:
495+
return name_mapped
496+
return None
497+
438498
for n, p in weights:
439499
if "A_log" in n:
440500
n = n.replace("A_log", "A")
441501

502+
if self.quant_config is not None and (
503+
scale_name := self.quant_config.get_cache_scale(n)
504+
):
505+
# Loading kv cache quantization scales
506+
loaded_weight = p
507+
loaded_weight = (
508+
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
509+
)
510+
_load(scale_name, loaded_weight)
511+
loaded_params.add(scale_name)
512+
continue
513+
514+
if _load_quant_expert(n, p):
515+
continue
516+
442517
# Logic analogous to: https://github.com/vllm-project/vllm/blob/f49e5aff11c986ed4d45202b1716c5d74786efa9/vllm/model_executor/models/granitemoeshared.py#L215
443518
# Mapping different experts' layout:
444519
# from HF (input_linear, output_linear, router)

0 commit comments

Comments
 (0)