Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/ut/ops/test_moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def test_fused_experts_method(self, mock_unified_apply_mlp,
w2 = w2.contiguous()

result = comm_impl.fused_experts(hidden_states=hidden_states,
w1=w1,
w2=w2,
w1=[w1],
w2=[w2],
topk_weights=topk_weights,
topk_ids=topk_ids,
activation="silu")
Expand Down
47 changes: 37 additions & 10 deletions vllm_ascend/eplb/adaptor/vllm_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,22 @@ def __init__(self, model, **args):
self.init_redundancy_expert = get_ascend_config(
).init_redundancy_expert

for i in range(self.num_dense_layers,
self.model.config.num_hidden_layers):
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_list"] = \
self.model.model.layers[i].mlp.experts.w13_weight_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_list"] = \
self.model.model.layers[i].mlp.experts.w2_weight_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w13_weight_scale_fp32_list"] = \
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None:
self.expert_weight_names = [
"w13_weight", "w2_weight", "w13_weight_scale",
"w13_weight_offset", "w2_weight_scale", "w2_weight_offset"
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w13_weight_offset",
"w2_weight_scale_list", "w2_weight_offset"
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
Expand Down Expand Up @@ -84,9 +95,14 @@ def init_buffer_tensor(self, num_buffer_tensor):
for name in self.expert_weight_names:
complete_name = "model.layers." + str(
self.num_dense_layers) + ".mlp.experts." + name
expert_tensor = self.param_dict[complete_name].data[0]
if name in ["w13_weight", "w2_weight"]:
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
]:
expert_tensor = self.param_dict[complete_name][0]
expert_tensor = expert_tensor.clone()
else:
expert_tensor = self.param_dict[complete_name][0].data[0]
buffer_tensor = torch.empty_like(expert_tensor)
self.buffer_tensor_list[buffer_id].append(buffer_tensor)

Expand All @@ -97,12 +113,23 @@ def init_expert_param_per_layer(self):
layer_idx = self.num_dense_layers + moe_layer_id
self.expert_param_per_layer[layer_idx] = list()
for local_expert_id in range(num_local_expert):
self.expert_param_per_layer[layer_idx].append([
self.param_dict["model.layers." + str(layer_idx) +
".mlp.experts." +
name].data[local_expert_id]
for name in self.expert_weight_names
])
per_expert_param = list()
for name in self.expert_weight_names:
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list",
"w2_weight_scale_list"
]:
per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) +
".mlp.experts." +
name][local_expert_id])
else:
per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) +
".mlp.experts." +
name][0].data[local_expert_id])
self.expert_param_per_layer[layer_idx].append(per_expert_param)

def get_rank_expert_workload(self) -> torch.Tensor:
self.moe_load = self.model.get_all_moe_loads()
Expand Down
8 changes: 4 additions & 4 deletions vllm_ascend/ops/fused_moe/moe_comm_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def finalize(self,
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
Expand All @@ -93,8 +93,8 @@ def fused_experts(
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
Expand Down
81 changes: 55 additions & 26 deletions vllm_ascend/ops/fused_moe/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
get_ascend_device_type)
enable_custom_op, get_ascend_device_type)


def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
return fusion and dynamic_eplb and enable_custom_op()


def cumsum_group_list(group_list: torch.Tensor,
Expand Down Expand Up @@ -55,10 +59,10 @@ def cumsum_group_list(group_list: torch.Tensor,


def quant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w1: list[torch.Tensor],
w1_scale: list[torch.Tensor],
w2: list[torch.Tensor],
w2_scale: list[torch.Tensor],
group_list: torch.Tensor,
group_list_type: int = 1,
dynamic_scale: torch.Tensor = None,
Expand All @@ -79,31 +83,42 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
quantized_hidden_states = hidden_states

bias1, bias2 = None, None
_output_dtype = w2_scale.dtype
_output_dtype = w2_scale[0].dtype

weight_prefetch_method = get_forward_context().weight_prefetch_method
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and is_mc2:
if fusion and not dynamic_eplb:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = (
torch.ops._C_ascend.
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type),
))
elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
weight=w1[0],
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
weight_scale=w1_scale[0],
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
if w1_scale.dtype != torch.float32:
w1_scale = w1_scale.to(torch.float32)
if w1_scale[0].dtype != torch.float32:
w1_scale[0] = w1_scale[0].to(torch.float32)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
weight=w1,
split_item=3,
group_list_type=group_list_type,
group_type=0,
Expand All @@ -126,14 +141,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
weight=w2,
scale=w2_scale,
per_token_scale=[swiglu_out_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale.dtype)[0]
output_dtype=w2_scale[0].dtype)[0]
else:
if w1_scale_bias is not None:
if group_list_type == 0:
Expand All @@ -146,23 +161,36 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# TODO w4a8 scene: dynamic acquisition of dtype in the future
_output_dtype = torch.bfloat16

if fusion and not dynamic_eplb:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = (
torch.ops._C_ascend.
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type),
bias=bias1,
))
elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
x=hidden_states,
weight=w1,
weight=w1[0],
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type),
weight_scale=w1_scale,
weight_scale=w1_scale[0],
x_scale=pertoken_scale)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
w1_scale[0] = w1_scale[0].to(w2_scale[0].dtype)
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
scale=[w1_scale.to(w2_scale.dtype)],
weight=w1,
scale=w1_scale,
bias=bias1,
per_token_scale=[pertoken_scale],
split_item=2,
Expand All @@ -179,8 +207,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale],
weight=w2,
scale=w2_scale,
bias=bias2,
per_token_scale=[swiglu_out_scale],
split_item=2,
Expand Down Expand Up @@ -232,11 +260,11 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,


def unified_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
Expand All @@ -247,6 +275,7 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
need_trans: bool = True,
dynamic_eplb: bool = False) -> torch.Tensor:
if with_quant:
assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
Expand Down
8 changes: 4 additions & 4 deletions vllm_ascend/quantization/w4a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,10 @@ def apply(
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1=[layer.w13_weight],
w2=[layer.w2_weight],
w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias,
w2_scale_bias=layer.w2_scale_bias,
topk_weights=topk_weights,
Expand Down
41 changes: 37 additions & 4 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,24 @@ def apply(
topk_weights = topk_weights.to(self.in_dtype)

moe_comm_method = get_forward_context().moe_comm_method
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.w13_weight_scale_fp32_list
w2 = layer.w2_weight_list
w2_scale = layer.w2_weight_scale_list
else:
w1 = [layer.w13_weight]
w1_scale = [layer.w13_weight_scale_fp32]
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]

return moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale_fp32,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
Expand Down Expand Up @@ -272,3 +283,25 @@ def process_weights_after_loading(self, layer):
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)
if self.dynamic_eplb:
layer.w13_weight_list = [
weight.clone()
for weight in layer.w13_weight.data.unbind(dim=0)
]
layer.w2_weight_list = [
weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)
]
layer.w13_weight_scale_fp32_list = [
weight.clone()
for weight in layer.w13_weight_scale.data.unbind(dim=0)
]
layer.w2_weight_scale_list = [
weight.clone()
for weight in layer.w2_weight_scale.data.unbind(dim=0)
]
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
torch.npu.empty_cache()
Loading