|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | import torch.nn as nn |
| 6 | +import triton.runtime.driver as driver |
6 | 7 | import vllm.v1.sample.rejection_sampler as rs |
7 | 8 | from vllm.triton_utils import HAS_TRITON, tl, triton |
8 | 9 | from vllm.v1.sample.metadata import SamplingMetadata |
|
17 | 18 | # step. This value is chosen to be large enough to handle typical use cases. |
18 | 19 | MAX_SPEC_LEN = 32 |
19 | 20 |
|
| 21 | +device_properties = driver.active.utils.get_device_properties(torch.npu.current_device()) |
| 22 | +vectorcore_num = device_properties['num_vectorcore'] |
| 23 | +#get vector core number in order for later tiling |
20 | 24 |
|
21 | 25 | class AscendRejectionSampler(RejectionSampler, nn.Module): |
22 | 26 | """ |
@@ -155,8 +159,8 @@ def rejection_sample( |
155 | 159 | n = cu_num_draft_tokens.numel() |
156 | 160 | BLOCK_SIZE = 2 |
157 | 161 | grid = triton.cdiv(n, BLOCK_SIZE) |
158 | | - if n >= 40: |
159 | | - grid = 40 # Empirically tuned value |
| 162 | + if n >= vectorcore_num: |
| 163 | + grid = vectorcore_num # Empirically tuned value |
160 | 164 | BLOCK_SIZE = triton.next_power_of_2(n // grid) |
161 | 165 |
|
162 | 166 | if min(num_draft_tokens) == 1 and max( |
@@ -295,8 +299,8 @@ def expand_batch_to_tokens( |
295 | 299 | n = cu_num_tokens.numel() |
296 | 300 | BLOCK_SIZE = 2 |
297 | 301 | grid = triton.cdiv(n, BLOCK_SIZE) |
298 | | - if n >= 40: |
299 | | - grid = 40 |
| 302 | + if n >= vectorcore_num: |
| 303 | + grid = vectorcore_num |
300 | 304 | BLOCK_SIZE = triton.next_power_of_2(n // grid) |
301 | 305 |
|
302 | 306 | expand_kernel[(grid, )]( |
|
0 commit comments