44import torch
55import torch .distributed as dist
66
7+ from vllm import _custom_ops as custom_ops
78from flash_attn import flash_attn_varlen_func
89from vllm .attention .ops .prefix_prefill import context_attention_fwd
910
@@ -59,7 +60,7 @@ def add_rms_norm(
5960 weight : Tensor ,
6061 epsilon : float ,
6162) -> Tuple [Tensor , Tensor ]:
62- vllm . _custom_ops .fused_add_rms_norm (hidden_states , residual , weight , epsilon )
63+ custom_ops .fused_add_rms_norm (hidden_states , residual , weight , epsilon )
6364 return hidden_states , residual
6465
6566
@@ -188,7 +189,7 @@ def fill_kv_cache(
188189 quant_bits : int ,
189190) -> Tuple [Tensor , Tensor ]:
190191 kv_indices = kv_indices .squeeze (- 1 )
191- vllm . _custom_ops .reshape_and_cache_new (
192+ custom_ops .reshape_and_cache_new (
192193 key , value , key_cache , value_cache , kv_indices , "auto" , 1.0 , 1.0
193194 )
194195 return key_cache , value_cache
@@ -220,7 +221,7 @@ def paged_decode_attention(
220221 num_kv_heads = value_cache .size (1 )
221222 block_size = value_cache .size (2 )
222223 output = torch .empty_like (query )
223- vllm . _custom_ops .paged_attention_v1 (
224+ custom_ops .paged_attention_v1 (
224225 output ,
225226 query ,
226227 key_cache ,
@@ -304,7 +305,7 @@ def rms_norm(
304305 epsilon : float ,
305306) -> Tensor :
306307 output = torch .empty_like (hidden_states )
307- vllm . _custom_ops .rms_norm (output , hidden_states , weight , epsilon )
308+ custom_ops .rms_norm (output , hidden_states , weight , epsilon )
308309 return output
309310
310311
@@ -322,7 +323,7 @@ def moe_gating_topk_softmax(
322323
323324 token_expert_indicies = torch .empty_like (topk_ids )
324325
325- vllm . _custom_ops .topk_softmax (
326+ custom_ops .topk_softmax (
326327 topk_weights ,
327328 topk_ids ,
328329 token_expert_indicies ,
@@ -344,7 +345,7 @@ def silu_and_mul(x: Tensor, dim: int = -1) -> Tensor:
344345 d = x .shape [- 1 ] // 2
345346 output_shape = x .shape [:- 1 ] + (d ,)
346347 out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
347- vllm . _custom_ops .silu_and_mul (out , x )
348+ custom_ops .silu_and_mul (out , x )
348349 return out
349350
350351
@@ -398,7 +399,7 @@ def weight_quant_matmul(
398399 group_size : Optional [int ] = 0 ,
399400):
400401 offset = None if (offset is None or offset .numel () == 0 ) else offset
401- output = vllm . _custom_ops .awq_gemm (x , qweight , scale , offset , group_size )
402+ output = custom_ops .awq_gemm (x , qweight , scale , offset , group_size )
402403 if bias is not None :
403404 output += bias
404405 return output
0 commit comments