Skip to content

Commit 1b5513a

Browse files
authored
[performance] Enhance performance after enabling min_p (#4529)
### What this PR does / why we need it? When min_p post-processing parameters are enabled, the original vllm implementation introduces the aclnInIndexPutImpl operator, which performs poorly on NPU ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? After enabling min_p to collect profiling The performance has been greatly improved - vLLM version: v0.11.2 --------- Signed-off-by: funanyang <[email protected]>
1 parent eabedf4 commit 1b5513a

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

vllm_ascend/sample/logits_processor/builtin.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,20 @@ def __init__(self, vllm_config: "VllmConfig", device: torch.device,
3333
self.min_p_device = self.min_p_cpu_tensor
3434
# Current slice of the device tensor
3535
self.min_p: torch.Tensor = self.min_p_device[:0]
36+
37+
def apply(self, logits: torch.Tensor) -> torch.Tensor:
38+
if not self.min_p_count:
39+
return logits
40+
# Convert logits to probability distribution
41+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
42+
# Calculate maximum probabilities per sequence
43+
max_probabilities = torch.amax(probability_values,
44+
dim=-1,
45+
keepdim=True)
46+
# Adjust min_p
47+
adjusted_min_p = max_probabilities.mul_(self.min_p)
48+
# Identify valid tokens using threshold comparison
49+
invalid_token_mask = probability_values < adjusted_min_p
50+
# Apply mask using boolean indexing
51+
logits.masked_fill_(invalid_token_mask, -float('inf'))
52+
return logits

0 commit comments

Comments
 (0)