@@ -133,8 +133,8 @@ def npu_rms_norm_w8a8(self, x, w, eps=1e-6, quant_dtype=torch.int8):
133133 )
134134 return rms_norm_w8a8
135135
136- @register_conversion ("torch.ops.lmdeploy .apply_rotary_pos_emb.default" )
137- def apply_rotary_pos_emb (self , q , k , cos , sin , q_out , k_out ):
136+ @register_conversion ("torch.ops.dlinfer .apply_rotary_pos_emb.default" )
137+ def apply_rotary_pos_emb (self , q , k , cos , sin ):
138138 q_shape = list (q .node .meta ["val" ].shape )
139139 k_shape = list (k .node .meta ["val" ].shape )
140140 is_qk_require_reshape = len (q_shape ) == 3
@@ -151,22 +151,6 @@ def apply_rotary_pos_emb(self, q, k, cos, sin, q_out, k_out):
151151 else self .get_proxy (atb_op .View , (k , (- 1 , k_shape [1 ] * k_shape [2 ])))
152152 )
153153 out = self .get_proxy (atb_op .Rope , (new_q , new_k , cos , sin , None ))
154- if is_qk_require_reshape :
155- out_q = self .get_proxy (atb_op .GetItem , (out , 0 ))
156- out_q = self .get_proxy (atb_op .View , (out_q , (- 1 , q_shape [1 ], q_shape [2 ])))
157- out_k = self .get_proxy (atb_op .GetItem , (out , 1 ))
158- out_k = self .get_proxy (atb_op .View , (out_k , (- 1 , k_shape [1 ], k_shape [2 ])))
159- out = self .get_proxy (atb_op .Tuple , (out_q , out_k ))
160- if (q_out is not None ) and (k_out is not None ):
161- self .get_proxy (
162- atb_op .AclNnInplaceCopy ,
163- (q_out , self .get_proxy (atb_op .GetItem , (out , 0 ))),
164- )
165- self .get_proxy (
166- atb_op .AclNnInplaceCopy ,
167- (k_out , self .get_proxy (atb_op .GetItem , (out , 1 ))),
168- )
169- out = self .get_proxy (atb_op .Tuple , (q_out , k_out ))
170154 return out
171155
172156 @register_conversion ("torch.ops.atb.inplace_div.default" )
0 commit comments