Skip to content

Commit a60e4c8

Browse files
committed
triton code vector core rejection_sampler.py
Signed-off-by: yuxingcyx <[email protected]>
1 parent 7f4e3b1 commit a60e4c8

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torch.nn as nn
6+
import triton.runtime.driver as driver
67
import vllm.v1.sample.rejection_sampler as rs
78
from vllm.triton_utils import HAS_TRITON, tl, triton
89
from vllm.v1.sample.metadata import SamplingMetadata
@@ -17,6 +18,9 @@
1718
# step. This value is chosen to be large enough to handle typical use cases.
1819
MAX_SPEC_LEN = 32
1920

21+
device_properties = driver.active.utils.get_device_properties(torch.npu.current_device())
22+
vectorcore_num = device_properties['num_vectorcore']
23+
#get vector core number in order for later tiling
2024

2125
class AscendRejectionSampler(RejectionSampler, nn.Module):
2226
"""
@@ -155,8 +159,8 @@ def rejection_sample(
155159
n = cu_num_draft_tokens.numel()
156160
BLOCK_SIZE = 2
157161
grid = triton.cdiv(n, BLOCK_SIZE)
158-
if n >= 40:
159-
grid = 40 # Empirically tuned value
162+
if n >= vectorcore_num:
163+
grid = vectorcore_num # Empirically tuned value
160164
BLOCK_SIZE = triton.next_power_of_2(n // grid)
161165

162166
if min(num_draft_tokens) == 1 and max(
@@ -295,8 +299,8 @@ def expand_batch_to_tokens(
295299
n = cu_num_tokens.numel()
296300
BLOCK_SIZE = 2
297301
grid = triton.cdiv(n, BLOCK_SIZE)
298-
if n >= 40:
299-
grid = 40
302+
if n >= vectorcore_num:
303+
grid = vectorcore_num
300304
BLOCK_SIZE = triton.next_power_of_2(n // grid)
301305

302306
expand_kernel[(grid, )](

0 commit comments

Comments
 (0)