-
Notifications
You must be signed in to change notification settings - Fork 637
[Kernel] add triton kernels for sampling #4550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] add triton kernels for sampling #4550
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request replaces several PyTorch-based sampling functions with Triton kernels, which should improve performance. The new kernels for greedy and random rejection sampling, batch expansion, and recovered token sampling appear to correctly replicate the logic of the original PyTorch functions. I've found a critical syntax error in the new Triton-related imports that will prevent the code from running. I've left a specific comment with a suggested fix. Additionally, it appears the unit tests in tests/ut/sample/test_rejection_sampler.py have not been updated to reflect the removal of the PyTorch functions and the introduction of the Triton kernels. These tests will fail and should be updated to validate the new implementations.
| from vllm.triton_utils import HAS_TRITON, triton | ||
| from vllm.triton_utils import triton.language as tl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import statement contains a syntax error. The from ... import ... syntax does not allow dot notation like triton.language in the imported names. To fix this and follow the common pattern in vLLM for handling the optional Triton dependency, you should import tl directly from vllm.triton_utils and combine the imports into a single line for conciseness.
from vllm.triton_utils import HAS_TRITON, triton, tl16cad8a to
7ff1172
Compare
6f4af8b to
6ffb171
Compare
Signed-off-by: Lord_of_Ironhill <[email protected]>
Signed-off-by: whx-sjtu <[email protected]>
### What this PR does / why we need it? Replace pyorch implement of sampling with triton kernels ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.11.2 --------- Signed-off-by: Lord_of_Ironhill <[email protected]> Signed-off-by: whx-sjtu <[email protected]> Co-authored-by: Lord_of_Ironhill <[email protected]> Co-authored-by: whx-sjtu <[email protected]>
### What this PR does / why we need it? Replace pyorch implement of sampling with triton kernels ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.11.2 --------- Signed-off-by: Lord_of_Ironhill <[email protected]> Signed-off-by: whx-sjtu <[email protected]> Co-authored-by: Lord_of_Ironhill <[email protected]> Co-authored-by: whx-sjtu <[email protected]> Signed-off-by: Che Ruan <[email protected]>
### What this PR does / why we need it? Replace pyorch implement of sampling with triton kernels ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.11.2 --------- Signed-off-by: Lord_of_Ironhill <[email protected]> Signed-off-by: whx-sjtu <[email protected]> Co-authored-by: Lord_of_Ironhill <[email protected]> Co-authored-by: whx-sjtu <[email protected]> Signed-off-by: Che Ruan <[email protected]>
### What this PR does / why we need it? Replace pyorch implement of sampling with triton kernels ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.11.2 --------- Signed-off-by: Lord_of_Ironhill <[email protected]> Signed-off-by: whx-sjtu <[email protected]> Co-authored-by: Lord_of_Ironhill <[email protected]> Co-authored-by: whx-sjtu <[email protected]>
What this PR does / why we need it?
Replace pyorch implement of sampling with triton kernels
Does this PR introduce any user-facing change?
No