Skip to content

Conversation

@shiyuan680
Copy link
Contributor

@shiyuan680 shiyuan680 commented Nov 8, 2025

What this PR does / why we need it?

qwen3-next suppot triton chunk_gated_delta_rule ops

Does this PR introduce any user-facing change?

How was this patch tested?

ttft reduce over 1/2 times

co-owners

@OsirisDuan

@github-actions
Copy link

github-actions bot commented Nov 8, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds Triton kernel implementations for chunk_gated_delta_rule operations, seemingly for use with Huawei Ascend NPUs within the vLLM framework. The implementation is a substantial port from the fla library, introducing several new files for forward and backward passes. While the effort to optimize these operations is commendable, the current implementation has several critical issues that will prevent it from running correctly. My review has identified missing Python imports, undefined variables causing NameErrors, inconsistent and likely incorrect use of chunk_size, missing parameters in a kernel launch, and a leftover debugging statement. These issues must be addressed to ensure the correctness and performance of the new operations.

Comment on lines 3 to 14
import warnings

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The type hint Optional is used in this file (e.g., on line 188), but it is not imported from the typing module. This will cause a NameError at runtime.

Suggested change
import warnings
import torch
import warnings
from typing import Optional
import torch

output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is an inconsistency in the chunk_size used. Here, chunk_local_cumsum is called with chunk_size=64, but all subsequent chunked operations in both the forward and backward passes use chunk_size=16. Since chunk_local_cumsum performs a cumulative sum within chunks, this discrepancy will likely lead to incorrect calculations in later stages that expect data to be processed in chunks of 16. To ensure correctness, the chunk size should be consistent across all related operations.

Suggested change
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = chunk_local_cumsum(g, chunk_size=16, cu_seqlens=cu_seqlens)

g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function exp is used within this Triton kernel, but it is not a standard tl function and has not been imported. This will result in a NameError. In other files within this PR, a custom exp function is imported from fla.ops.utils.op. The same import is needed here. Please add from fla.ops.utils.op import exp to the imports at the top of the file.

for num_stages in [2, 3, 4]
],
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_G_GAMMA', 'USE_DW'],
**autotune_cache_kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable autotune_cache_kwargs is used in the @triton.autotune decorator, but it is not defined anywhere in this file. This will cause a NameError during module loading. You should define it at the top of the file, for example, by adapting the definition from the fla library.

Comment on lines 477 to 168
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: torch.Tensor | None = None,
g_gamma: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, H, K, V = *q.shape, v.shape[-1]
BT = chunk_size
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1] ** -0.5

o = torch.empty_like(v)
def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
chunk_fwd_kernel_o[grid](
q=q,
k=k,
v=v,
h=h,
g=g,
g_gamma=g_gamma,
o=o,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
)
return o
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The chunk_fwd_kernel_o kernel is not autotuned and requires BK and BV to be passed as constexpr arguments. However, these are missing from the kernel launch call, which will lead to a runtime error. You should define BK and BV and pass them to the kernel, similar to how it's done in other wrapper functions in this file.

def chunk_fwd_o(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    h: torch.Tensor,
    g: torch.Tensor | None = None,
    g_gamma: torch.Tensor | None = None,
    scale: float | None = None,
    cu_seqlens: torch.LongTensor | None = None,
    chunk_size: int = 64,
) -> torch.Tensor:
    B, T, H, K, V = *q.shape, v.shape[-1]
    BT = chunk_size
    chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
    if scale is None:
        scale = k.shape[-1] ** -0.5

    if check_shared_mem('hopper', k.device.index):
        CONST_TILING = 128
    elif check_shared_mem:
        CONST_TILING = 64
    else:
        CONST_TILING = 32
    BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
    BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)

    o = torch.empty_like(v)
    def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H)
    chunk_fwd_kernel_o[grid](
        q=q,
        k=k,
        v=v,
        h=h,
        g=g,
        g_gamma=g_gamma,
        o=o,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        scale=scale,
        T=T,
        H=H,
        K=K,
        V=V,
        BT=BT,
        BK=BK,
        BV=BV
    )
    return o

p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))

tl.debug_barrier()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A tl.debug_barrier() is present here. This is typically used for debugging and should be removed from production code as it forces synchronization and can negatively impact performance.

@shiyuan680 shiyuan680 force-pushed the triton branch 3 times, most recently from fe2a876 to 315ec77 Compare November 12, 2025 09:39
@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@shiyuan680 shiyuan680 changed the title 【Draft】support triton chunk_gated_delta_rule ops 【OPS】qwen3-next support triton chunk_gated_delta_rule ops Nov 13, 2025
@shiyuan680 shiyuan680 force-pushed the triton branch 4 times, most recently from 352e435 to 7b4db5a Compare November 13, 2025 09:08
@shiyuan680 shiyuan680 force-pushed the triton branch 5 times, most recently from 48a8164 to abdaca1 Compare November 13, 2025 12:35
@MengqingCao MengqingCao added ready read for review ready-for-test start test by label for PR labels Nov 14, 2025
@shiyuan680 shiyuan680 force-pushed the triton branch 3 times, most recently from a4074c5 to a2c834e Compare November 18, 2025 11:04
assert last_recurrent_state.shape == (3, 8, 128, 128)


if __name__ == '__main__':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this two line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@pytest.fixture
def mock_moe_env():

with patch("torch_npu.npu_moe_finalize_routing",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why patch it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

import pytest
import torch

from tests.ut.base import PytestBase
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a ut or e2e? you create this file in e2e moudle but import ut base?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this ut use triton_npu,ut test is not install the package

@@ -0,0 +1,51 @@
import unittest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is a e2e, you should enable this test in .github/workflow as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this ut use triton_npu,ut test is not install the package

@@ -0,0 +1,226 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this file is copied from other place? where? it's better to add the origin link as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy from the vllm origin file

# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the ruff and mypy is skipped?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton files will have check error, i find the other triton files in project also skipped.

@shiyuan680 shiyuan680 force-pushed the triton branch 2 times, most recently from 110ea91 to f569a9f Compare November 20, 2025 02:15
@shiyuan680 shiyuan680 force-pushed the triton branch 3 times, most recently from 20d201e to 2fee3b1 Compare November 21, 2025 03:46
@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

run: |
. /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh
python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27.whl

@wangxiyuan wangxiyuan merged commit 1c4a046 into vllm-project:main Nov 28, 2025
21 of 22 checks passed
ChenCangtao pushed a commit to ChenCangtao/vllm-ascend that referenced this pull request Dec 3, 2025
…ct#4070)

### What this PR does / why we need it?
qwen3-next suppot  triton chunk_gated_delta_rule ops

### co-owners
@OsirisDuan

- vLLM version: v0.11.2

Signed-off-by: shiyuan680 <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Dec 4, 2025
…ct#4070)

### What this PR does / why we need it?
qwen3-next suppot  triton chunk_gated_delta_rule ops

### co-owners
@OsirisDuan

- vLLM version: v0.11.2

Signed-off-by: shiyuan680 <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
Mercykid-bash pushed a commit to Mercykid-bash/vllm-ascend that referenced this pull request Dec 4, 2025
…ct#4070)

### What this PR does / why we need it?
qwen3-next suppot  triton chunk_gated_delta_rule ops

### co-owners
@OsirisDuan

- vLLM version: v0.11.2

Signed-off-by: shiyuan680 <[email protected]>
Signed-off-by: Che Ruan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops module:tests ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants