diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 77b5251d6d1..131c194cfc9 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -58,7 +58,7 @@ class AscendConfig: def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} - + self.mix_placement = additional_config.get("mix_placement", False) xlite_graph_config = additional_config.get("xlite_graph_config", {}) self.xlite_graph_config = XliteGraphConfig(xlite_graph_config, vllm_config) @@ -233,13 +233,12 @@ def __init__(self, **kwargs): """ Initialize the configuration. - + Args: fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization. When set to True, the system will optimize norm and quant operations. Default: True - fuse_qknorm_rope (bool): Whether to enable qknorm and rope fusion optimization. - Default: False + **kwargs: Additional optional parameters for forward compatibility and configuration extension. """ self.fuse_norm_quant = fuse_norm_quant diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index 05ec0e38491..93f1f78e8ca 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -33,6 +33,8 @@ def select_experts(hidden_states: torch.Tensor, routed_scaling_factor=1.0, e_score_correction_bias: Optional[torch.Tensor] = None, indices_type: Optional[torch.dtype] = None, + mix_placement: Optional[bool] = False, + num_logical_experts: int = -1, global_num_experts: int = -1): """ Fused experts with select experts. @@ -95,6 +97,20 @@ def select_experts(hidden_states: torch.Tensor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts, ) + if mix_placement: + shared_expert_routing_fator = 0.4 + pad_shared_expert_ids = torch.full((topk_ids.shape[0], 1), + num_logical_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + + pad_shared_expert_weights = torch.full((topk_weights.shape[0], 1), + shared_expert_routing_fator, + dtype=topk_weights.dtype, + device=topk_weights.device) + topk_ids = torch.cat([topk_ids, pad_shared_expert_ids], dim=1) + topk_weights = torch.cat([topk_weights, pad_shared_expert_weights], + dim=1) return topk_weights, topk_ids diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 2d0e7afca06..38de0b086c8 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -169,13 +169,13 @@ def __init__(self, *args, **kwargs): self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() - ascend_config = get_ascend_config() - self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path - self.expert_map_path = ascend_config.expert_map_path - self.global_redundant_expert_num = ascend_config.init_redundancy_expert + self.ascend_config = get_ascend_config() + self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path + self.expert_map_path = self.ascend_config.expert_map_path + self.global_redundant_expert_num = self.ascend_config.init_redundancy_expert self.global_num_experts = num_experts + self.global_redundant_expert_num # flashcommon3 gate stream - self.multistream_overlap_gate = ascend_config.multistream_overlap_gate + self.multistream_overlap_gate = self.ascend_config.multistream_overlap_gate if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None: AscendFusedMoE.gate_stream = torch.npu.Stream() if self.custom_routing_function is None and self.e_score_correction_bias is not None: @@ -189,6 +189,8 @@ def __init__(self, *args, **kwargs): # TODO: Temporary flag to indicate if static EPLB is enabled. This is a # workaround to bypass a quantization check that fails with float weights. init_eplb_enable = False + num_experts += 1 if getattr(self.ascend_config, "mix_placement", + False) else 0 # static eplb initializing with expert_map_path if self.expert_map_path and os.path.exists( self.expert_map_path) and os.access(self.expert_map_path, @@ -255,7 +257,7 @@ def __init__(self, *args, **kwargs): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp setup_moe_comm_method(self.moe_config) self.quant_type = self._get_quant_type() @@ -464,9 +466,9 @@ def __init__( self._shared_experts = shared_experts self.use_overlapped = use_overlapped self.shared_expert_stream = None - ascend_config = get_ascend_config() - self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert - self.multistream_overlap_gate = ascend_config.multistream_overlap_gate + self.ascend_config = get_ascend_config() + self.multistream_overlap_shared_expert = self.ascend_config.multistream_overlap_shared_expert + self.multistream_overlap_gate = self.ascend_config.multistream_overlap_gate if enable_sp(): logger.info_once( "Sequence parallelism is enabled, shared experts are replicated for best performance." @@ -494,11 +496,19 @@ def forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - shared_out, fused_out = AscendFusedMoE.forward( - self, - hidden_states=hidden_states, - router_logits=router_logits, - ) + if self._shared_experts is None: + fused_out = AscendFusedMoE.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + shared_out = None + else: + shared_out, fused_out = AscendFusedMoE.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) return shared_out, fused_out def forward_impl(self, hidden_states: torch.Tensor, @@ -514,7 +524,12 @@ def forward_impl(self, hidden_states: torch.Tensor, shared_experts_calculation_stream(), enabled=self.multistream_overlap_shared_expert): # Use a separate stream to run shared experts. - shared_out = self._shared_experts(hidden_states) + # Note that currently we only support calculations in separate streams with aclgraph. + # Communication operations in another stream might cause unknown errors. + if self._shared_experts is None: + shared_out = None + else: + shared_out = self._shared_experts(hidden_states) else: set_flash_common3_context(shared_experts=self._shared_experts) @@ -523,7 +538,6 @@ def forward_impl(self, hidden_states: torch.Tensor, hidden_states=hidden_states, router_logits=router_logits, ) - if not self.multistream_overlap_gate: # Make sure the default stream waits for the shared experts stream to finish. if self.multistream_overlap_shared_expert: @@ -534,11 +548,13 @@ def forward_impl(self, hidden_states: torch.Tensor, forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ - and not shared_expert_dp_enabled(): + and not shared_expert_dp_enabled() and shared_out is not None: shared_out = tensor_model_parallel_all_reduce(shared_out) else: fc3_context = get_flash_common3_context() assert fc3_context is not None shared_out = fc3_context.shared_out - - return shared_out, fused_output + if shared_out is None: + return fused_output + else: + return shared_out, fused_output diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4d3a9daf939..400e8771599 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -285,3 +285,14 @@ # Future Plan: # Remove this patch when vLLM support these operators. # +# ** File: worker/patch_deepseekv3.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. vllm.model_executor.models.deepseek_v2 (DeepseekV2 & DeepseekMoE related logic) +# Why: +# The mix placement feature requires modifying the loading format of DeepseekV3 shared expert weights and adjusting the inference path of DeepseekMoE. +# How: +# Patch the weight loading logic of DeepseekV3 to adapt to the mix placement storage format, and modify the forward inference path of DeepseekMoE to support the mix placement feature. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/4881 +# Future Plan: +# Remove this patch after the mix placement feature is natively implemented in the official vllm codebase. diff --git a/vllm_ascend/patch/worker/patch_deepseekv3.py b/vllm_ascend/patch/worker/patch_deepseekv3.py new file mode 100644 index 00000000000..c01b7146bb7 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_deepseekv3.py @@ -0,0 +1,482 @@ +import typing +from collections.abc import Callable, Iterable + +import torch +import vllm +from torch import nn +from transformers import DeepseekV2Config, DeepseekV3Config +from vllm.config import ParallelConfig +from vllm.distributed import (get_ep_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.fused_moe.shared_fused_moe import \ + SharedFusedMoE +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2ForCausalLM, DeepseekV2MLP, DeepseekV2MoE, + get_spec_layer_idx_from_weight_name) +from vllm.model_executor.models.utils import (is_pp_missing_parameter, + sequence_parallel_chunk) + +from vllm_ascend.ascend_config import get_ascend_config + + +class AscendDeepseekV2MoE(DeepseekV2MoE, nn.Module): + + def __init__( + self, + config: DeepseekV2Config | DeepseekV3Config, + parallel_config: ParallelConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32)) + else: + self.gate.e_score_correction_bias = None + + # Load balancing settings. + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + ascend_config = get_ascend_config() + mix_placement = getattr(ascend_config, "mix_placement", False) + if (config.n_shared_experts is None or mix_placement): + self.shared_experts = None + else: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not mix_placement else self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None + + # Fix FP16 overflow + # See DeepseekV2DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_dim) + + +class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + ascend_config = get_ascend_config() + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + mix_placement = getattr(ascend_config, "mix_placement", False) + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + + (self.config.n_shared_experts if mix_placement else 0), + num_redundant_experts=self.num_redundant_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue + + is_fuse_shared_experts_layer = (mix_placement + and ("mlp.shared_experts" in name)) + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if ("mlp.experts." in name) and name not in params_dict: + continue + if is_fuse_shared_experts_layer: + continue + name_mapped = name.replace(weight_name, param_name) + + if (param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: + continue + else: + name = name_mapped + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict.keys(): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + num_chunks = 1 + if is_fuse_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", + 1) or 1 + split_dim = 1 if "down_proj.weight" in name else 0 + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} not divisible by num_chunks {num_chunks}" + ) + chunk_size = total // num_chunks + + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fuse_shared_experts_layer: + if split_dim == 0: + weight_to_load = loaded_weight[j * + chunk_size:(j + 1) * + chunk_size, :] + else: + weight_to_load = loaded_weight[:, j * + chunk_size:(j + 1) * + chunk_size] + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + + is_expert_weight = True + name_mapped = chunk_name.replace( + weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + if name_mapped not in params_dict.keys(): + continue + param = params_dict[name_mapped] + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fuse_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + continue + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict.keys(): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if not is_fuse_shared_experts_layer: + loaded_params.add(name) + return loaded_params + + +class CustomDeepSeekMTP(DeepSeekMTP): + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + ascend_config = get_ascend_config() + mix_placement = getattr(ascend_config, "mix_placement", False) + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + + (self.config.n_shared_experts if mix_placement else 0), + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + is_fusion_moe_shared_experts_layer = (mix_placement + and ("mlp.shared_experts" + in name)) + name = self._rewrite_spec_layer_name(spec_layer, name) + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + if is_fusion_moe_shared_experts_layer: + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + if (param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: + continue + else: + name = name_mapped + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Special handling: when AITER fusion_shared_experts is enabled, + # checkpoints may provide a single widened shared_experts tensor + # without explicit expert indices + # (e.g. ...mlp.shared_experts.gate_proj.weight). + # For models with multiple shared experts, split that tensor + # evenly into per-shared-expert slices and load them into + # appended expert slots mlp.experts.{n_routed_experts + j}.* + # accordingly. + num_chunks = 1 + if is_fusion_moe_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", + 1) or 1 + # Determine split axis based on op type + # gate/up: ColumnParallel → split along dim 0 + # down: RowParallel → split along dim 1 + split_dim = 1 if "down_proj.weight" in name else 0 + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} " + f"not divisible by num_chunks {num_chunks}") + chunk_size = total // num_chunks + + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fusion_moe_shared_experts_layer: + if split_dim == 0: + weight_to_load = loaded_weight[j * + chunk_size:(j + 1) * + chunk_size, :] + else: + weight_to_load = loaded_weight[:, j * + chunk_size:(j + 1) * + chunk_size] + # Synthesize an expert-style name so expert mapping + # can route it + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + # Use expert_params_mapping to locate the destination + # param and delegate to its expert-aware weight_loader + # with expert_id. + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = chunk_name.replace( + weight_name, param_name) + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fusion_moe_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if not is_fusion_moe_shared_experts_layer: + loaded_params.add(name) + return loaded_params + + +vllm.model_executor.models.deepseek_v2.DeepseekV2MoE = AscendDeepseekV2MoE +DeepseekV2ForCausalLM.load_weights = CustomDeepseekV2ForCausalLM.load_weights +DeepSeekMTP.load_weights = CustomDeepSeekMTP.load_weights diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 8952d3cf5fe..9cd4b35dfa7 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -198,8 +198,10 @@ def apply( pertoken_scale: Optional[Any] = None, **kwargs, ) -> torch.Tensor: + mix_placement = getattr(layer.ascend_config, "mix_placement", False) + n_shared_experts = 1 if mix_placement else 0 assert router_logits.shape[ - 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + 1] == global_num_experts - global_redundant_expert_num - n_shared_experts, "Number of global experts mismatch (excluding redundancy)" if self.multistream_overlap_gate: fc3_context = get_flash_common3_context() @@ -218,6 +220,9 @@ def apply( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, + mix_placement=getattr(layer.ascend_config, "mix_placement", + False), + num_logical_experts=router_logits.shape[1], global_num_experts=global_num_experts) assert topk_ids is not None assert topk_weights is not None