Skip to content

Commit 66f61e1

Browse files
committed
triton code vector core notation new rejection_sampler.py
Signed-off-by: yuxingcyx <[email protected]>
1 parent 353facf commit 66f61e1

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020
# step. This value is chosen to be large enough to handle typical use cases.
2121
MAX_SPEC_LEN = 32
2222

23-
2423
vectorcore_num = None
2524
device_properties = None
2625

27-
2826
if HAS_TRITON:
29-
from triton.runtime import driver
30-
device_properties = driver.active.utils.get_device_properties(torch.npu.current_device())
27+
from triton.runtime import driver #type: ignore
28+
device_properties = driver.active.utils.get_device_properties(
29+
torch.npu.current_device())
3130
vectorcore_num = device_properties['num_vectorcore']
3231
#get vector core number in order for later tiling
3332

@@ -239,7 +238,7 @@ def rejection_sample(
239238

240239
if min(num_draft_tokens) == 1 and max(
241240
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
242-
rejection_greedy_sample_spec_len_1_triton[(grid,)](
241+
rejection_greedy_sample_spec_len_1_triton[(grid, )](
243242
output_token_ids,
244243
draft_token_ids,
245244
target_argmax,
@@ -248,7 +247,7 @@ def rejection_sample(
248247
BLOCK_SIZE=BLOCK_SIZE,
249248
)
250249
else:
251-
rejection_greedy_sample_triton[(grid,)](
250+
rejection_greedy_sample_triton[(grid, )](
252251
output_token_ids,
253252
cu_num_draft_tokens,
254253
draft_token_ids,
@@ -377,7 +376,7 @@ def expand_batch_to_tokens(
377376
grid = vectorcore_num
378377
BLOCK_SIZE = triton.next_power_of_2(n // grid)
379378

380-
expand_kernel[(grid,)](
379+
expand_kernel[(grid, )](
381380
expanded_x,
382381
x,
383382
cu_num_tokens,

0 commit comments

Comments
 (0)