|
24 | 24 |
|
25 | 25 | from tests.ut.base import TestBase |
26 | 26 | from vllm_ascend.ascend_forward_context import MoECommType |
27 | | -from vllm_ascend.ops.fused_moe import (AscendFusedMoE, |
28 | | - AscendUnquantizedFusedMoEMethod) |
| 27 | +from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod |
29 | 28 | from vllm_ascend.ops.moe.experts_selector import select_experts |
30 | 29 | from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp |
31 | 30 | from vllm_ascend.utils import AscendSocVersion, adapt_patch |
@@ -70,7 +69,7 @@ def setup_vllm_config_mock(mocker: MockerFixture): |
70 | 69 | mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4) |
71 | 70 | mock_vllm_config.model_config.max_model_len = 2048 |
72 | 71 |
|
73 | | - mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', |
| 72 | + mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config', |
74 | 73 | return_value=mock_vllm_config) |
75 | 74 | mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config', |
76 | 75 | return_value=mock_vllm_config) |
@@ -104,24 +103,24 @@ def mock_finalize(hidden_states, **kwargs): |
104 | 103 |
|
105 | 104 | with patch('torch.distributed.get_rank', return_value=0), \ |
106 | 105 | patch('torch.distributed.get_world_size', return_value=4), \ |
107 | | - patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
| 106 | + patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
108 | 107 | patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
109 | | - patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
110 | | - patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
| 108 | + patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ |
| 109 | + patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
111 | 110 | patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
112 | | - patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
| 111 | + patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
113 | 112 | patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ |
114 | 113 | patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', |
115 | 114 | return_value=mock_dp_and_tp_group(mocker)), \ |
116 | | - patch('vllm_ascend.ops.fused_moe.get_ascend_config', |
| 115 | + patch('vllm_ascend.ops.common_fused_moe.get_ascend_config', |
117 | 116 | return_value=MagicMock( |
118 | 117 | torchair_graph_config=MagicMock(enabled=False), |
119 | 118 | enable_multistream_moe=False, |
120 | 119 | expert_map_path=None |
121 | 120 | )), \ |
122 | | - patch('vllm_ascend.ops.fused_moe.determine_expert_map', |
| 121 | + patch('vllm_ascend.ops.common_fused_moe.determine_expert_map', |
123 | 122 | return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ |
124 | | - patch('vllm_ascend.ops.fused_moe.get_forward_context', |
| 123 | + patch('vllm_ascend.ops.common_fused_moe.get_forward_context', |
125 | 124 | return_value=mock_forward_context_obj), \ |
126 | 125 | patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context', |
127 | 126 | return_value=mock_forward_context_obj), \ |
@@ -252,196 +251,6 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module): |
252 | 251 | pass |
253 | 252 |
|
254 | 253 |
|
255 | | -class TestAscendFusedMoe: |
256 | | - |
257 | | - def test_init_no_quant(self, mock_dist_env, default_moe_config): |
258 | | - layer = AscendFusedMoE(**default_moe_config) |
259 | | - |
260 | | - layer.w13_weight = nn.Parameter( |
261 | | - torch.randn(default_moe_config['num_experts'], |
262 | | - default_moe_config['intermediate_size'] * 2, |
263 | | - default_moe_config['hidden_size'])) |
264 | | - layer.w2_weight = nn.Parameter( |
265 | | - torch.randn(default_moe_config['num_experts'], |
266 | | - default_moe_config['hidden_size'], |
267 | | - default_moe_config['intermediate_size'])) |
268 | | - |
269 | | - assert layer.num_experts == default_moe_config['num_experts'] |
270 | | - assert layer.top_k == default_moe_config['top_k'] |
271 | | - assert hasattr(layer, 'w13_weight') |
272 | | - assert hasattr(layer, 'w2_weight') |
273 | | - |
274 | | - with pytest.raises(AssertionError): |
275 | | - error_config = default_moe_config.copy() |
276 | | - error_config['use_grouped_topk'] = True |
277 | | - layer = AscendFusedMoE(**error_config) |
278 | | - |
279 | | - with pytest.raises(ValueError): |
280 | | - error_config = default_moe_config.copy() |
281 | | - error_config['scoring_func'] = "random" |
282 | | - layer = AscendFusedMoE(**error_config) |
283 | | - |
284 | | - def test_init_with_quant(self, mock_dist_env, default_moe_config): |
285 | | - mock_quant_config = MagicMock() |
286 | | - mock_quant_method = MockFusedMoEMethod() |
287 | | - mock_quant_config.get_quant_method.return_value = mock_quant_method |
288 | | - |
289 | | - moe = AscendFusedMoE(**default_moe_config, |
290 | | - quant_config=mock_quant_config) |
291 | | - |
292 | | - assert moe.quant_method is not None |
293 | | - assert moe.quant_method == mock_quant_method |
294 | | - |
295 | | - @pytest.mark.parametrize( |
296 | | - "others_param", |
297 | | - [[None, |
298 | | - MagicMock(return_value=torch.randn(5, 32)), False, 5, None], |
299 | | - [2, None, False, 5, None], [None, None, True, 5, None], |
300 | | - [None, None, False, 1, None], [None, None, True, 5, 1], |
301 | | - [None, None, False, 5, 1]]) |
302 | | - def test_forward(self, mock_dist_env, default_moe_config, others_param): |
303 | | - |
304 | | - top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param |
305 | | - inputs = torch.randn(num_tokens, 32) |
306 | | - router_logits = torch.randn(num_tokens, 8) |
307 | | - moe = AscendFusedMoE(**default_moe_config) |
308 | | - |
309 | | - if ep_size == 1: |
310 | | - moe.moe_parallel_config.ep_size = 1 |
311 | | - |
312 | | - moe.quant_method = MockQuantMethod(shared_experts, num_tokens) |
313 | | - forward_context = mock_dist_env['mock_forward_context_obj'] |
314 | | - with patch("vllm_ascend.ops.fused_moe.get_forward_context", |
315 | | - return_value=forward_context): |
316 | | - output = moe.forward(inputs, |
317 | | - router_logits, |
318 | | - is_prefill=is_prefill, |
319 | | - top_k=top_k, |
320 | | - shared_experts=shared_experts) |
321 | | - |
322 | | - moe.quant_method.apply.assert_called_once() |
323 | | - |
324 | | - if shared_experts: |
325 | | - assert output[0].shape == (num_tokens, 32) |
326 | | - assert output[1].shape == (num_tokens, 10) |
327 | | - else: |
328 | | - assert output.shape == (num_tokens, 32) |
329 | | - |
330 | | - def test_forward_ms_fused_moe_comp(self, mock_dist_env, |
331 | | - default_moe_config): |
332 | | - inputs = torch.randn(5, 32) |
333 | | - router_logits = torch.randn(5, 8) |
334 | | - moe = AscendFusedMoE(**default_moe_config) |
335 | | - |
336 | | - moe.quant_method = MockQuantMethod(None, 5) |
337 | | - output = moe._forward_ms_fused_moe_comp(inputs, |
338 | | - router_logits, |
339 | | - is_prefill=False, |
340 | | - real_top_k=1) |
341 | | - |
342 | | - moe.quant_method.apply.assert_called_once() |
343 | | - |
344 | | - assert output.shape == (5, 32) |
345 | | - |
346 | | - |
347 | | -class TestAscendUnquantizedFusedMoEMethod: |
348 | | - |
349 | | - def test_process_weights_after_loading(self, moe_method, mock_dist_env): |
350 | | - layer = MagicMock() |
351 | | - layer.w13_weight.data = torch.randn(16, 32) |
352 | | - layer.w2_weight.data = torch.randn(16, 32) |
353 | | - |
354 | | - with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \ |
355 | | - patch('vllm_ascend.utils.is_310p', return_value=False): |
356 | | - moe_method.process_weights_after_loading(layer) |
357 | | - |
358 | | - assert isinstance(layer.w13_weight, torch.nn.Parameter) |
359 | | - assert isinstance(layer.w2_weight, torch.nn.Parameter) |
360 | | - assert not layer.w13_weight.requires_grad |
361 | | - assert not layer.w2_weight.requires_grad |
362 | | - |
363 | | - @pytest.mark.parametrize("others_param", |
364 | | - [[256, 4], [128, 1], [128, 1], [128, 4]]) |
365 | | - def test_apply_without_expert_map(self, moe_method, mock_dist_env, |
366 | | - mock_moe_env, others_param): |
367 | | - global_num_experts, ep_size = others_param |
368 | | - is_prefill = False |
369 | | - |
370 | | - forward_context = mock_dist_env['mock_forward_context_obj'] |
371 | | - |
372 | | - with patch("vllm_ascend.ops.fused_moe.get_forward_context", |
373 | | - return_value=forward_context): |
374 | | - moe_method.ep_size = ep_size |
375 | | - x = torch.randn(8, 2, 2) |
376 | | - router_logits = torch.randn(8, 8) |
377 | | - layer = MagicMock() |
378 | | - local_num_experts = 2 |
379 | | - hidden_size = 2 |
380 | | - intermediate_size_per_partition = 4 |
381 | | - |
382 | | - layer.w13_weight = torch.randn(local_num_experts, |
383 | | - intermediate_size_per_partition * 2, |
384 | | - hidden_size) |
385 | | - layer.w2_weight = torch.randn(local_num_experts, hidden_size, |
386 | | - intermediate_size_per_partition) |
387 | | - |
388 | | - result = moe_method.apply(layer=layer, |
389 | | - x=x, |
390 | | - router_logits=router_logits, |
391 | | - top_k=2, |
392 | | - renormalize=True, |
393 | | - global_num_experts=global_num_experts, |
394 | | - is_prefill=is_prefill) |
395 | | - |
396 | | - mock_moe_comm_method = mock_dist_env['mock_moe_comm_method'] |
397 | | - mock_moe_comm_method.fused_experts.assert_called_once() |
398 | | - |
399 | | - expected_shape = (16, 2) |
400 | | - assert result.shape == expected_shape |
401 | | - |
402 | | - @pytest.mark.parametrize("others_param", [16, 1, 4]) |
403 | | - def test_apply_with_expert_map(self, moe_method, mock_dist_env, |
404 | | - mock_moe_env, others_param): |
405 | | - ep_size = others_param |
406 | | - is_prefill = False |
407 | | - |
408 | | - forward_context = mock_dist_env['mock_forward_context_obj'] |
409 | | - |
410 | | - with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ |
411 | | - patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3): |
412 | | - expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) |
413 | | - moe_method.ep_size = ep_size |
414 | | - x = torch.randn(8, 2, 2) |
415 | | - if ep_size == 1: |
416 | | - x = x.view(-1, 2) |
417 | | - router_logits = torch.randn(8, 8) |
418 | | - layer = MagicMock() |
419 | | - |
420 | | - local_num_experts = 2 |
421 | | - hidden_size = 2 |
422 | | - intermediate_size_per_partition = 4 |
423 | | - layer.w13_weight = torch.randn(local_num_experts, |
424 | | - intermediate_size_per_partition * 2, |
425 | | - hidden_size) |
426 | | - layer.w2_weight = torch.randn(local_num_experts, hidden_size, |
427 | | - intermediate_size_per_partition) |
428 | | - |
429 | | - result = moe_method.apply(layer=layer, |
430 | | - x=x, |
431 | | - router_logits=router_logits, |
432 | | - top_k=2, |
433 | | - renormalize=True, |
434 | | - global_num_experts=128, |
435 | | - expert_map=expert_map, |
436 | | - is_prefill=is_prefill) |
437 | | - |
438 | | - mock_moe_comm_method = mock_dist_env['mock_moe_comm_method'] |
439 | | - mock_moe_comm_method.fused_experts.assert_called_once() |
440 | | - |
441 | | - expected_shape = (16, 2) |
442 | | - assert result.shape == expected_shape |
443 | | - |
444 | | - |
445 | 254 | class TestExpertsSelector: |
446 | 255 |
|
447 | 256 | @pytest.mark.parametrize("global_num_experts", [[256], [128]]) |
|
0 commit comments