Skip to content

Commit 883f351

Browse files
committed
fix ut by triton torch_npu._inductor
Signed-off-by: Meihan-chen <[email protected]>
1 parent c7a42f8 commit 883f351

File tree

3 files changed

+13
-12
lines changed

3 files changed

+13
-12
lines changed

tests/e2e/nightly/ops/triton/test_rope.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import gc
2+
import sys
3+
from unittest.mock import MagicMock
24

35
import pytest
46
import torch
57

8+
if 'torch_npu._inductor' not in sys.modules:
9+
sys.modules['torch_npu._inductor'] = MagicMock()
610
from vllm_ascend.ops.triton.rope import rope_forward_triton
711
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
812

tests/ut/attention/test_sfa_v1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import sys
12
from unittest.mock import MagicMock
23

34
import torch
45
from vllm.v1.attention.backends.utils import AttentionCGSupport
56

67
from tests.ut.base import TestBase
78
from vllm_ascend.attention.attention_v1 import AscendAttentionState
9+
10+
if 'torch_npu._inductor' not in sys.modules:
11+
sys.modules['torch_npu._inductor'] = MagicMock()
12+
813
from vllm_ascend.attention.sfa_v1 import (AscendSFABackend, AscendSFAImpl,
914
AscendSFAMetadata,
1015
AscendSFAMetadataBuilder)

tests/ut/sample/test_rejection_sampler.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,18 +128,10 @@ def test_expand_batch_to_tokens(self):
128128
cu_num_tokens = torch.tensor([2, 5, 7])
129129
num_tokens = 7
130130

131-
with patch("vllm_ascend.sample.rejection_sampler.expand_pytorch"
132-
) as mock_kernel:
133-
expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
134-
mock_kernel.assert_called_once()
135-
args = mock_kernel.call_args[0]
136-
assert (args[1] == x).all()
137-
assert (args[2] == cu_num_tokens).all()
138-
139-
# Run actual function
140-
result = expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
141-
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
142-
assert torch.equal(result, expected)
131+
with patch("vllm_ascend.sample.rejection_sampler.HAS_TRITON", False):
132+
result = expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
133+
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
134+
assert torch.equal(result, expected)
143135

144136
def test_sample_recovered_tokens_pytorch_ngram(self):
145137
"""Test recovered token sampling under n-gram mode"""

0 commit comments

Comments
 (0)