Skip to content

Commit fc568d6

Browse files
committed
[XPU] support token-slice for encoder
1 parent 3c61295 commit fc568d6

File tree

7 files changed

+358
-28
lines changed

7 files changed

+358
-28
lines changed

lite/api/paddle_use_passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ USE_MIR_PASS(__xpu__spatial_transformer_resblock_fuse_pass);
9999
USE_MIR_PASS(__xpu__matmul_scale_softmax_v1_fuse_pass);
100100
USE_MIR_PASS(__xpu__up_decoder_fuse_pass);
101101
USE_MIR_PASS(__xpu__multi_up_decoder_fuse_pass);
102+
USE_MIR_PASS(__xpu__remove_mask_slice_pass);
102103
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_fuse_pass);
103104
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v2_fuse_pass);
104105
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v3_fuse_pass);

lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc

Lines changed: 169 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class XPUSingleEncoderFuser : public FuseBase {
5555
bool norm_before = false,
5656
const std::string& relative_type = "",
5757
bool with_mask = true,
58-
bool smooth_quant = false)
58+
bool smooth_quant = false,
59+
bool with_token_slice = false)
5960
: act_type_(act_type),
6061
input_pos_(input_pos),
6162
qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
@@ -66,7 +67,8 @@ class XPUSingleEncoderFuser : public FuseBase {
6667
norm_before_(norm_before),
6768
relative_emb_type_(relative_type),
6869
with_mask_(with_mask),
69-
smooth_quant_(smooth_quant) {}
70+
smooth_quant_(smooth_quant),
71+
with_token_slice_(with_token_slice) {}
7072

7173
void BuildPattern() override {
7274
PMNode* input = nullptr;
@@ -79,11 +81,39 @@ class XPUSingleEncoderFuser : public FuseBase {
7981
PMNode* ln_before_out = nullptr;
8082
PMNode* ln_before_mean = nullptr;
8183
PMNode* ln_before_var = nullptr;
84+
PMNode* input_slice = nullptr;
85+
PMNode* input_slice_out = nullptr;
8286
if (smooth_quant_ && !norm_before_) {
8387
VLOG(3) << "build first smooth_quant_scale";
84-
input = VarNode("input")
85-
->assert_is_op_input("elementwise_mul", "X")
86-
->AsInput();
88+
if (with_token_slice_) {
89+
VLOG(3) << "build input_slice";
90+
input =
91+
VarNode("input")->assert_is_op_input("slice", "Input")->AsInput();
92+
input_slice = OpNode("input_slice", "slice")
93+
->assert_op_attr_satisfied<std::vector<int>>(
94+
"axes",
95+
[](const std::vector<int>& attr) {
96+
return attr.size() == 1 && attr[0] == 1;
97+
})
98+
->assert_op_attr_satisfied<std::vector<int>>(
99+
"starts",
100+
[](const std::vector<int>& attr) {
101+
return attr.size() == 1;
102+
})
103+
->assert_op_attr_satisfied<std::vector<int>>(
104+
"ends",
105+
[](const std::vector<int>& attr) {
106+
return attr.size() == 1;
107+
})
108+
->AsIntermediate();
109+
input_slice_out = VarNode("input_slice_out")
110+
->assert_is_op_output("slice", "Out")
111+
->assert_is_op_input("elementwise_mul", "X");
112+
} else {
113+
input = VarNode("input")
114+
->assert_is_op_input("elementwise_mul", "X")
115+
->AsInput();
116+
}
87117
smooth_scale_1_weight = VarNode("smooth_scale_1_weight")
88118
->assert_is_op_input("elementwise_mul", "Y")
89119
->AsInput();
@@ -92,6 +122,26 @@ class XPUSingleEncoderFuser : public FuseBase {
92122
smooth_scale_1_out = VarNode("smooth_scale_1_out")
93123
->assert_is_op_output("elementwise_mul", "Out")
94124
->AsIntermediate();
125+
} else if (with_token_slice_) {
126+
VLOG(3) << "build input_slice";
127+
input = VarNode("input")->assert_is_op_input("slice", "Input")->AsInput();
128+
input_slice =
129+
OpNode("input_slice", "slice")
130+
->assert_op_attr_satisfied<std::vector<int>>(
131+
"axes",
132+
[](const std::vector<int>& attr) {
133+
return attr.size() == 1 && attr[0] == 1;
134+
})
135+
->assert_op_attr_satisfied<std::vector<int>>(
136+
"starts",
137+
[](const std::vector<int>& attr) { return attr.size() == 1; })
138+
->assert_op_attr_satisfied<std::vector<int>>(
139+
"ends",
140+
[](const std::vector<int>& attr) { return attr.size() == 1; })
141+
->AsIntermediate();
142+
input_slice_out = VarNode("input_slice_out")
143+
->assert_is_op_output("slice", "Out")
144+
->assert_is_op_input("elementwise_add", input_pos_);
95145
} else {
96146
input = VarNode("input")
97147
->assert_is_op_input("elementwise_add", input_pos_)
@@ -311,11 +361,40 @@ class XPUSingleEncoderFuser : public FuseBase {
311361
VarNode("qkv_transpose2_xshape")
312362
->assert_is_op_output("transpose2", "XShape")
313363
->AsIntermediate();
364+
PMNode* qkv_slice = nullptr;
365+
PMNode* qkv_slice_out = nullptr;
366+
PMNode* qkv_reshape2_out = nullptr;
314367
auto* qkv_reshape2 = OpNode("qkv_reshape2", "reshape2")->AsIntermediate();
315-
auto* qkv_reshape2_out = VarNode("qkv_reshape2_out")
316-
->assert_is_op_output("reshape2", "Out")
317-
->assert_is_op_input(mul_type_, "X")
318-
->AsIntermediate();
368+
if (with_token_slice_) {
369+
qkv_reshape2_out = VarNode("qkv_reshape2_out")
370+
->assert_is_op_output("reshape2", "Out")
371+
->assert_is_op_input("slice", "Input")
372+
->AsIntermediate();
373+
VLOG(3) << "build qkv_slice";
374+
qkv_slice =
375+
OpNode("qkv_slice", "slice")
376+
->assert_op_attr_satisfied<std::vector<int>>(
377+
"axes",
378+
[](const std::vector<int>& attr) {
379+
return attr.size() == 1 && attr[0] == 1;
380+
})
381+
->assert_op_attr_satisfied<std::vector<int>>(
382+
"starts",
383+
[](const std::vector<int>& attr) { return attr.size() == 1; })
384+
->assert_op_attr_satisfied<std::vector<int>>(
385+
"ends",
386+
[](const std::vector<int>& attr) { return attr.size() == 1; })
387+
->AsIntermediate();
388+
qkv_slice_out = VarNode("qkv_slice_out")
389+
->assert_is_op_output("slice", "Out")
390+
->assert_is_op_input(mul_type_, "X")
391+
->AsIntermediate();
392+
} else {
393+
qkv_reshape2_out = VarNode("qkv_reshape2_out")
394+
->assert_is_op_output("reshape2", "Out")
395+
->assert_is_op_input(mul_type_, "X")
396+
->AsIntermediate();
397+
}
319398
auto* qkv_reshape2_xshape = VarNode("qkv_reshape2_xshape")
320399
->assert_is_op_output("reshape2", "XShape")
321400
->AsIntermediate();
@@ -531,18 +610,33 @@ class XPUSingleEncoderFuser : public FuseBase {
531610
*v_transpose2 >> *v_transpose2_xshape;
532611

533612
*qkv_matmul >> *qkv_matmul_out >> *qkv_transpose2 >> *qkv_transpose2_out >>
534-
*qkv_reshape2 >> *qkv_reshape2_out >> *qkv_mul >> *qkv_mul_out >>
535-
*qkv_add >> *qkv_add_out >> *qkv_add_2;
613+
*qkv_reshape2 >> *qkv_reshape2_out;
614+
if (with_token_slice_) {
615+
*qkv_reshape2_out >> *qkv_slice >> *qkv_slice_out >> *qkv_mul;
616+
} else {
617+
*qkv_reshape2_out >> *qkv_mul;
618+
}
619+
*qkv_mul >> *qkv_mul_out >> *qkv_add >> *qkv_add_out >> *qkv_add_2;
536620
*qkv_transpose2 >> *qkv_transpose2_xshape;
537621
*qkv_reshape2 >> *qkv_reshape2_xshape;
538622
*qkv_mul_y >> *qkv_mul;
539623
*qkv_add_y >> *qkv_add;
540624
if (smooth_quant_ && !norm_before_) {
541625
*smooth_scale_1_weight >> *smooth_scale_1;
542-
*input >> *smooth_scale_1 >> *smooth_scale_1_out >> *qkv_add_2 >>
543-
*qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
626+
if (with_token_slice_) {
627+
*input >> *input_slice >> *input_slice_out >> *smooth_scale_1;
628+
} else {
629+
*input >> *smooth_scale_1;
630+
}
631+
*smooth_scale_1 >> *smooth_scale_1_out >> *qkv_add_2 >> *qkv_add_2_out >>
632+
*qkv_ln_2 >> *qkv_ln_2_out;
544633
} else {
545-
*input >> *qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
634+
if (with_token_slice_) {
635+
*input >> *input_slice >> *input_slice_out >> *qkv_add_2;
636+
} else {
637+
*input >> *qkv_add_2;
638+
}
639+
*qkv_add_2 >> *qkv_add_2_out >> *qkv_ln_2 >> *qkv_ln_2_out;
546640
}
547641
*qkv_ln_2_scale >> *qkv_ln_2;
548642
*qkv_ln_2_bias >> *qkv_ln_2;
@@ -619,6 +713,8 @@ class XPUSingleEncoderFuser : public FuseBase {
619713
// the model is smooth_quant or not, we don't need to do anything
620714
// so, set is_smooth_quant as false.
621715
op_desc.SetAttr<bool>("is_smooth_quant", false);
716+
// temporarily does not support token slice in the case of pre-layernorm
717+
op_desc.SetAttr<bool>("with_token_slice", false);
622718
} else {
623719
op_desc.SetInput("LNScale",
624720
{
@@ -643,6 +739,33 @@ class XPUSingleEncoderFuser : public FuseBase {
643739
} else {
644740
op_desc.SetAttr<bool>("is_smooth_quant", false);
645741
}
742+
if (with_token_slice_) {
743+
op_desc.SetAttr<bool>("with_token_slice", true);
744+
int token_sliced_length = -1;
745+
auto* qkv_slice_op_info = matched.at("qkv_slice")->stmt()->op_info();
746+
auto* input_slice_op_info =
747+
matched.at("input_slice")->stmt()->op_info();
748+
if (qkv_slice_op_info->HasAttr("starts") &&
749+
qkv_slice_op_info->HasAttr("ends") &&
750+
input_slice_op_info->HasAttr("starts") &&
751+
input_slice_op_info->HasAttr("ends")) {
752+
auto qkv_slice_starts =
753+
qkv_slice_op_info->GetAttr<std::vector<int>>("starts");
754+
auto qkv_slice_ends =
755+
qkv_slice_op_info->GetAttr<std::vector<int>>("ends");
756+
auto input_slice_starts =
757+
input_slice_op_info->GetAttr<std::vector<int>>("starts");
758+
auto input_slice_ends =
759+
input_slice_op_info->GetAttr<std::vector<int>>("ends");
760+
CHECK_EQ(qkv_slice_starts.size(), input_slice_starts.size());
761+
CHECK_EQ(qkv_slice_ends.size(), input_slice_ends.size());
762+
CHECK_EQ(qkv_slice_starts[0], input_slice_starts[0]);
763+
CHECK_EQ(qkv_slice_ends[0], input_slice_ends[0]);
764+
token_sliced_length = qkv_slice_ends[0] - qkv_slice_starts[0];
765+
CHECK_GT(token_sliced_length, 0);
766+
op_desc.SetAttr<int>("token_sliced_length", token_sliced_length);
767+
}
768+
}
646769
}
647770
// XXX: keep these to fool SubgraphOp::AttachImpl()
648771
op_desc.SetAttr<int>("sub_block", 0);
@@ -792,6 +915,7 @@ class XPUSingleEncoderFuser : public FuseBase {
792915
const std::string relative_emb_type_;
793916
bool with_mask_;
794917
bool smooth_quant_;
918+
bool with_token_slice_;
795919
// quant_info: mul input_max, output_max * 6 + matmul x_max:y_max,
796920
// output_max
797921
// * 2
@@ -1845,6 +1969,8 @@ class XPUMultiEncoderFuser {
18451969
std::vector<float> fc_input_max;
18461970
std::vector<float> softmax_max;
18471971
std::vector<std::string> quant_types;
1972+
bool has_token_sliced_layer = false;
1973+
std::vector<int> token_sliced_length(all_encoders.size(), -1);
18481974

18491975
for (size_t i = 0; i < all_encoders.size(); ++i) {
18501976
Node* cur_encoder = all_encoders[i];
@@ -1889,6 +2015,12 @@ class XPUMultiEncoderFuser {
18892015
}
18902016
}
18912017

2018+
if (op_info->HasAttr("with_token_slice") &&
2019+
op_info->HasAttr("token_sliced_length")) {
2020+
has_token_sliced_layer = true;
2021+
token_sliced_length[i] = op_info->GetAttr<int>("token_sliced_length");
2022+
}
2023+
18922024
auto* cur_out =
18932025
graph->RetrieveArgument(op_info->Output("Outputs").front());
18942026
if (all_encoders.size() == 1) {
@@ -1930,6 +2062,9 @@ class XPUMultiEncoderFuser {
19302062
}
19312063
op_desc.SetOutput("Output", {out_name});
19322064
op_desc.SetAttr<int>("xpu", 1);
2065+
op_desc.SetAttr<std::vector<int>>("token_sliced_length",
2066+
token_sliced_length);
2067+
op_desc.SetAttr<bool>("has_token_sliced_layer", has_token_sliced_layer);
19332068
op_desc.SetAttr<bool>(
19342069
"is_smooth_quant",
19352070
first_encoder_op_info->GetAttr<bool>("is_smooth_quant"));
@@ -2325,6 +2460,7 @@ class XPUMultiEncoderFusePass : public ProgramPass {
23252460
std::vector<std::string> relative_embedding_type{
23262461
"", "__xpu__roformer_relative_embedding"};
23272462
std::vector<bool> with_smooth_quant{true, false};
2463+
std::vector<bool> with_token_slice{true, false};
23282464

23292465
std::string fc_precision;
23302466
bool adaptive_seqlen = false;
@@ -2384,19 +2520,25 @@ class XPUMultiEncoderFusePass : public ProgramPass {
23842520
// so remove one
23852521
continue;
23862522
}
2387-
fusion::XPUSingleEncoderFuser single_encoder_fuser(
2388-
act_type,
2389-
input_pos,
2390-
qkv_ln_2_out_pos,
2391-
matmul_type,
2392-
matmul2_type,
2393-
mul_type,
2394-
with_q_scale,
2395-
norm_before,
2396-
relative_type,
2397-
mask,
2398-
smooth_quant);
2399-
single_encoder_fuser(graph.get());
2523+
for (auto token_slice : with_token_slice) {
2524+
fusion::XPUSingleEncoderFuser single_encoder_fuser(
2525+
act_type,
2526+
input_pos,
2527+
qkv_ln_2_out_pos,
2528+
matmul_type,
2529+
matmul2_type,
2530+
mul_type,
2531+
with_q_scale,
2532+
norm_before,
2533+
relative_type,
2534+
mask,
2535+
smooth_quant,
2536+
token_slice);
2537+
single_encoder_fuser(graph.get());
2538+
}
2539+
// must wait for both cases of whether single_encoders
2540+
// have token_slice
2541+
// to be detected before multi_encoder detecting.
24002542
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
24012543
fc_precision, adaptive_seqlen);
24022544
multi_encoder_fuser(graph.get());

0 commit comments

Comments
 (0)