Skip to content

Commit ec703ca

Browse files
committed
triton code pre commit for rejection_sampler.py
Signed-off-by: yuxingcyx <[email protected]>
1 parent e2bc5c9 commit ec703ca

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

vllm_ascend/sample/rejection_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def rejection_sample(
156156
BLOCK_SIZE = 2
157157
grid = triton.cdiv(n, BLOCK_SIZE)
158158
if n >= 40:
159-
grid = 40 # Empirically tuned value
159+
grid = 40 # Empirically tuned value
160160
BLOCK_SIZE = triton.next_power_of_2(n // grid)
161161

162162
if min(num_draft_tokens) == 1 and max(
@@ -167,7 +167,7 @@ def rejection_sample(
167167
target_argmax,
168168
bonus_token_ids,
169169
vec_len,
170-
BLOCK_SIZE = BLOCK_SIZE,
170+
BLOCK_SIZE=BLOCK_SIZE,
171171
)
172172
else:
173173
rejection_greedy_sample_triton[(grid, )](
@@ -307,7 +307,7 @@ def expand_batch_to_tokens(
307307
replace_to,
308308
vec_len,
309309
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
310-
BLOCK_SIZE = BLOCK_SIZE,
310+
BLOCK_SIZE=BLOCK_SIZE,
311311
)
312312
else:
313313
expand_pytorch(

0 commit comments

Comments
 (0)