|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | import torch.nn as nn |
6 | | -import torch_npu |
| 6 | +import triton.runtime.driver as driver |
7 | 7 | import vllm.v1.sample.rejection_sampler as rs |
8 | 8 | from vllm.triton_utils import HAS_TRITON, tl, triton |
9 | 9 | from vllm.v1.sample.metadata import SamplingMetadata |
10 | | -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p |
11 | 10 | from vllm.v1.sample.rejection_sampler import (RejectionSampler, |
| 11 | + apply_sampling_constraints, |
12 | 12 | generate_uniform_probs) |
13 | 13 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
14 | 14 |
|
15 | | -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type |
16 | | - |
17 | 15 | PLACEHOLDER_TOKEN_ID = -1 |
18 | 16 | GREEDY_TEMPERATURE = -1 |
19 | 17 | # Maximum number of speculative draft tokens allowed per request in a single |
20 | 18 | # step. This value is chosen to be large enough to handle typical use cases. |
21 | 19 | MAX_SPEC_LEN = 32 |
22 | 20 |
|
| 21 | +device_properties = driver.active.utils.get_device_properties(torch.npu.current_device()) |
| 22 | +vectorcore_num = device_properties['num_vectorcore'] |
| 23 | +#get vector core number in order for later tiling |
23 | 24 |
|
24 | 25 | class AscendRejectionSampler(RejectionSampler, nn.Module): |
25 | 26 | """ |
@@ -107,70 +108,6 @@ def forward( |
107 | 108 | return output_token_ids |
108 | 109 |
|
109 | 110 |
|
110 | | -def apply_sampling_constraints( |
111 | | - logits: torch.Tensor, # [num_tokens, vocab_size] |
112 | | - cu_num_draft_tokens: torch.Tensor, # [batch_size] |
113 | | - sampling_metadata: SamplingMetadata, |
114 | | -) -> torch.Tensor: |
115 | | - """Process logits based on sampling metadata. |
116 | | -
|
117 | | - This function applies temperature scaling to the logits, |
118 | | - as well as top-k and top-p. For greedy decoding, it returns |
119 | | - the original logits. |
120 | | -
|
121 | | - Args: |
122 | | - logits: Input logits tensor to be processed. |
123 | | - cu_num_draft_tokens: Cumulative number of draft tokens. |
124 | | - sampling_metadata: Metadata containing sampling parameters such as |
125 | | - temperature and whether greedy sampling is used. |
126 | | -
|
127 | | - Returns: |
128 | | - torch.Tensor: Processed logits if non-greedy sampling is used, |
129 | | - otherwise returns the original logits. |
130 | | - """ |
131 | | - assert logits.ndim == 2 |
132 | | - assert cu_num_draft_tokens.ndim == 1 |
133 | | - if sampling_metadata.all_greedy: |
134 | | - return logits |
135 | | - |
136 | | - num_tokens = logits.shape[0] |
137 | | - temperature = expand_batch_to_tokens( |
138 | | - sampling_metadata.temperature, |
139 | | - cu_num_draft_tokens, |
140 | | - num_tokens, |
141 | | - replace_from=GREEDY_TEMPERATURE, |
142 | | - replace_to=1, |
143 | | - ) |
144 | | - # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor. |
145 | | - logits.div_(temperature.unsqueeze(-1)) |
146 | | - |
147 | | - # Get expanded top_k and top_p tensors. |
148 | | - top_k = None |
149 | | - if sampling_metadata.top_k is not None: |
150 | | - top_k = expand_batch_to_tokens( |
151 | | - sampling_metadata.top_k, |
152 | | - cu_num_draft_tokens, |
153 | | - num_tokens, |
154 | | - ) |
155 | | - top_p = None |
156 | | - if sampling_metadata.top_p is not None: |
157 | | - top_p = expand_batch_to_tokens( |
158 | | - sampling_metadata.top_p, |
159 | | - cu_num_draft_tokens, |
160 | | - num_tokens, |
161 | | - ) |
162 | | - |
163 | | - if get_ascend_device_type( |
164 | | - ) != AscendDeviceType._310P and top_p is not None and top_k is not None and 1 <= int( |
165 | | - top_k.max()) <= 1024: |
166 | | - return torch_npu.npu_top_k_top_p(logits, top_p.to(torch.bfloat16), |
167 | | - top_k) |
168 | | - else: |
169 | | - # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, |
170 | | - # which is slow for large vocab sizes. This may cause performance issues. |
171 | | - return apply_top_k_top_p(logits, top_k, top_p) |
172 | | - |
173 | | - |
174 | 111 | def rejection_sample( |
175 | 112 | # [num_tokens] |
176 | 113 | draft_token_ids: torch.Tensor, |
@@ -218,15 +155,36 @@ def rejection_sample( |
218 | 155 | # Rejection sampling for greedy sampling requests. |
219 | 156 | target_argmax = target_probs.argmax(dim=-1) |
220 | 157 | if HAS_TRITON: |
221 | | - rejection_greedy_sample_kernel[(batch_size, )]( |
222 | | - output_token_ids, |
223 | | - cu_num_draft_tokens, |
224 | | - draft_token_ids, |
225 | | - target_argmax, |
226 | | - bonus_token_ids, |
227 | | - is_greedy, |
228 | | - max_spec_len, |
229 | | - ) |
| 158 | + vec_len = batch_size |
| 159 | + n = cu_num_draft_tokens.numel() |
| 160 | + BLOCK_SIZE = 2 |
| 161 | + grid = triton.cdiv(n, BLOCK_SIZE) |
| 162 | + if n >= vectorcore_num: |
| 163 | + grid = vectorcore_num # Empirically tuned value |
| 164 | + BLOCK_SIZE = triton.next_power_of_2(n // grid) |
| 165 | + |
| 166 | + if min(num_draft_tokens) == 1 and max( |
| 167 | + num_draft_tokens) == 1 and sampling_metadata.all_greedy: |
| 168 | + rejection_greedy_sample_spec_len_1_triton[(grid, )]( |
| 169 | + output_token_ids, |
| 170 | + draft_token_ids, |
| 171 | + target_argmax, |
| 172 | + bonus_token_ids, |
| 173 | + vec_len, |
| 174 | + BLOCK_SIZE=BLOCK_SIZE, |
| 175 | + ) |
| 176 | + else: |
| 177 | + rejection_greedy_sample_triton[(grid, )]( |
| 178 | + output_token_ids, |
| 179 | + cu_num_draft_tokens, |
| 180 | + draft_token_ids, |
| 181 | + target_argmax, |
| 182 | + bonus_token_ids, |
| 183 | + is_greedy, |
| 184 | + vec_len, |
| 185 | + max_spec_len, |
| 186 | + BLOCK_SIZE=BLOCK_SIZE, |
| 187 | + ) |
230 | 188 | else: |
231 | 189 | if min(num_draft_tokens) == 1 and max( |
232 | 190 | num_draft_tokens) == 1 and sampling_metadata.all_greedy: |
@@ -337,13 +295,23 @@ def expand_batch_to_tokens( |
337 | 295 | assert cu_num_tokens.shape[0] == batch_size |
338 | 296 | expanded_x = x.new_empty(num_tokens) |
339 | 297 | if HAS_TRITON: |
340 | | - expand_kernel[(batch_size, )]( |
| 298 | + vec_len = batch_size |
| 299 | + n = cu_num_tokens.numel() |
| 300 | + BLOCK_SIZE = 2 |
| 301 | + grid = triton.cdiv(n, BLOCK_SIZE) |
| 302 | + if n >= vectorcore_num: |
| 303 | + grid = vectorcore_num |
| 304 | + BLOCK_SIZE = triton.next_power_of_2(n // grid) |
| 305 | + |
| 306 | + expand_kernel[(grid, )]( |
341 | 307 | expanded_x, |
342 | 308 | x, |
343 | 309 | cu_num_tokens, |
344 | 310 | replace_from, |
345 | 311 | replace_to, |
| 312 | + vec_len, |
346 | 313 | MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. |
| 314 | + BLOCK_SIZE=BLOCK_SIZE, |
347 | 315 | ) |
348 | 316 | else: |
349 | 317 | expand_pytorch( |
@@ -626,50 +594,115 @@ def sample_recovered_tokens_pytorch( |
626 | 594 |
|
627 | 595 |
|
628 | 596 | @triton.jit(do_not_specialize=["max_spec_len"]) |
629 | | -def rejection_greedy_sample_kernel( |
| 597 | +def bonus_renew_1( |
| 598 | + bonus_token_ids_ptr, |
| 599 | + position, |
| 600 | + output_token_ids_ptr, |
| 601 | +): |
| 602 | + bonus_token_id = tl.load(bonus_token_ids_ptr + position) |
| 603 | + tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id) |
| 604 | + |
| 605 | + |
| 606 | +@triton.jit(do_not_specialize=["max_spec_len"]) |
| 607 | +def rejection_greedy_sample_spec_len_1_triton( |
| 608 | + output_token_ids_ptr, # [batch_size, 2] |
| 609 | + draft_token_ids_ptr, # [num_tokens] |
| 610 | + target_argmax_ptr, # [num_tokens] |
| 611 | + bonus_token_ids_ptr, |
| 612 | + vec_len, |
| 613 | + BLOCK_SIZE: tl.constexpr, |
| 614 | +): |
| 615 | + block_idx = tl.program_id(0) |
| 616 | + offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 617 | + mask = offset < vec_len |
| 618 | + |
| 619 | + draft_token_id = tl.load(draft_token_ids_ptr + offset, mask) |
| 620 | + target_argmax_id = tl.load(target_argmax_ptr + offset, mask) |
| 621 | + tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) |
| 622 | + |
| 623 | + for pos in tl.range(0, BLOCK_SIZE): |
| 624 | + draft_token_id1 = tl.get_element(draft_token_id, (pos, )) |
| 625 | + target_argmax1 = tl.get_element(target_argmax_id, (pos, )) |
| 626 | + position = block_idx * BLOCK_SIZE + pos |
| 627 | + if draft_token_id1 == target_argmax1: |
| 628 | + bonus_renew_1( |
| 629 | + bonus_token_ids_ptr, |
| 630 | + position, |
| 631 | + output_token_ids_ptr, |
| 632 | + ) |
| 633 | + |
| 634 | + |
| 635 | +@triton.jit(do_not_specialize=["max_spec_len"]) |
| 636 | +def bonus_renew( |
| 637 | + bonus_token_ids_ptr, |
| 638 | + position, |
| 639 | + output_token_ids_ptr, |
| 640 | + max_spec_len, |
| 641 | + num_tokens1, |
| 642 | +): |
| 643 | + bonus_token_id = tl.load(bonus_token_ids_ptr + position) |
| 644 | + tl.store( |
| 645 | + output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, |
| 646 | + bonus_token_id) |
| 647 | + |
| 648 | + |
| 649 | +@triton.jit(do_not_specialize=["max_spec_len"]) |
| 650 | +def rejection_greedy_sample_triton( |
630 | 651 | output_token_ids_ptr, # [batch_size, max_spec_len + 1] |
631 | 652 | cu_num_draft_tokens_ptr, # [batch_size] |
632 | 653 | draft_token_ids_ptr, # [num_tokens] |
633 | 654 | target_argmax_ptr, # [num_tokens] |
634 | 655 | bonus_token_ids_ptr, # [batch_size] |
635 | 656 | is_greedy_ptr, # [batch_size] or None |
| 657 | + vec_len, |
636 | 658 | max_spec_len, |
| 659 | + BLOCK_SIZE: tl.constexpr, |
637 | 660 | ): |
638 | | - req_idx = tl.program_id(0) |
639 | | - # Because is_greedy_ptr is not Nonr at profiling run, |
640 | | - # re-comilation may happen during runtime when is_greedy_ptr is None. |
641 | | - is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + |
642 | | - req_idx) |
643 | | - if not is_greedy: |
644 | | - # Early exit for non-greedy sampling requests |
645 | | - return |
| 661 | + block_idx = tl.program_id(0) |
| 662 | + offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 663 | + mask = offset < vec_len |
646 | 664 |
|
647 | | - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + |
648 | | - req_idx - 1) |
649 | | - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) |
| 665 | + if is_greedy_ptr is None: |
| 666 | + is_greedy_mask = mask |
| 667 | + else: |
| 668 | + is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0) |
| 669 | + is_greedy_mask = mask & (is_greedy != 0) |
| 670 | + |
| 671 | + start_idx = tl.where( |
| 672 | + offset == 0, 0, |
| 673 | + tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) |
| 674 | + end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask) |
650 | 675 | num_draft_tokens = end_idx - start_idx |
651 | 676 |
|
652 | | - rejected = False |
653 | | - for pos in range(num_draft_tokens): |
654 | | - if not rejected: |
655 | | - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) |
656 | | - target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) |
657 | | - tl.store( |
658 | | - output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, |
659 | | - target_argmax_id, |
660 | | - ) |
661 | | - if draft_token_id != target_argmax_id: |
662 | | - # Reject |
663 | | - rejected = True |
| 677 | + for pos in tl.range(0, BLOCK_SIZE): |
| 678 | + num_tokens1 = tl.get_element(num_draft_tokens, (pos, )) |
| 679 | + if num_tokens1 != 0: |
| 680 | + rejected = False |
| 681 | + start_idx1 = tl.get_element(start_idx, (pos, )) |
| 682 | + position = block_idx * BLOCK_SIZE + pos |
| 683 | + for i in range(num_tokens1): |
| 684 | + if not rejected: |
| 685 | + draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + |
| 686 | + i) |
| 687 | + target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + |
| 688 | + i) |
| 689 | + tl.store( |
| 690 | + output_token_ids_ptr + position * (max_spec_len + 1) + |
| 691 | + i, |
| 692 | + target_argmax_id, |
| 693 | + ) |
| 694 | + if draft_token_id != target_argmax_id: |
| 695 | + # Reject. |
| 696 | + rejected = True |
664 | 697 |
|
665 | | - if not rejected: |
666 | | - # If all tokens are accepted, append the bonus token |
667 | | - bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) |
668 | | - tl.store( |
669 | | - output_token_ids_ptr + req_idx * (max_spec_len + 1) + |
670 | | - num_draft_tokens, |
671 | | - bonus_token_id, |
672 | | - ) |
| 698 | + if not rejected: |
| 699 | + bonus_renew( |
| 700 | + bonus_token_ids_ptr, |
| 701 | + position, |
| 702 | + output_token_ids_ptr, |
| 703 | + max_spec_len, |
| 704 | + num_tokens1, |
| 705 | + ) |
673 | 706 |
|
674 | 707 |
|
675 | 708 | @triton.jit(do_not_specialize=["max_spec_len"]) |
@@ -739,22 +772,30 @@ def expand_kernel( |
739 | 772 | cu_num_tokens_ptr, # [batch_size] |
740 | 773 | replace_from, |
741 | 774 | replace_to, |
| 775 | + vec_len, |
742 | 776 | MAX_NUM_TOKENS: tl.constexpr, |
| 777 | + BLOCK_SIZE: tl.constexpr, |
743 | 778 | ): |
744 | 779 | req_idx = tl.program_id(0) |
745 | | - if req_idx == 0: |
746 | | - start_idx = 0 |
747 | | - else: |
748 | | - start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) |
749 | | - end_idx = tl.load(cu_num_tokens_ptr + req_idx) |
| 780 | + offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 781 | + len_mask = offset < vec_len |
| 782 | + |
| 783 | + start_idx = tl.where(offset == 0, 0, |
| 784 | + tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) |
| 785 | + end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) |
750 | 786 | num_tokens = end_idx - start_idx |
751 | 787 |
|
752 | | - src_val = tl.load(input_ptr + req_idx) |
| 788 | + src_val = tl.load(input_ptr + offset, len_mask) |
753 | 789 | src_val = tl.where(src_val == replace_from, replace_to, src_val) |
754 | | - offset = tl.arange(0, MAX_NUM_TOKENS) |
755 | | - tl.store(output_ptr + start_idx + offset, |
756 | | - src_val, |
757 | | - mask=offset < num_tokens) |
| 790 | + |
| 791 | + for i in tl.range(0, BLOCK_SIZE): |
| 792 | + num_tokens1 = tl.get_element(num_tokens, (i, )) |
| 793 | + start_idx1 = tl.get_element(start_idx, (i, )) |
| 794 | + src_val1 = tl.get_element(src_val, (i, )) |
| 795 | + offset1 = tl.arange(0, MAX_NUM_TOKENS) |
| 796 | + tl.store(output_ptr + start_idx1 + offset1, |
| 797 | + src_val1, |
| 798 | + mask=offset1 < num_tokens1) |
758 | 799 |
|
759 | 800 |
|
760 | 801 | @triton.jit |
|
0 commit comments