1515# This file is a part of the vllm-ascend project.
1616#
1717
18- from typing import Optional , Tuple , Union , cast
18+ from typing import Optional , Tuple , Union , cast , Dict , Any
1919
2020import torch
2121from vllm .config import get_current_vllm_config
2222from vllm .forward_context import get_forward_context
2323from vllm .model_executor .layers .layernorm import GemmaRMSNorm , RMSNorm
24+ from vllm .triton_utils import tl , triton
25+ from functools import cache
2426
2527
28+ @cache
29+ def get_device_properties () -> Tuple [int , int ]:
30+ device = torch .npu .current_device ()
31+ device_properties : Dict [str , Any ] = (
32+ triton .runtime .driver .active .utils .get_device_properties (device )
33+ )
34+
35+ num_aicore = device_properties .get ("num_aicore" , - 1 )
36+ num_vectorcore = device_properties .get ("num_vectorcore" , - 1 )
37+
38+ assert num_aicore > 0 and num_vectorcore > 0 , "Failed to detect device properties."
39+ return num_aicore , num_vectorcore
40+
41+
42+ @triton .heuristics ({
43+ "HAS_BIAS" : lambda args : args ["B" ] is not None
44+ })
45+ @triton .heuristics ({"HAS_Z" : lambda args : args ["Z" ] is not None })
46+ @triton .jit
47+ def rms_norm_fwd_kernel (
48+ X , # pointer to the input
49+ Y , # pointer to the output
50+ W , # pointer to the weights
51+ B , # pointer to the biases
52+ Z , # pointer to the residual
53+ Z_Out , # pointer to the residual output
54+ stride_x_row , # how much to increase the pointer when moving by 1 row
55+ stride_y_row ,
56+ stride_z_row ,
57+ stride_z_out_row ,
58+ n_rows , # number of rows in X_base
59+ n_cols , # number of columns in X_base
60+ eps , # epsilon to avoid division by zero
61+ BLOCK_N : tl .constexpr ,
62+ HAS_BIAS : tl .constexpr ,
63+ HAS_Z : tl .constexpr ,
64+ ):
65+ # Map the program id to the row of X_base and Y_base it should compute.
66+ # Each program computes a row of X_base and store to Y_base
67+ row_start = tl .program_id (0 )
68+ for row_idx in tl .range (row_start , n_rows , tl .num_programs (0 )):
69+ start_x = X + row_idx * stride_x_row
70+ start_y = Y + row_idx * stride_y_row
71+ if HAS_Z :
72+ start_z = Z + row_idx * stride_z_row
73+ start_z_out = Z_Out + row_idx * stride_z_out_row
74+ offsets = tl .arange (0 , BLOCK_N )
75+ mask = offsets < n_cols
76+ x = tl .load (start_x + offsets , mask = mask , other = 0.0 )
77+ original_dtype = x .dtype
78+ x = x .to (tl .float32 )
79+ if HAS_Z :
80+ z = tl .load (start_z + offsets , mask = mask , other = 0.0 ).to (tl .float32 )
81+ x = x + z
82+ tl .store (start_z_out + offsets , x , mask = mask )
83+ var = tl .sum (x * x , axis = 0 ) / n_cols
84+ rstd = 1 / tl .sqrt (var + eps )
85+ w = tl .load (W + offsets , mask = mask ).to (tl .float32 )
86+ if HAS_BIAS :
87+ bias = tl .load (B + offsets , mask = mask ).to (tl .float32 )
88+
89+ x_hat = x * rstd
90+ x_hat = x_hat .to (original_dtype )
91+ tl .device_print ("[Row %d]xxxdtype: %s" , row_idx , x_hat .dtype )
92+ y = x_hat * w
93+ if HAS_BIAS :
94+ y = y + bias
95+ tl .store (start_y + offsets , y , mask = mask )
96+
97+
98+ def _rms_norm_fwd_triton (
99+ x ,
100+ weight ,
101+ eps ,
102+ residual = None ,
103+ bias = None ,
104+ out = None ,
105+ residual_out = None ,
106+ ):
107+ M , N = x .shape
108+ assert x .stride (- 1 ) == 1
109+ assert weight .shape == (N ,)
110+ assert weight .stride (- 1 ) == 1
111+ # logger.info(f"bias is {bias}")
112+ if bias is not None :
113+ assert bias .stride (- 1 ) == 1
114+ assert bias .shape == (N ,)
115+ if residual is not None :
116+ assert residual .shape == x .shape
117+ assert residual .stride (- 1 ) == 1
118+ if residual_out is None :
119+ residual_out = torch .empty_like (x )
120+ # allocate output
121+ if out is not None :
122+ assert out .shape == x .shape
123+ else :
124+ out = torch .empty_like (x )
125+ assert out .stride (- 1 ) == 1
126+ # Less than 64KB per feature: enqueue fused kernel
127+ MAX_FUSED_SIZE = 65536 // x .element_size ()
128+ BLOCK_N = min (MAX_FUSED_SIZE , triton .next_power_of_2 (N ))
129+ if N > BLOCK_N :
130+ raise RuntimeError (
131+ "This rms norm doesn't support feature dim >= 64KB." )
132+ # heuristics for number of warps
133+ num_warps = min (max (BLOCK_N // 256 , 1 ), 8 )
134+ # _, num_vectorcore = get_device_properties()
135+ num_vectorcore = 40
136+ grid = (M if M < num_vectorcore else num_vectorcore ,)
137+ # with torch.npu.device(x.device.index):
138+ rms_norm_fwd_kernel [grid ](
139+ x ,
140+ out ,
141+ weight ,
142+ bias ,
143+ residual ,
144+ residual_out ,
145+ x .stride (0 ),
146+ out .stride (0 ),
147+ residual .stride (0 ) if residual is not None else None ,
148+ residual_out .stride (0 ) if residual is not None else None ,
149+ M ,
150+ N ,
151+ eps ,
152+ BLOCK_N = BLOCK_N ,
153+ num_warps = num_warps ,
154+ # multibuffer=True,
155+ )
156+ return out , residual_out
157+
26158def _addrmsnorm_forward_oot (
27159 self ,
28160 x : torch .Tensor ,
@@ -32,10 +164,9 @@ def _addrmsnorm_forward_oot(
32164) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
33165 import torch_npu
34166
35- from vllm_ascend .utils import AscendDeviceType , get_ascend_device_type
167+ from vllm_ascend .utils import is_310p
36168
37- if layer is not None and get_ascend_device_type (
38- ) != AscendDeviceType ._310P :
169+ if layer is not None and not is_310p ():
39170 layer_cls_name = layer .__class__ .__name__
40171 try :
41172 weight_prefetch_method = get_forward_context (
@@ -68,17 +199,14 @@ def _addrmsnorm_forward_oot(
68199 )
69200
70201 else :
71- if get_ascend_device_type () == AscendDeviceType . _310P :
202+ if is_310p () :
72203 orig_dtype = residual .dtype
73204 x = x + residual .to (x .dtype )
74205 residual = x .to (orig_dtype )
75206 x , _ = torch_npu .npu_rms_norm (x , self .weight ,
76207 self .variance_epsilon )
77208 else :
78- x , _ , residual = torch_npu .npu_add_rms_norm (
79- x , residual , self .weight , self .variance_epsilon )
80- if bias is not None :
81- x .add_ (bias )
209+ x , residual = _rms_norm_fwd_triton (x , self .weight , self .variance_epsilon , residual , bias )
82210 torch .ops .vllm .maybe_wait_prefetch_done (x )
83211 return x , residual
84212
@@ -115,10 +243,12 @@ def forward_oot(
115243 self , x , residual , self .next_need_quant_fusion_linear ,
116244 self .bias )
117245 return x , residual
118- x , residual = torch_npu .npu_rms_norm (x , self .weight ,
119- self .variance_epsilon )
246+
120247 if self .bias is not None :
121- x .add_ (self .bias )
248+ x , _ = _rms_norm_fwd_triton (x , self .weight , self .variance_epsilon , None , self .bias )
249+ else :
250+ x , _ = torch_npu .npu_rms_norm (x , self .weight ,
251+ self .variance_epsilon )
122252 return x
123253
124254 @property
@@ -196,9 +326,9 @@ def forward_oot(
196326 ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
197327 import torch_npu
198328
199- from vllm_ascend .utils import AscendDeviceType , get_ascend_device_type
329+ from vllm_ascend .utils import is_310p
200330 if residual is not None :
201- if get_ascend_device_type () == AscendDeviceType . _310P :
331+ if is_310p () :
202332 orig_dtype = residual .dtype
203333 x = x + residual .to (x .dtype )
204334 residual = x .to (orig_dtype )
0 commit comments