Skip to content

Commit 842930e

Browse files
author
zxwang
committed
addrmsnorm_bias
Signed-off-by: zxwang <[email protected]>
1 parent 136ea9f commit 842930e

File tree

1 file changed

+144
-14
lines changed

1 file changed

+144
-14
lines changed

vllm_ascend/ops/layernorm.py

Lines changed: 144 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,146 @@
1515
# This file is a part of the vllm-ascend project.
1616
#
1717

18-
from typing import Optional, Tuple, Union, cast
18+
from typing import Optional, Tuple, Union, cast, Dict, Any
1919

2020
import torch
2121
from vllm.config import get_current_vllm_config
2222
from vllm.forward_context import get_forward_context
2323
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
24+
from vllm.triton_utils import tl, triton
25+
from functools import cache
2426

2527

28+
@cache
29+
def get_device_properties() -> Tuple[int, int]:
30+
device = torch.npu.current_device()
31+
device_properties: Dict[str, Any] = (
32+
triton.runtime.driver.active.utils.get_device_properties(device)
33+
)
34+
35+
num_aicore = device_properties.get("num_aicore", -1)
36+
num_vectorcore = device_properties.get("num_vectorcore", -1)
37+
38+
assert num_aicore > 0 and num_vectorcore > 0, "Failed to detect device properties."
39+
return num_aicore, num_vectorcore
40+
41+
42+
@triton.heuristics({
43+
"HAS_BIAS": lambda args: args["B"] is not None
44+
})
45+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
46+
@triton.jit
47+
def rms_norm_fwd_kernel(
48+
X, # pointer to the input
49+
Y, # pointer to the output
50+
W, # pointer to the weights
51+
B, # pointer to the biases
52+
Z, # pointer to the residual
53+
Z_Out, # pointer to the residual output
54+
stride_x_row, # how much to increase the pointer when moving by 1 row
55+
stride_y_row,
56+
stride_z_row,
57+
stride_z_out_row,
58+
n_rows, # number of rows in X_base
59+
n_cols, # number of columns in X_base
60+
eps, # epsilon to avoid division by zero
61+
BLOCK_N: tl.constexpr,
62+
HAS_BIAS: tl.constexpr,
63+
HAS_Z: tl.constexpr,
64+
):
65+
# Map the program id to the row of X_base and Y_base it should compute.
66+
# Each program computes a row of X_base and store to Y_base
67+
row_start = tl.program_id(0)
68+
for row_idx in tl.range(row_start, n_rows, tl.num_programs(0)):
69+
start_x = X + row_idx * stride_x_row
70+
start_y = Y + row_idx * stride_y_row
71+
if HAS_Z:
72+
start_z = Z + row_idx * stride_z_row
73+
start_z_out = Z_Out + row_idx * stride_z_out_row
74+
offsets = tl.arange(0, BLOCK_N)
75+
mask = offsets < n_cols
76+
x = tl.load(start_x + offsets, mask=mask, other=0.0)
77+
original_dtype = x.dtype
78+
x = x.to(tl.float32)
79+
if HAS_Z:
80+
z = tl.load(start_z + offsets, mask=mask, other=0.0).to(tl.float32)
81+
x = x + z
82+
tl.store(start_z_out + offsets, x, mask=mask)
83+
var = tl.sum(x * x, axis=0) / n_cols
84+
rstd = 1 / tl.sqrt(var + eps)
85+
w = tl.load(W + offsets, mask=mask).to(tl.float32)
86+
if HAS_BIAS:
87+
bias = tl.load(B + offsets, mask=mask).to(tl.float32)
88+
89+
x_hat = x * rstd
90+
x_hat = x_hat.to(original_dtype)
91+
tl.device_print("[Row %d]xxxdtype: %s", row_idx, x_hat.dtype)
92+
y = x_hat * w
93+
if HAS_BIAS:
94+
y = y + bias
95+
tl.store(start_y + offsets, y, mask=mask)
96+
97+
98+
def _rms_norm_fwd_triton(
99+
x,
100+
weight,
101+
eps,
102+
residual=None,
103+
bias=None,
104+
out=None,
105+
residual_out=None,
106+
):
107+
M, N = x.shape
108+
assert x.stride(-1) == 1
109+
assert weight.shape == (N,)
110+
assert weight.stride(-1) == 1
111+
# logger.info(f"bias is {bias}")
112+
if bias is not None:
113+
assert bias.stride(-1) == 1
114+
assert bias.shape == (N,)
115+
if residual is not None:
116+
assert residual.shape == x.shape
117+
assert residual.stride(-1) == 1
118+
if residual_out is None:
119+
residual_out = torch.empty_like(x)
120+
# allocate output
121+
if out is not None:
122+
assert out.shape == x.shape
123+
else:
124+
out = torch.empty_like(x)
125+
assert out.stride(-1) == 1
126+
# Less than 64KB per feature: enqueue fused kernel
127+
MAX_FUSED_SIZE = 65536 // x.element_size()
128+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
129+
if N > BLOCK_N:
130+
raise RuntimeError(
131+
"This rms norm doesn't support feature dim >= 64KB.")
132+
# heuristics for number of warps
133+
num_warps = min(max(BLOCK_N // 256, 1), 8)
134+
# _, num_vectorcore = get_device_properties()
135+
num_vectorcore = 40
136+
grid = (M if M < num_vectorcore else num_vectorcore,)
137+
# with torch.npu.device(x.device.index):
138+
rms_norm_fwd_kernel[grid](
139+
x,
140+
out,
141+
weight,
142+
bias,
143+
residual,
144+
residual_out,
145+
x.stride(0),
146+
out.stride(0),
147+
residual.stride(0) if residual is not None else None,
148+
residual_out.stride(0) if residual is not None else None,
149+
M,
150+
N,
151+
eps,
152+
BLOCK_N=BLOCK_N,
153+
num_warps=num_warps,
154+
# multibuffer=True,
155+
)
156+
return out, residual_out
157+
26158
def _addrmsnorm_forward_oot(
27159
self,
28160
x: torch.Tensor,
@@ -32,10 +164,9 @@ def _addrmsnorm_forward_oot(
32164
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
33165
import torch_npu
34166

35-
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
167+
from vllm_ascend.utils import is_310p
36168

37-
if layer is not None and get_ascend_device_type(
38-
) != AscendDeviceType._310P:
169+
if layer is not None and not is_310p():
39170
layer_cls_name = layer.__class__.__name__
40171
try:
41172
weight_prefetch_method = get_forward_context(
@@ -68,17 +199,14 @@ def _addrmsnorm_forward_oot(
68199
)
69200

70201
else:
71-
if get_ascend_device_type() == AscendDeviceType._310P:
202+
if is_310p():
72203
orig_dtype = residual.dtype
73204
x = x + residual.to(x.dtype)
74205
residual = x.to(orig_dtype)
75206
x, _ = torch_npu.npu_rms_norm(x, self.weight,
76207
self.variance_epsilon)
77208
else:
78-
x, _, residual = torch_npu.npu_add_rms_norm(
79-
x, residual, self.weight, self.variance_epsilon)
80-
if bias is not None:
81-
x.add_(bias)
209+
x, residual = _rms_norm_fwd_triton(x, self.weight, self.variance_epsilon, residual, bias)
82210
torch.ops.vllm.maybe_wait_prefetch_done(x)
83211
return x, residual
84212

@@ -115,10 +243,12 @@ def forward_oot(
115243
self, x, residual, self.next_need_quant_fusion_linear,
116244
self.bias)
117245
return x, residual
118-
x, residual = torch_npu.npu_rms_norm(x, self.weight,
119-
self.variance_epsilon)
246+
120247
if self.bias is not None:
121-
x.add_(self.bias)
248+
x, _ = _rms_norm_fwd_triton(x, self.weight, self.variance_epsilon, None, self.bias)
249+
else:
250+
x, _ = torch_npu.npu_rms_norm(x, self.weight,
251+
self.variance_epsilon)
122252
return x
123253

124254
@property
@@ -196,9 +326,9 @@ def forward_oot(
196326
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
197327
import torch_npu
198328

199-
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
329+
from vllm_ascend.utils import is_310p
200330
if residual is not None:
201-
if get_ascend_device_type() == AscendDeviceType._310P:
331+
if is_310p():
202332
orig_dtype = residual.dtype
203333
x = x + residual.to(x.dtype)
204334
residual = x.to(orig_dtype)

0 commit comments

Comments
 (0)