|
21 | 21 | MAX_SPEC_LEN = 32 |
22 | 22 |
|
23 | 23 |
|
| 24 | +vectorcore_num = None |
| 25 | +device_properties = None |
| 26 | + |
| 27 | + |
| 28 | +if HAS_TRITON: |
| 29 | + from triton.runtime import driver |
| 30 | + device_properties = driver.active.utils.get_device_properties(torch.npu.current_device()) |
| 31 | + vectorcore_num = device_properties['num_vectorcore'] |
| 32 | +#get vector core number in order for later tiling |
| 33 | + |
| 34 | + |
24 | 35 | class AscendRejectionSampler(RejectionSampler, nn.Module): |
25 | 36 | """ |
26 | 37 | The implementation strictly follows the algorithm described in |
@@ -218,15 +229,36 @@ def rejection_sample( |
218 | 229 | # Rejection sampling for greedy sampling requests. |
219 | 230 | target_argmax = target_probs.argmax(dim=-1) |
220 | 231 | if HAS_TRITON: |
221 | | - rejection_greedy_sample_kernel[(batch_size, )]( |
222 | | - output_token_ids, |
223 | | - cu_num_draft_tokens, |
224 | | - draft_token_ids, |
225 | | - target_argmax, |
226 | | - bonus_token_ids, |
227 | | - is_greedy, |
228 | | - max_spec_len, |
229 | | - ) |
| 232 | + vec_len = batch_size |
| 233 | + n = cu_num_draft_tokens.numel() |
| 234 | + BLOCK_SIZE = 2 |
| 235 | + grid = triton.cdiv(n, BLOCK_SIZE) |
| 236 | + if n >= vectorcore_num: |
| 237 | + grid = vectorcore_num # Empirically tuned value |
| 238 | + BLOCK_SIZE = triton.next_power_of_2(n // grid) |
| 239 | + |
| 240 | + if min(num_draft_tokens) == 1 and max( |
| 241 | + num_draft_tokens) == 1 and sampling_metadata.all_greedy: |
| 242 | + rejection_greedy_sample_spec_len_1_triton[(grid,)]( |
| 243 | + output_token_ids, |
| 244 | + draft_token_ids, |
| 245 | + target_argmax, |
| 246 | + bonus_token_ids, |
| 247 | + vec_len, |
| 248 | + BLOCK_SIZE=BLOCK_SIZE, |
| 249 | + ) |
| 250 | + else: |
| 251 | + rejection_greedy_sample_triton[(grid,)]( |
| 252 | + output_token_ids, |
| 253 | + cu_num_draft_tokens, |
| 254 | + draft_token_ids, |
| 255 | + target_argmax, |
| 256 | + bonus_token_ids, |
| 257 | + is_greedy, |
| 258 | + vec_len, |
| 259 | + max_spec_len, |
| 260 | + BLOCK_SIZE=BLOCK_SIZE, |
| 261 | + ) |
230 | 262 | else: |
231 | 263 | if min(num_draft_tokens) == 1 and max( |
232 | 264 | num_draft_tokens) == 1 and sampling_metadata.all_greedy: |
@@ -337,13 +369,23 @@ def expand_batch_to_tokens( |
337 | 369 | assert cu_num_tokens.shape[0] == batch_size |
338 | 370 | expanded_x = x.new_empty(num_tokens) |
339 | 371 | if HAS_TRITON: |
340 | | - expand_kernel[(batch_size, )]( |
| 372 | + vec_len = batch_size |
| 373 | + n = cu_num_tokens.numel() |
| 374 | + BLOCK_SIZE = 2 |
| 375 | + grid = triton.cdiv(n, BLOCK_SIZE) |
| 376 | + if n >= vectorcore_num: |
| 377 | + grid = vectorcore_num |
| 378 | + BLOCK_SIZE = triton.next_power_of_2(n // grid) |
| 379 | + |
| 380 | + expand_kernel[(grid,)]( |
341 | 381 | expanded_x, |
342 | 382 | x, |
343 | 383 | cu_num_tokens, |
344 | 384 | replace_from, |
345 | 385 | replace_to, |
| 386 | + vec_len, |
346 | 387 | MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. |
| 388 | + BLOCK_SIZE=BLOCK_SIZE, |
347 | 389 | ) |
348 | 390 | else: |
349 | 391 | expand_pytorch( |
@@ -626,50 +668,115 @@ def sample_recovered_tokens_pytorch( |
626 | 668 |
|
627 | 669 |
|
628 | 670 | @triton.jit(do_not_specialize=["max_spec_len"]) |
629 | | -def rejection_greedy_sample_kernel( |
| 671 | +def bonus_renew_1( |
| 672 | + bonus_token_ids_ptr, |
| 673 | + position, |
| 674 | + output_token_ids_ptr, |
| 675 | +): |
| 676 | + bonus_token_id = tl.load(bonus_token_ids_ptr + position) |
| 677 | + tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id) |
| 678 | + |
| 679 | + |
| 680 | +@triton.jit(do_not_specialize=["max_spec_len"]) |
| 681 | +def rejection_greedy_sample_spec_len_1_triton( |
| 682 | + output_token_ids_ptr, # [batch_size, 2] |
| 683 | + draft_token_ids_ptr, # [num_tokens] |
| 684 | + target_argmax_ptr, # [num_tokens] |
| 685 | + bonus_token_ids_ptr, |
| 686 | + vec_len, |
| 687 | + BLOCK_SIZE: tl.constexpr, |
| 688 | +): |
| 689 | + block_idx = tl.program_id(0) |
| 690 | + offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 691 | + mask = offset < vec_len |
| 692 | + |
| 693 | + draft_token_id = tl.load(draft_token_ids_ptr + offset, mask) |
| 694 | + target_argmax_id = tl.load(target_argmax_ptr + offset, mask) |
| 695 | + tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) |
| 696 | + |
| 697 | + for pos in tl.range(0, BLOCK_SIZE): |
| 698 | + draft_token_id1 = tl.get_element(draft_token_id, (pos, )) |
| 699 | + target_argmax1 = tl.get_element(target_argmax_id, (pos, )) |
| 700 | + position = block_idx * BLOCK_SIZE + pos |
| 701 | + if draft_token_id1 == target_argmax1: |
| 702 | + bonus_renew_1( |
| 703 | + bonus_token_ids_ptr, |
| 704 | + position, |
| 705 | + output_token_ids_ptr, |
| 706 | + ) |
| 707 | + |
| 708 | + |
| 709 | +@triton.jit(do_not_specialize=["max_spec_len"]) |
| 710 | +def bonus_renew( |
| 711 | + bonus_token_ids_ptr, |
| 712 | + position, |
| 713 | + output_token_ids_ptr, |
| 714 | + max_spec_len, |
| 715 | + num_tokens1, |
| 716 | +): |
| 717 | + bonus_token_id = tl.load(bonus_token_ids_ptr + position) |
| 718 | + tl.store( |
| 719 | + output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, |
| 720 | + bonus_token_id) |
| 721 | + |
| 722 | + |
| 723 | +@triton.jit(do_not_specialize=["max_spec_len"]) |
| 724 | +def rejection_greedy_sample_triton( |
630 | 725 | output_token_ids_ptr, # [batch_size, max_spec_len + 1] |
631 | 726 | cu_num_draft_tokens_ptr, # [batch_size] |
632 | 727 | draft_token_ids_ptr, # [num_tokens] |
633 | 728 | target_argmax_ptr, # [num_tokens] |
634 | 729 | bonus_token_ids_ptr, # [batch_size] |
635 | 730 | is_greedy_ptr, # [batch_size] or None |
| 731 | + vec_len, |
636 | 732 | max_spec_len, |
| 733 | + BLOCK_SIZE: tl.constexpr, |
637 | 734 | ): |
638 | | - req_idx = tl.program_id(0) |
639 | | - # Because is_greedy_ptr is not Nonr at profiling run, |
640 | | - # re-comilation may happen during runtime when is_greedy_ptr is None. |
641 | | - is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + |
642 | | - req_idx) |
643 | | - if not is_greedy: |
644 | | - # Early exit for non-greedy sampling requests |
645 | | - return |
| 735 | + block_idx = tl.program_id(0) |
| 736 | + offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 737 | + mask = offset < vec_len |
646 | 738 |
|
647 | | - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + |
648 | | - req_idx - 1) |
649 | | - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) |
| 739 | + if is_greedy_ptr is None: |
| 740 | + is_greedy_mask = mask |
| 741 | + else: |
| 742 | + is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0) |
| 743 | + is_greedy_mask = mask & (is_greedy != 0) |
| 744 | + |
| 745 | + start_idx = tl.where( |
| 746 | + offset == 0, 0, |
| 747 | + tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) |
| 748 | + end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask) |
650 | 749 | num_draft_tokens = end_idx - start_idx |
651 | 750 |
|
652 | | - rejected = False |
653 | | - for pos in range(num_draft_tokens): |
654 | | - if not rejected: |
655 | | - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) |
656 | | - target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) |
657 | | - tl.store( |
658 | | - output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, |
659 | | - target_argmax_id, |
660 | | - ) |
661 | | - if draft_token_id != target_argmax_id: |
662 | | - # Reject |
663 | | - rejected = True |
| 751 | + for pos in tl.range(0, BLOCK_SIZE): |
| 752 | + num_tokens1 = tl.get_element(num_draft_tokens, (pos, )) |
| 753 | + if num_tokens1 != 0: |
| 754 | + rejected = False |
| 755 | + start_idx1 = tl.get_element(start_idx, (pos, )) |
| 756 | + position = block_idx * BLOCK_SIZE + pos |
| 757 | + for i in range(num_tokens1): |
| 758 | + if not rejected: |
| 759 | + draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + |
| 760 | + i) |
| 761 | + target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + |
| 762 | + i) |
| 763 | + tl.store( |
| 764 | + output_token_ids_ptr + position * (max_spec_len + 1) + |
| 765 | + i, |
| 766 | + target_argmax_id, |
| 767 | + ) |
| 768 | + if draft_token_id != target_argmax_id: |
| 769 | + # Reject. |
| 770 | + rejected = True |
664 | 771 |
|
665 | | - if not rejected: |
666 | | - # If all tokens are accepted, append the bonus token |
667 | | - bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) |
668 | | - tl.store( |
669 | | - output_token_ids_ptr + req_idx * (max_spec_len + 1) + |
670 | | - num_draft_tokens, |
671 | | - bonus_token_id, |
672 | | - ) |
| 772 | + if not rejected: |
| 773 | + bonus_renew( |
| 774 | + bonus_token_ids_ptr, |
| 775 | + position, |
| 776 | + output_token_ids_ptr, |
| 777 | + max_spec_len, |
| 778 | + num_tokens1, |
| 779 | + ) |
673 | 780 |
|
674 | 781 |
|
675 | 782 | @triton.jit(do_not_specialize=["max_spec_len"]) |
@@ -739,22 +846,30 @@ def expand_kernel( |
739 | 846 | cu_num_tokens_ptr, # [batch_size] |
740 | 847 | replace_from, |
741 | 848 | replace_to, |
| 849 | + vec_len, |
742 | 850 | MAX_NUM_TOKENS: tl.constexpr, |
| 851 | + BLOCK_SIZE: tl.constexpr, |
743 | 852 | ): |
744 | 853 | req_idx = tl.program_id(0) |
745 | | - if req_idx == 0: |
746 | | - start_idx = 0 |
747 | | - else: |
748 | | - start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) |
749 | | - end_idx = tl.load(cu_num_tokens_ptr + req_idx) |
| 854 | + offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 855 | + len_mask = offset < vec_len |
| 856 | + |
| 857 | + start_idx = tl.where(offset == 0, 0, |
| 858 | + tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) |
| 859 | + end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) |
750 | 860 | num_tokens = end_idx - start_idx |
751 | 861 |
|
752 | | - src_val = tl.load(input_ptr + req_idx) |
| 862 | + src_val = tl.load(input_ptr + offset, len_mask) |
753 | 863 | src_val = tl.where(src_val == replace_from, replace_to, src_val) |
754 | | - offset = tl.arange(0, MAX_NUM_TOKENS) |
755 | | - tl.store(output_ptr + start_idx + offset, |
756 | | - src_val, |
757 | | - mask=offset < num_tokens) |
| 864 | + |
| 865 | + for i in tl.range(0, BLOCK_SIZE): |
| 866 | + num_tokens1 = tl.get_element(num_tokens, (i, )) |
| 867 | + start_idx1 = tl.get_element(start_idx, (i, )) |
| 868 | + src_val1 = tl.get_element(src_val, (i, )) |
| 869 | + offset1 = tl.arange(0, MAX_NUM_TOKENS) |
| 870 | + tl.store(output_ptr + start_idx1 + offset1, |
| 871 | + src_val1, |
| 872 | + mask=offset1 < num_tokens1) |
758 | 873 |
|
759 | 874 |
|
760 | 875 | @triton.jit |
|
0 commit comments