Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_pipeline_parallel.py
pytest -sv tests/e2e/multicard/test_prefix_caching.py
pytest -sv tests/e2e/multicard/test_qwen3_moe.py
pytest -sv tests/e2e/multicard/test_offline_weight_load.py

e2e-4-cards:
name: multicard-4
Expand Down
26 changes: 17 additions & 9 deletions examples/offline_weight_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_open_port

from vllm.model_executor.model_loader.utils import \
process_weights_after_loading

os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

Expand Down Expand Up @@ -219,15 +222,6 @@ def main(
gpu_memory_utilization = 0.95,
enable_sleep_mode=enable_sleep_mode,
)
model_path = model
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
patch_vllm_moe_model_weight_loader(runmodel)
sd = load_and_merge_safetensors(model_path)
runmodel.load_weights(sd.items())
print('load state dict done')
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')

outputs = llm.generate(prompts, sampling_params)

if enable_sleep_mode:
Expand All @@ -242,6 +236,20 @@ def main(
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes

llm.wake_up()

model_path = model
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
patch_vllm_moe_model_weight_loader(runmodel)
sd = load_and_merge_safetensors(model_path)
runmodel.load_weights(sd.items())
print('load state dict done')
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')

vllm_config = llm.llm_engine.vllm_config.model_config
device = next(runmodel.parameters()).device
process_weights_after_loading(runmodel, vllm_config, device)

outputs_after_wakeup = llm.generate(prompts, sampling_params)
if rank == 0:
# cmp output
Expand Down
74 changes: 74 additions & 0 deletions tests/e2e/multicard/test_offline_weight_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Run `pytest tests/multicard/test_offline_load_weight.py`.
"""

import os
import subprocess
import sys
from pathlib import Path
from unittest.mock import patch

import pytest

MODELS = ["Qwen/Qwen3-30B-A3B"]


@pytest.mark.parametrize("model", MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
def test_offline_weight_load_and_sleepmode(model):
script = Path(
__file__
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
env = os.environ.copy()
cmd = [
sys.executable,
str(script),
"--model",
model,
"--tp-size",
"2",
"--node-size",
"1",
"--node-rank",
"0",
"--proc-per-node",
"2",
"--trust-remote-code",
"--enable-sleep-mode",
"--temperature",
"0",
"--model-weight-gib",
"0.8",
]

print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600,
)
output = proc.stdout.decode(errors='ignore')

print(output)

assert "Generated text:" in output
assert "Sleep and wake up successfully!!" in output
assert proc.returncode == 0
39 changes: 1 addition & 38 deletions tests/ut/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.fused_moe import (
AscendFusedMoE, AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
unified_apply_mlp)
from vllm_ascend.utils import AscendDeviceType, adapt_patch
Expand Down Expand Up @@ -595,39 +594,3 @@ def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
self.assertTrue(mock_forward_context.with_quant)
self.assertEqual(result.shape, hidden_states_shape)
self.assertEqual(result.dtype, torch.bfloat16)


class TestLoadWeight(TestBase):

def test_load_w13_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)

expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)

expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)

expert_data = torch.randn(128, 8)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)

expert_data = torch.randn(8, 128)
loaded_weight = torch.randn(128, 4)
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)

def test_load_w2_transpose(self):
with patch.object(AscendFusedMoE, "__init__",
lambda self, *args, **kwargs: None):
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
expert_data = torch.randn(128, 4)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)

expert_data = torch.randn(4, 128)
loaded_weight = torch.randn(128, 8)
moe._load_w2(expert_data, 1, loaded_weight, 0)
13 changes: 13 additions & 0 deletions tests/ut/worker/test_worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,22 @@ def test_wake_up_mode_enabled(self, mock_allocator_class,
mock_allocator = MagicMock()
mock_allocator_class.get_instance.return_value = mock_allocator

mock_hidden_size = MagicMock()
mock_hf_config = MagicMock()
mock_hf_config.hidden_size = mock_hidden_size
mock_model_config = MagicMock()
mock_model_config.hf_config = mock_hf_config
mock_vllm_config = MagicMock()
mock_vllm_config.model_config = mock_model_config

mock_model_runner = MagicMock()
mock_model_runner.model = MagicMock()

# Create worker mock
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
worker = NPUWorker()
worker.model_runner = mock_model_runner
worker.vllm_config = mock_vllm_config
worker._sleep_saved_buffers = {}
# Test wake_up method
worker.wake_up(tags=["test_tag"])
Expand Down
78 changes: 6 additions & 72 deletions vllm_ascend/ops/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,18 @@ def __init__(self, moe: FusedMoEConfig = None):

super().__init__(moe=moe)
self.dynamic_eplb = get_ascend_config().dynamic_eplb
self.transpose = True

def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
if self.transpose:
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data,
requires_grad=False)

w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)

self.transpose = False
else:
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w13_weight = torch.nn.Parameter(w13_data,
requires_grad=False)

w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)

if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
):
Expand Down Expand Up @@ -389,61 +378,6 @@ def forward_impl(self, hidden_states: torch.Tensor,

return final_hidden_states

def transpose_weight(self, loaded_weight, expert_data, shard_dim):
# Ensure training and inference weight shapes match during RL weight updates
if (len(loaded_weight.shape) >= 2 and len(expert_data.shape) >= 2 and \
loaded_weight.shape[1] != expert_data.shape[1] and \
loaded_weight.shape[0] != expert_data.shape[0]
):
shard_dim = int(not shard_dim)
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
return loaded_weight, shard_dim

def _load_w13(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim] // 2
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)


class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):

Expand Down
21 changes: 20 additions & 1 deletion vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,28 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags)

hidden_size = self.vllm_config.model_config.hf_config.hidden_size
model = self.model_runner.model
for name, param in model.named_parameters():
if 'w2_weight' in name and param.shape[2] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))

w2_data = param.transpose(1, 2)
w2_data = torch.nn.Parameter(w2_data, requires_grad=False)
setattr(parent_module, param_name, w2_data)
elif 'w13_weight' in name and param.shape[1] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))

w13_data = param.transpose(1, 2)
w13_data = torch.nn.Parameter(w13_data, requires_grad=False)
setattr(parent_module, param_name, w13_data)
Comment on lines +182 to +198
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are a couple of issues in this block of code:

  1. [Critical] The condition to identify the w2_weight parameter is incorrect. The shape of w2_weight at this point is (num_experts, hidden_size, intermediate_size). The condition param.shape[2] == hidden_size compares intermediate_size with hidden_size, which is not always true and will cause this logic to fail for many models. It should be param.shape[1] == hidden_size to correctly identify the parameter by its hidden dimension size.

  2. [High] After transposing a tensor, it's good practice to call .contiguous() to ensure the memory layout is continuous. This can prevent potential errors and performance issues in subsequent operations that expect a contiguous tensor. The load_weights method which is called after this might rely on it.

  3. [Medium] The code for transposing w2_weight and w13_weight is very similar. This duplication can be avoided by refactoring it into a helper function to improve readability and maintainability.

I've provided a suggestion that fixes the critical bug, adds .contiguous(), and refactors the duplicated logic.

Suggested change
for name, param in model.named_parameters():
if 'w2_weight' in name and param.shape[2] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))
w2_data = param.transpose(1, 2)
w2_data = torch.nn.Parameter(w2_data, requires_grad=False)
setattr(parent_module, param_name, w2_data)
elif 'w13_weight' in name and param.shape[1] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))
w13_data = param.transpose(1, 2)
w13_data = torch.nn.Parameter(w13_data, requires_grad=False)
setattr(parent_module, param_name, w13_data)
for name, param in model.named_parameters():
# The shape of w2_weight is (num_experts, hidden_size, intermediate_size)
# The shape of w13_weight is (num_experts, hidden_size, 2 * intermediate_size)
if ('w2_weight' in name or 'w13_weight' in name) and len(param.shape) == 3 and param.shape[1] == hidden_size:
parts = name.split('.')
param_name = parts[-1]
parent_module = model.get_submodule(".".join(parts[:-1]))
# Transpose back to training format and ensure contiguity
new_data = param.transpose(1, 2).contiguous()
new_param = torch.nn.Parameter(new_data, requires_grad=False)
setattr(parent_module, param_name, new_param)


# Restore the buffers after level 2 sleep
if len(self._sleep_saved_buffers):
model = self.model_runner.model
for name, buffer in model.named_buffers():
if name in self._sleep_saved_buffers:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
Expand Down
Loading