diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index c1ef10db5a9..efde3ed7ed1 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -20,6 +20,16 @@ # step. This value is chosen to be large enough to handle typical use cases. MAX_SPEC_LEN = 32 +vectorcore_num = None +device_properties = None + +if HAS_TRITON: + from triton.runtime import driver # type: ignore + device_properties = driver.active.utils.get_device_properties( + torch.npu.current_device()) + vectorcore_num = device_properties['num_vectorcore'] +#get vector core number in order for later tiling + class AscendRejectionSampler(RejectionSampler, nn.Module): """ @@ -218,15 +228,36 @@ def rejection_sample( # Rejection sampling for greedy sampling requests. target_argmax = target_probs.argmax(dim=-1) if HAS_TRITON: - rejection_greedy_sample_kernel[(batch_size, )]( - output_token_ids, - cu_num_draft_tokens, - draft_token_ids, - target_argmax, - bonus_token_ids, - is_greedy, - max_spec_len, - ) + vec_len = batch_size + n = cu_num_draft_tokens.numel() + BLOCK_SIZE = 2 + grid = triton.cdiv(n, BLOCK_SIZE) + if n >= vectorcore_num: + grid = vectorcore_num # Empirically tuned value + BLOCK_SIZE = triton.next_power_of_2(n // grid) + + if min(num_draft_tokens) == 1 and max( + num_draft_tokens) == 1 and sampling_metadata.all_greedy: + rejection_greedy_sample_spec_len_1_triton[(grid, )]( + output_token_ids, + draft_token_ids, + target_argmax, + bonus_token_ids, + vec_len, + BLOCK_SIZE=BLOCK_SIZE, + ) + else: + rejection_greedy_sample_triton[(grid, )]( + output_token_ids, + cu_num_draft_tokens, + draft_token_ids, + target_argmax, + bonus_token_ids, + is_greedy, + vec_len, + max_spec_len, + BLOCK_SIZE=BLOCK_SIZE, + ) else: if min(num_draft_tokens) == 1 and max( num_draft_tokens) == 1 and sampling_metadata.all_greedy: @@ -337,13 +368,23 @@ def expand_batch_to_tokens( assert cu_num_tokens.shape[0] == batch_size expanded_x = x.new_empty(num_tokens) if HAS_TRITON: - expand_kernel[(batch_size, )]( + vec_len = batch_size + n = cu_num_tokens.numel() + BLOCK_SIZE = 2 + grid = triton.cdiv(n, BLOCK_SIZE) + if n >= vectorcore_num: + grid = vectorcore_num + BLOCK_SIZE = triton.next_power_of_2(n // grid) + + expand_kernel[(grid, )]( expanded_x, x, cu_num_tokens, replace_from, replace_to, + vec_len, MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. + BLOCK_SIZE=BLOCK_SIZE, ) else: expand_pytorch( @@ -626,50 +667,115 @@ def sample_recovered_tokens_pytorch( @triton.jit(do_not_specialize=["max_spec_len"]) -def rejection_greedy_sample_kernel( +def bonus_renew_1( + bonus_token_ids_ptr, + position, + output_token_ids_ptr, +): + bonus_token_id = tl.load(bonus_token_ids_ptr + position) + tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id) + + +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_greedy_sample_spec_len_1_triton( + output_token_ids_ptr, # [batch_size, 2] + draft_token_ids_ptr, # [num_tokens] + target_argmax_ptr, # [num_tokens] + bonus_token_ids_ptr, + vec_len, + BLOCK_SIZE: tl.constexpr, +): + block_idx = tl.program_id(0) + offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < vec_len + + draft_token_id = tl.load(draft_token_ids_ptr + offset, mask) + target_argmax_id = tl.load(target_argmax_ptr + offset, mask) + tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) + + for pos in tl.range(0, BLOCK_SIZE): + draft_token_id1 = tl.get_element(draft_token_id, (pos, )) + target_argmax1 = tl.get_element(target_argmax_id, (pos, )) + position = block_idx * BLOCK_SIZE + pos + if draft_token_id1 == target_argmax1: + bonus_renew_1( + bonus_token_ids_ptr, + position, + output_token_ids_ptr, + ) + + +@triton.jit(do_not_specialize=["max_spec_len"]) +def bonus_renew( + bonus_token_ids_ptr, + position, + output_token_ids_ptr, + max_spec_len, + num_tokens1, +): + bonus_token_id = tl.load(bonus_token_ids_ptr + position) + tl.store( + output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, + bonus_token_id) + + +@triton.jit(do_not_specialize=["max_spec_len"]) +def rejection_greedy_sample_triton( output_token_ids_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens] bonus_token_ids_ptr, # [batch_size] is_greedy_ptr, # [batch_size] or None + vec_len, max_spec_len, + BLOCK_SIZE: tl.constexpr, ): - req_idx = tl.program_id(0) - # Because is_greedy_ptr is not Nonr at profiling run, - # re-comilation may happen during runtime when is_greedy_ptr is None. - is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + - req_idx) - if not is_greedy: - # Early exit for non-greedy sampling requests - return + block_idx = tl.program_id(0) + offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < vec_len - start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + - req_idx - 1) - end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) + if is_greedy_ptr is None: + is_greedy_mask = mask + else: + is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0) + is_greedy_mask = mask & (is_greedy != 0) + + start_idx = tl.where( + offset == 0, 0, + tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) + end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask) num_draft_tokens = end_idx - start_idx - rejected = False - for pos in range(num_draft_tokens): - if not rejected: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - target_argmax_id, - ) - if draft_token_id != target_argmax_id: - # Reject - rejected = True + for pos in tl.range(0, BLOCK_SIZE): + num_tokens1 = tl.get_element(num_draft_tokens, (pos, )) + if num_tokens1 != 0: + rejected = False + start_idx1 = tl.get_element(start_idx, (pos, )) + position = block_idx * BLOCK_SIZE + pos + for i in range(num_tokens1): + if not rejected: + draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + + i) + target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + + i) + tl.store( + output_token_ids_ptr + position * (max_spec_len + 1) + + i, + target_argmax_id, + ) + if draft_token_id != target_argmax_id: + # Reject. + rejected = True - if not rejected: - # If all tokens are accepted, append the bonus token - bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx) - tl.store( - output_token_ids_ptr + req_idx * (max_spec_len + 1) + - num_draft_tokens, - bonus_token_id, - ) + if not rejected: + bonus_renew( + bonus_token_ids_ptr, + position, + output_token_ids_ptr, + max_spec_len, + num_tokens1, + ) @triton.jit(do_not_specialize=["max_spec_len"]) @@ -739,22 +845,30 @@ def expand_kernel( cu_num_tokens_ptr, # [batch_size] replace_from, replace_to, + vec_len, MAX_NUM_TOKENS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) - if req_idx == 0: - start_idx = 0 - else: - start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1) - end_idx = tl.load(cu_num_tokens_ptr + req_idx) + offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + len_mask = offset < vec_len + + start_idx = tl.where(offset == 0, 0, + tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) + end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) num_tokens = end_idx - start_idx - src_val = tl.load(input_ptr + req_idx) + src_val = tl.load(input_ptr + offset, len_mask) src_val = tl.where(src_val == replace_from, replace_to, src_val) - offset = tl.arange(0, MAX_NUM_TOKENS) - tl.store(output_ptr + start_idx + offset, - src_val, - mask=offset < num_tokens) + + for i in tl.range(0, BLOCK_SIZE): + num_tokens1 = tl.get_element(num_tokens, (i, )) + start_idx1 = tl.get_element(start_idx, (i, )) + src_val1 = tl.get_element(src_val, (i, )) + offset1 = tl.arange(0, MAX_NUM_TOKENS) + tl.store(output_ptr + start_idx1 + offset1, + src_val1, + mask=offset1 < num_tokens1) @triton.jit