@@ -55,7 +55,8 @@ class XPUSingleEncoderFuser : public FuseBase {
5555 bool norm_before = false ,
5656 const std::string& relative_type = " " ,
5757 bool with_mask = true ,
58- bool smooth_quant = false )
58+ bool smooth_quant = false ,
59+ bool with_token_slice = false )
5960 : act_type_(act_type),
6061 input_pos_(input_pos),
6162 qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
@@ -66,7 +67,8 @@ class XPUSingleEncoderFuser : public FuseBase {
6667 norm_before_(norm_before),
6768 relative_emb_type_(relative_type),
6869 with_mask_(with_mask),
69- smooth_quant_(smooth_quant) {}
70+ smooth_quant_(smooth_quant),
71+ with_token_slice_(with_token_slice) {}
7072
7173 void BuildPattern () override {
7274 PMNode* input = nullptr ;
@@ -79,11 +81,39 @@ class XPUSingleEncoderFuser : public FuseBase {
7981 PMNode* ln_before_out = nullptr ;
8082 PMNode* ln_before_mean = nullptr ;
8183 PMNode* ln_before_var = nullptr ;
84+ PMNode* input_slice = nullptr ;
85+ PMNode* input_slice_out = nullptr ;
8286 if (smooth_quant_ && !norm_before_) {
8387 VLOG (3 ) << " build first smooth_quant_scale" ;
84- input = VarNode (" input" )
85- ->assert_is_op_input (" elementwise_mul" , " X" )
86- ->AsInput ();
88+ if (with_token_slice_) {
89+ VLOG (3 ) << " build input_slice" ;
90+ input =
91+ VarNode (" input" )->assert_is_op_input (" slice" , " Input" )->AsInput ();
92+ input_slice = OpNode (" input_slice" , " slice" )
93+ ->assert_op_attr_satisfied <std::vector<int >>(
94+ " axes" ,
95+ [](const std::vector<int >& attr) {
96+ return attr.size () == 1 && attr[0 ] == 1 ;
97+ })
98+ ->assert_op_attr_satisfied <std::vector<int >>(
99+ " starts" ,
100+ [](const std::vector<int >& attr) {
101+ return attr.size () == 1 ;
102+ })
103+ ->assert_op_attr_satisfied <std::vector<int >>(
104+ " ends" ,
105+ [](const std::vector<int >& attr) {
106+ return attr.size () == 1 ;
107+ })
108+ ->AsIntermediate ();
109+ input_slice_out = VarNode (" input_slice_out" )
110+ ->assert_is_op_output (" slice" , " Out" )
111+ ->assert_is_op_input (" elementwise_mul" , " X" );
112+ } else {
113+ input = VarNode (" input" )
114+ ->assert_is_op_input (" elementwise_mul" , " X" )
115+ ->AsInput ();
116+ }
87117 smooth_scale_1_weight = VarNode (" smooth_scale_1_weight" )
88118 ->assert_is_op_input (" elementwise_mul" , " Y" )
89119 ->AsInput ();
@@ -92,6 +122,26 @@ class XPUSingleEncoderFuser : public FuseBase {
92122 smooth_scale_1_out = VarNode (" smooth_scale_1_out" )
93123 ->assert_is_op_output (" elementwise_mul" , " Out" )
94124 ->AsIntermediate ();
125+ } else if (with_token_slice_) {
126+ VLOG (3 ) << " build input_slice" ;
127+ input = VarNode (" input" )->assert_is_op_input (" slice" , " Input" )->AsInput ();
128+ input_slice =
129+ OpNode (" input_slice" , " slice" )
130+ ->assert_op_attr_satisfied <std::vector<int >>(
131+ " axes" ,
132+ [](const std::vector<int >& attr) {
133+ return attr.size () == 1 && attr[0 ] == 1 ;
134+ })
135+ ->assert_op_attr_satisfied <std::vector<int >>(
136+ " starts" ,
137+ [](const std::vector<int >& attr) { return attr.size () == 1 ; })
138+ ->assert_op_attr_satisfied <std::vector<int >>(
139+ " ends" ,
140+ [](const std::vector<int >& attr) { return attr.size () == 1 ; })
141+ ->AsIntermediate ();
142+ input_slice_out = VarNode (" input_slice_out" )
143+ ->assert_is_op_output (" slice" , " Out" )
144+ ->assert_is_op_input (" elementwise_add" , input_pos_);
95145 } else {
96146 input = VarNode (" input" )
97147 ->assert_is_op_input (" elementwise_add" , input_pos_)
@@ -311,11 +361,40 @@ class XPUSingleEncoderFuser : public FuseBase {
311361 VarNode (" qkv_transpose2_xshape" )
312362 ->assert_is_op_output (" transpose2" , " XShape" )
313363 ->AsIntermediate ();
364+ PMNode* qkv_slice = nullptr ;
365+ PMNode* qkv_slice_out = nullptr ;
366+ PMNode* qkv_reshape2_out = nullptr ;
314367 auto * qkv_reshape2 = OpNode (" qkv_reshape2" , " reshape2" )->AsIntermediate ();
315- auto * qkv_reshape2_out = VarNode (" qkv_reshape2_out" )
316- ->assert_is_op_output (" reshape2" , " Out" )
317- ->assert_is_op_input (mul_type_, " X" )
318- ->AsIntermediate ();
368+ if (with_token_slice_) {
369+ qkv_reshape2_out = VarNode (" qkv_reshape2_out" )
370+ ->assert_is_op_output (" reshape2" , " Out" )
371+ ->assert_is_op_input (" slice" , " Input" )
372+ ->AsIntermediate ();
373+ VLOG (3 ) << " build qkv_slice" ;
374+ qkv_slice =
375+ OpNode (" qkv_slice" , " slice" )
376+ ->assert_op_attr_satisfied <std::vector<int >>(
377+ " axes" ,
378+ [](const std::vector<int >& attr) {
379+ return attr.size () == 1 && attr[0 ] == 1 ;
380+ })
381+ ->assert_op_attr_satisfied <std::vector<int >>(
382+ " starts" ,
383+ [](const std::vector<int >& attr) { return attr.size () == 1 ; })
384+ ->assert_op_attr_satisfied <std::vector<int >>(
385+ " ends" ,
386+ [](const std::vector<int >& attr) { return attr.size () == 1 ; })
387+ ->AsIntermediate ();
388+ qkv_slice_out = VarNode (" qkv_slice_out" )
389+ ->assert_is_op_output (" slice" , " Out" )
390+ ->assert_is_op_input (mul_type_, " X" )
391+ ->AsIntermediate ();
392+ } else {
393+ qkv_reshape2_out = VarNode (" qkv_reshape2_out" )
394+ ->assert_is_op_output (" reshape2" , " Out" )
395+ ->assert_is_op_input (mul_type_, " X" )
396+ ->AsIntermediate ();
397+ }
319398 auto * qkv_reshape2_xshape = VarNode (" qkv_reshape2_xshape" )
320399 ->assert_is_op_output (" reshape2" , " XShape" )
321400 ->AsIntermediate ();
@@ -531,18 +610,33 @@ class XPUSingleEncoderFuser : public FuseBase {
531610 *v_transpose2 >> *v_transpose2_xshape;
532611
533612 *qkv_matmul >> *qkv_matmul_out >> *qkv_transpose2 >> *qkv_transpose2_out >>
534- *qkv_reshape2 >> *qkv_reshape2_out >> *qkv_mul >> *qkv_mul_out >>
535- *qkv_add >> *qkv_add_out >> *qkv_add_2;
613+ *qkv_reshape2 >> *qkv_reshape2_out;
614+ if (with_token_slice_) {
615+ *qkv_reshape2_out >> *qkv_slice >> *qkv_slice_out >> *qkv_mul;
616+ } else {
617+ *qkv_reshape2_out >> *qkv_mul;
618+ }
619+ *qkv_mul >> *qkv_mul_out >> *qkv_add >> *qkv_add_out >> *qkv_add_2;
536620 *qkv_transpose2 >> *qkv_transpose2_xshape;
537621 *qkv_reshape2 >> *qkv_reshape2_xshape;
538622 *qkv_mul_y >> *qkv_mul;
539623 *qkv_add_y >> *qkv_add;
540624 if (smooth_quant_ && !norm_before_) {
541625 *smooth_scale_1_weight >> *smooth_scale_1;
542- *input >> *smooth_scale_1 >> *smooth_scale_1_out >> *qkv_add_2 >>
543- *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
626+ if (with_token_slice_) {
627+ *input >> *input_slice >> *input_slice_out >> *smooth_scale_1;
628+ } else {
629+ *input >> *smooth_scale_1;
630+ }
631+ *smooth_scale_1 >> *smooth_scale_1_out >> *qkv_add_2 >> *qkv_add_2_out >>
632+ *qkv_ln_2 >> *qkv_ln_2_out;
544633 } else {
545- *input >> *qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
634+ if (with_token_slice_) {
635+ *input >> *input_slice >> *input_slice_out >> *qkv_add_2;
636+ } else {
637+ *input >> *qkv_add_2;
638+ }
639+ *qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
546640 }
547641 *qkv_ln_2_scale >> *qkv_ln_2;
548642 *qkv_ln_2_bias >> *qkv_ln_2;
@@ -619,6 +713,8 @@ class XPUSingleEncoderFuser : public FuseBase {
619713 // the model is smooth_quant or not, we don't need to do anything
620714 // so, set is_smooth_quant as false.
621715 op_desc.SetAttr <bool >(" is_smooth_quant" , false );
716+ // temporarily does not support token slice in the case of pre-layernorm
717+ op_desc.SetAttr <bool >(" with_token_slice" , false );
622718 } else {
623719 op_desc.SetInput (" LNScale" ,
624720 {
@@ -643,6 +739,33 @@ class XPUSingleEncoderFuser : public FuseBase {
643739 } else {
644740 op_desc.SetAttr <bool >(" is_smooth_quant" , false );
645741 }
742+ if (with_token_slice_) {
743+ op_desc.SetAttr <bool >(" with_token_slice" , true );
744+ int token_sliced_length = -1 ;
745+ auto * qkv_slice_op_info = matched.at (" qkv_slice" )->stmt ()->op_info ();
746+ auto * input_slice_op_info =
747+ matched.at (" input_slice" )->stmt ()->op_info ();
748+ if (qkv_slice_op_info->HasAttr (" starts" ) &&
749+ qkv_slice_op_info->HasAttr (" ends" ) &&
750+ input_slice_op_info->HasAttr (" starts" ) &&
751+ input_slice_op_info->HasAttr (" ends" )) {
752+ auto qkv_slice_starts =
753+ qkv_slice_op_info->GetAttr <std::vector<int >>(" starts" );
754+ auto qkv_slice_ends =
755+ qkv_slice_op_info->GetAttr <std::vector<int >>(" ends" );
756+ auto input_slice_starts =
757+ input_slice_op_info->GetAttr <std::vector<int >>(" starts" );
758+ auto input_slice_ends =
759+ input_slice_op_info->GetAttr <std::vector<int >>(" ends" );
760+ CHECK_EQ (qkv_slice_starts.size (), input_slice_starts.size ());
761+ CHECK_EQ (qkv_slice_ends.size (), input_slice_ends.size ());
762+ CHECK_EQ (qkv_slice_starts[0 ], input_slice_starts[0 ]);
763+ CHECK_EQ (qkv_slice_ends[0 ], input_slice_ends[0 ]);
764+ token_sliced_length = qkv_slice_ends[0 ] - qkv_slice_starts[0 ];
765+ CHECK_GT (token_sliced_length, 0 );
766+ op_desc.SetAttr <int >(" token_sliced_length" , token_sliced_length);
767+ }
768+ }
646769 }
647770 // XXX: keep these to fool SubgraphOp::AttachImpl()
648771 op_desc.SetAttr <int >(" sub_block" , 0 );
@@ -792,6 +915,7 @@ class XPUSingleEncoderFuser : public FuseBase {
792915 const std::string relative_emb_type_;
793916 bool with_mask_;
794917 bool smooth_quant_;
918+ bool with_token_slice_;
795919 // quant_info: mul input_max, output_max * 6 + matmul x_max:y_max,
796920 // output_max
797921 // * 2
@@ -1845,6 +1969,8 @@ class XPUMultiEncoderFuser {
18451969 std::vector<float > fc_input_max;
18461970 std::vector<float > softmax_max;
18471971 std::vector<std::string> quant_types;
1972+ bool has_token_sliced_layer = false ;
1973+ std::vector<int > token_sliced_length (all_encoders.size (), -1 );
18481974
18491975 for (size_t i = 0 ; i < all_encoders.size (); ++i) {
18501976 Node* cur_encoder = all_encoders[i];
@@ -1889,6 +2015,12 @@ class XPUMultiEncoderFuser {
18892015 }
18902016 }
18912017
2018+ if (op_info->HasAttr (" with_token_slice" ) &&
2019+ op_info->HasAttr (" token_sliced_length" )) {
2020+ has_token_sliced_layer = true ;
2021+ token_sliced_length[i] = op_info->GetAttr <int >(" token_sliced_length" );
2022+ }
2023+
18922024 auto * cur_out =
18932025 graph->RetrieveArgument (op_info->Output (" Outputs" ).front ());
18942026 if (all_encoders.size () == 1 ) {
@@ -1930,6 +2062,9 @@ class XPUMultiEncoderFuser {
19302062 }
19312063 op_desc.SetOutput (" Output" , {out_name});
19322064 op_desc.SetAttr <int >(" xpu" , 1 );
2065+ op_desc.SetAttr <std::vector<int >>(" token_sliced_length" ,
2066+ token_sliced_length);
2067+ op_desc.SetAttr <bool >(" has_token_sliced_layer" , has_token_sliced_layer);
19332068 op_desc.SetAttr <bool >(
19342069 " is_smooth_quant" ,
19352070 first_encoder_op_info->GetAttr <bool >(" is_smooth_quant" ));
@@ -2325,6 +2460,7 @@ class XPUMultiEncoderFusePass : public ProgramPass {
23252460 std::vector<std::string> relative_embedding_type{
23262461 " " , " __xpu__roformer_relative_embedding" };
23272462 std::vector<bool > with_smooth_quant{true , false };
2463+ std::vector<bool > with_token_slice{true , false };
23282464
23292465 std::string fc_precision;
23302466 bool adaptive_seqlen = false ;
@@ -2384,19 +2520,25 @@ class XPUMultiEncoderFusePass : public ProgramPass {
23842520 // so remove one
23852521 continue ;
23862522 }
2387- fusion::XPUSingleEncoderFuser single_encoder_fuser (
2388- act_type,
2389- input_pos,
2390- qkv_ln_2_out_pos,
2391- matmul_type,
2392- matmul2_type,
2393- mul_type,
2394- with_q_scale,
2395- norm_before,
2396- relative_type,
2397- mask,
2398- smooth_quant);
2399- single_encoder_fuser (graph.get ());
2523+ for (auto token_slice : with_token_slice) {
2524+ fusion::XPUSingleEncoderFuser single_encoder_fuser (
2525+ act_type,
2526+ input_pos,
2527+ qkv_ln_2_out_pos,
2528+ matmul_type,
2529+ matmul2_type,
2530+ mul_type,
2531+ with_q_scale,
2532+ norm_before,
2533+ relative_type,
2534+ mask,
2535+ smooth_quant,
2536+ token_slice);
2537+ single_encoder_fuser (graph.get ());
2538+ }
2539+ // must wait for both cases of whether single_encoders
2540+ // have token_slice
2541+ // to be detected before multi_encoder detecting.
24002542 fusion::XPUMultiEncoderFuser multi_encoder_fuser (
24012543 fc_precision, adaptive_seqlen);
24022544 multi_encoder_fuser (graph.get ());
0 commit comments