-
Notifications
You must be signed in to change notification settings - Fork 654
feat: implement high-performance Triton kernels for rejection sampling #4830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
353facf
66f61e1
e304bc4
6fae243
2e49ca1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,16 @@ | |
| # step. This value is chosen to be large enough to handle typical use cases. | ||
| MAX_SPEC_LEN = 32 | ||
|
|
||
| vectorcore_num = None | ||
| device_properties = None | ||
|
|
||
| if HAS_TRITON: | ||
| from triton.runtime import driver # type: ignore | ||
| device_properties = driver.active.utils.get_device_properties( | ||
| torch.npu.current_device()) | ||
| vectorcore_num = device_properties['num_vectorcore'] | ||
| #get vector core number in order for later tiling | ||
|
|
||
|
|
||
| class AscendRejectionSampler(RejectionSampler, nn.Module): | ||
| """ | ||
|
|
@@ -218,15 +228,36 @@ def rejection_sample( | |
| # Rejection sampling for greedy sampling requests. | ||
| target_argmax = target_probs.argmax(dim=-1) | ||
| if HAS_TRITON: | ||
| rejection_greedy_sample_kernel[(batch_size, )]( | ||
| output_token_ids, | ||
| cu_num_draft_tokens, | ||
| draft_token_ids, | ||
| target_argmax, | ||
| bonus_token_ids, | ||
| is_greedy, | ||
| max_spec_len, | ||
| ) | ||
| vec_len = batch_size | ||
| n = cu_num_draft_tokens.numel() | ||
| BLOCK_SIZE = 2 | ||
| grid = triton.cdiv(n, BLOCK_SIZE) | ||
| if n >= vectorcore_num: | ||
| grid = vectorcore_num # Empirically tuned value | ||
| BLOCK_SIZE = triton.next_power_of_2(n // grid) | ||
|
|
||
| if min(num_draft_tokens) == 1 and max( | ||
| num_draft_tokens) == 1 and sampling_metadata.all_greedy: | ||
| rejection_greedy_sample_spec_len_1_triton[(grid, )]( | ||
| output_token_ids, | ||
| draft_token_ids, | ||
| target_argmax, | ||
| bonus_token_ids, | ||
| vec_len, | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
| else: | ||
| rejection_greedy_sample_triton[(grid, )]( | ||
| output_token_ids, | ||
| cu_num_draft_tokens, | ||
| draft_token_ids, | ||
| target_argmax, | ||
| bonus_token_ids, | ||
| is_greedy, | ||
| vec_len, | ||
| max_spec_len, | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
| else: | ||
| if min(num_draft_tokens) == 1 and max( | ||
| num_draft_tokens) == 1 and sampling_metadata.all_greedy: | ||
|
|
@@ -337,13 +368,23 @@ def expand_batch_to_tokens( | |
| assert cu_num_tokens.shape[0] == batch_size | ||
| expanded_x = x.new_empty(num_tokens) | ||
| if HAS_TRITON: | ||
| expand_kernel[(batch_size, )]( | ||
| vec_len = batch_size | ||
| n = cu_num_tokens.numel() | ||
| BLOCK_SIZE = 2 | ||
| grid = triton.cdiv(n, BLOCK_SIZE) | ||
| if n >= vectorcore_num: | ||
| grid = vectorcore_num | ||
| BLOCK_SIZE = triton.next_power_of_2(n // grid) | ||
|
|
||
| expand_kernel[(grid, )]( | ||
| expanded_x, | ||
| x, | ||
| cu_num_tokens, | ||
| replace_from, | ||
| replace_to, | ||
| vec_len, | ||
| MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. | ||
| BLOCK_SIZE=BLOCK_SIZE, | ||
| ) | ||
| else: | ||
| expand_pytorch( | ||
|
|
@@ -626,50 +667,115 @@ def sample_recovered_tokens_pytorch( | |
|
|
||
|
|
||
| @triton.jit(do_not_specialize=["max_spec_len"]) | ||
| def rejection_greedy_sample_kernel( | ||
| def bonus_renew_1( | ||
| bonus_token_ids_ptr, | ||
| position, | ||
| output_token_ids_ptr, | ||
| ): | ||
| bonus_token_id = tl.load(bonus_token_ids_ptr + position) | ||
| tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id) | ||
|
|
||
|
|
||
| @triton.jit(do_not_specialize=["max_spec_len"]) | ||
| def rejection_greedy_sample_spec_len_1_triton( | ||
| output_token_ids_ptr, # [batch_size, 2] | ||
| draft_token_ids_ptr, # [num_tokens] | ||
| target_argmax_ptr, # [num_tokens] | ||
| bonus_token_ids_ptr, | ||
| vec_len, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| block_idx = tl.program_id(0) | ||
| offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| mask = offset < vec_len | ||
|
|
||
| draft_token_id = tl.load(draft_token_ids_ptr + offset, mask) | ||
| target_argmax_id = tl.load(target_argmax_ptr + offset, mask) | ||
| tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) | ||
|
|
||
| for pos in tl.range(0, BLOCK_SIZE): | ||
| draft_token_id1 = tl.get_element(draft_token_id, (pos, )) | ||
| target_argmax1 = tl.get_element(target_argmax_id, (pos, )) | ||
| position = block_idx * BLOCK_SIZE + pos | ||
| if draft_token_id1 == target_argmax1: | ||
| bonus_renew_1( | ||
| bonus_token_ids_ptr, | ||
| position, | ||
| output_token_ids_ptr, | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit(do_not_specialize=["max_spec_len"]) | ||
| def bonus_renew( | ||
| bonus_token_ids_ptr, | ||
| position, | ||
| output_token_ids_ptr, | ||
| max_spec_len, | ||
| num_tokens1, | ||
| ): | ||
| bonus_token_id = tl.load(bonus_token_ids_ptr + position) | ||
| tl.store( | ||
| output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, | ||
| bonus_token_id) | ||
|
|
||
|
|
||
| @triton.jit(do_not_specialize=["max_spec_len"]) | ||
| def rejection_greedy_sample_triton( | ||
| output_token_ids_ptr, # [batch_size, max_spec_len + 1] | ||
| cu_num_draft_tokens_ptr, # [batch_size] | ||
| draft_token_ids_ptr, # [num_tokens] | ||
| target_argmax_ptr, # [num_tokens] | ||
| bonus_token_ids_ptr, # [batch_size] | ||
| is_greedy_ptr, # [batch_size] or None | ||
| vec_len, | ||
| max_spec_len, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| req_idx = tl.program_id(0) | ||
| # Because is_greedy_ptr is not Nonr at profiling run, | ||
| # re-comilation may happen during runtime when is_greedy_ptr is None. | ||
| is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + | ||
| req_idx) | ||
| if not is_greedy: | ||
| # Early exit for non-greedy sampling requests | ||
| return | ||
| block_idx = tl.program_id(0) | ||
| offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| mask = offset < vec_len | ||
|
|
||
| start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + | ||
| req_idx - 1) | ||
| end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) | ||
| if is_greedy_ptr is None: | ||
| is_greedy_mask = mask | ||
| else: | ||
| is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0) | ||
| is_greedy_mask = mask & (is_greedy != 0) | ||
|
|
||
| start_idx = tl.where( | ||
| offset == 0, 0, | ||
| tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) | ||
| end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask) | ||
| num_draft_tokens = end_idx - start_idx | ||
|
|
||
| rejected = False | ||
| for pos in range(num_draft_tokens): | ||
| if not rejected: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) | ||
| target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) | ||
| tl.store( | ||
| output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, | ||
| target_argmax_id, | ||
| ) | ||
| if draft_token_id != target_argmax_id: | ||
| # Reject | ||
| rejected = True | ||
| for pos in tl.range(0, BLOCK_SIZE): | ||
| num_tokens1 = tl.get_element(num_draft_tokens, (pos, )) | ||
| if num_tokens1 != 0: | ||
| rejected = False | ||
| start_idx1 = tl.get_element(start_idx, (pos, )) | ||
| position = block_idx * BLOCK_SIZE + pos | ||
| for i in range(num_tokens1): | ||
| if not rejected: | ||
| draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + | ||
| i) | ||
| target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + | ||
| i) | ||
| tl.store( | ||
| output_token_ids_ptr + position * (max_spec_len + 1) + | ||
| i, | ||
| target_argmax_id, | ||
| ) | ||
| if draft_token_id != target_argmax_id: | ||
| # Reject. | ||
| rejected = True | ||
|
|
||
| if not rejected: | ||
| # If all tokens are accepted, append the bonus token | ||
| bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) | ||
| tl.store( | ||
| output_token_ids_ptr + req_idx * (max_spec_len + 1) + | ||
| num_draft_tokens, | ||
| bonus_token_id, | ||
| ) | ||
| if not rejected: | ||
| bonus_renew( | ||
| bonus_token_ids_ptr, | ||
| position, | ||
| output_token_ids_ptr, | ||
| max_spec_len, | ||
| num_tokens1, | ||
| ) | ||
yuxingcyx marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @triton.jit(do_not_specialize=["max_spec_len"]) | ||
|
|
@@ -739,22 +845,30 @@ def expand_kernel( | |
| cu_num_tokens_ptr, # [batch_size] | ||
| replace_from, | ||
| replace_to, | ||
| vec_len, | ||
| MAX_NUM_TOKENS: tl.constexpr, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| req_idx = tl.program_id(0) | ||
| if req_idx == 0: | ||
| start_idx = 0 | ||
| else: | ||
| start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) | ||
| end_idx = tl.load(cu_num_tokens_ptr + req_idx) | ||
| offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| len_mask = offset < vec_len | ||
|
|
||
| start_idx = tl.where(offset == 0, 0, | ||
| tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) | ||
| end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) | ||
| num_tokens = end_idx - start_idx | ||
|
|
||
| src_val = tl.load(input_ptr + req_idx) | ||
| src_val = tl.load(input_ptr + offset, len_mask) | ||
| src_val = tl.where(src_val == replace_from, replace_to, src_val) | ||
| offset = tl.arange(0, MAX_NUM_TOKENS) | ||
| tl.store(output_ptr + start_idx + offset, | ||
| src_val, | ||
| mask=offset < num_tokens) | ||
|
|
||
| for i in tl.range(0, BLOCK_SIZE): | ||
| num_tokens1 = tl.get_element(num_tokens, (i, )) | ||
| start_idx1 = tl.get_element(start_idx, (i, )) | ||
| src_val1 = tl.get_element(src_val, (i, )) | ||
| offset1 = tl.arange(0, MAX_NUM_TOKENS) | ||
| tl.store(output_ptr + start_idx1 + offset1, | ||
| src_val1, | ||
| mask=offset1 < num_tokens1) | ||
|
Comment on lines
+864
to
+871
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loop |
||
|
|
||
|
|
||
| @triton.jit | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a comment to state why you encapsulate this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For certain inputs, each request does not go through the 'if not rejected' branch. After optimizing and collecting profiling data, it was found that this branch still accounted for a significant portion of MTE3 transfer time. Removing this branch does not affect accuracy and improves performance by 40%. Considering that this branch may still be executed for other inputs, these steps were written into a new operator, which is then called within the operator. Experiments show that performance improves to varying degrees for different inputs. After operator optimization, the setup is as follows: the part on the right represents the internal calls of the 'rejection_greedy_sample_kernel' operator to the left operator 'bonus_renew'.