Skip to content

Commit a6ef3ac

Browse files
weijinqian0weijinqian_v1
andauthored
[Performance] Pre-issued exponential distribution operator. (#4908)
Pre-issued exponential distribution operator. Result: Single inference saves 200-300 microseconds. before: <img width="2257" height="1058" alt="2" src="https://github.com/user-attachments/assets/c1da19e2-a439-42cb-9d7c-c0218e61fd4c" /> After: <img width="2211" height="342" alt="image" src="https://github.com/user-attachments/assets/03c84292-c802-4755-949c-4266a9a72fc0" /> - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e --------- Signed-off-by: weijinqian_v1 <[email protected]> Co-authored-by: weijinqian_v1 <[email protected]>
1 parent 0fbe083 commit a6ef3ac

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

tests/ut/sample/test_sampler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ def test_init_with_raw_logprobs(self):
1717

1818
class TestAscendTopKTopPSampler(TestBase):
1919

20+
@mock.patch("vllm_ascend.sample.sampler.random_sample")
2021
@mock.patch("torch_npu.npu_top_k_top_p")
21-
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
22+
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op,
23+
mock_random_sample):
2224
mock_npu_op.return_value = (torch.randn(1, 3))
25+
mock_random_sample.return_value = torch.randn(3)
2326
sampler = AscendTopKTopPSampler()
2427

2528
logits = torch.tensor([[1.0, 2.0, 3.0]])

vllm_ascend/sample/sampler.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,40 @@
11
import torch
22
import torch_npu
3-
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
3+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
44
from vllm.v1.sample.sampler import Sampler
55

6-
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
6+
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
7+
global_stream, npu_stream_switch)
78

89
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
910

1011

12+
def random_sample(
13+
probs: torch.Tensor,
14+
generators: dict[int, torch.Generator],
15+
) -> torch.Tensor:
16+
"""Randomly sample from the probabilities.
17+
18+
We use this function instead of torch.multinomial because torch.multinomial
19+
causes CPU-NPU synchronization.
20+
"""
21+
# NOTE(woosuk): To batch-process the requests without their own seeds,
22+
# which is the common case, we first assume that every request does
23+
# not have its own seed. Then, we overwrite the values for the requests
24+
# that have their own seeds.
25+
with npu_stream_switch(global_stream()):
26+
q = torch.empty_like(probs)
27+
if len(generators) != probs.shape[0]:
28+
q.exponential_()
29+
if generators:
30+
# TODO(woosuk): This can be slow because we handle each request
31+
# one by one. Optimize this.
32+
for i, generator in generators.items():
33+
q[i].exponential_(generator=generator)
34+
torch.npu.current_stream().wait_stream(global_stream())
35+
return probs.div_(q).argmax(dim=-1).view(-1)
36+
37+
1138
class AscendSampler(Sampler):
1239

1340
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):

vllm_ascend/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
_CUSTOM_OP_ENABLED = None
5252
_CURRENT_STREAM = None
5353
_PREFETCH_STREAM = None
54+
_GLOBAL_STREAM = None
5455
_SHARED_EXPERTS_CALCULATION_STREAM = None
5556
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
5657
_DEFAULT_BUFFER_SIZE = 200
@@ -292,6 +293,15 @@ def prefetch_stream() -> torch.npu.Stream:
292293
return _PREFETCH_STREAM
293294

294295

296+
def global_stream() -> torch.npu.Stream:
297+
global _GLOBAL_STREAM
298+
if _GLOBAL_STREAM is None:
299+
# when this function is called before any stream is set,
300+
# we return the default stream.
301+
_GLOBAL_STREAM = torch_npu.npu.Stream()
302+
return _GLOBAL_STREAM
303+
304+
295305
def shared_experts_calculation_stream() -> torch.npu.Stream:
296306
global _SHARED_EXPERTS_CALCULATION_STREAM
297307
if _SHARED_EXPERTS_CALCULATION_STREAM is None:

0 commit comments

Comments
 (0)