2626 this output is scaled to the proper range.
2727*/
2828
29- #ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
30- #define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
29+ #ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
30+ #define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
3131#if MXNET_USE_ONEDNN == 1
3232
3333#include < memory>
@@ -55,27 +55,21 @@ inline bool EndsWith(std::string const& value, std::string const& ending) {
5555class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
5656 private:
5757 bool quantized_;
58- SelectStatus status_ = kFail ;
59- std::vector<const BiDirectedNode*> matched_list_;
58+ bool patternFound = false ;
6059
6160 public:
6261 explicit SgDNNLFCSumFuseSelector (bool quantized) : quantized_(quantized) {}
6362
6463 bool Select (const BiDirectedNode& seed_node,
6564 const std::shared_ptr<NodeAttr>& node_attr) override {
6665 const auto n = seed_node.node ;
67- if (n->op () == Op::Get (" _sg_onednn_fully_connected" )) {
68- if (SupportDNNLAttr (node_attr) && (seed_node.outputs .size () == 1 )) {
69- auto const & fc_param = nnvm::get<DNNLFCFullParam>(n->attrs .parsed );
70- if ((!quantized_) || (fc_param.dnnl_param .quantized && !fc_param.dnnl_param .with_eltwise )) {
71- // Start subgraph when fusing for floats (quantized_ is false for ONEDNN backend) or
72- // when FC is already quantized (second pass for ONEDNN_QUANTIZE) but not already fuzed
73- // with elemwise operator.
74- status_ = kStart ;
75- matched_list_.clear ();
76- matched_list_.push_back (&seed_node);
77- return true ;
78- }
66+ if (n->op () == Op::Get (" _sg_onednn_fully_connected" ) && seed_node.outputs .size () == 1 ) {
67+ auto const & fc_param = nnvm::get<DNNLFCFullParam>(n->attrs .parsed );
68+ if (!quantized_ || (fc_param.dnnl_param .quantized && !fc_param.dnnl_param .with_eltwise )) {
69+ // Start subgraph when fusing for floats (quantized_ is false for ONEDNN backend) or
70+ // when FC is already quantized (second pass for ONEDNN_QUANTIZE) but not already fused
71+ // with elemwise operator.
72+ return true ;
7973 }
8074 }
8175 return false ;
@@ -88,57 +82,37 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
8882 bool SelectOutput (const BiDirectedNode& cur_node, const BiDirectedNode& output_node) override {
8983 const auto cur_n = cur_node.node ;
9084 const auto output_n = output_node.node ;
91- if (status_ == kFail || status_ == kSuccess || output_n->is_variable ()) {
85+ if (patternFound || output_n->is_variable ()) {
9286 return false ;
9387 }
94- // If n isn't the last matched node, then we encoutered an internal
95- // branch, we should pop out the node behind n and stop fusion.
96- if (matched_list_.back () != &cur_node) {
97- if (std::find (matched_list_.begin (), matched_list_.end (), &cur_node) != matched_list_.end ()) {
98- while (matched_list_.back () != &cur_node) {
99- matched_list_.pop_back ();
88+
89+ // Find _contrib_quantized_elemwise_add or elemwise_add
90+ if (EndsWith (output_n->op ()->name , " elemwise_add" )) {
91+ if (quantized_) {
92+ auto const & fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs .parsed );
93+ if (!fc_param.dnnl_param .enable_float_output ) {
94+ // For quantized graph, when FC floating point output is not enabled elementwise add must
95+ // also be quantized (min and max value have to be already stored in elementwise add).
96+ CHECK_EQ (output_n->attrs .dict .count (" min_calib_range" ), 1 );
10097 }
10198 }
102- status_ = kSuccess ;
99+ patternFound = true ;
100+ return true ;
101+ } else {
103102 return false ;
104103 }
105-
106- switch (status_) {
107- case kStart :
108- // Find _contrib_quantized_elemwise_add or elemwise_add
109- if (EndsWith (output_n->op ()->name , " elemwise_add" )) {
110- if (quantized_) {
111- auto const & fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs .parsed );
112- if (!fc_param.dnnl_param .enable_float_output ) {
113- // For quantized graph, when FC floating point output is not enabled
114- // elementwise add must also be quantized (min and max value have to be already stored
115- // in elementwise add).
116- CHECK_EQ (output_n->attrs .dict .count (" min_calib_range" ), 1 );
117- }
118- }
119- matched_list_.push_back (&output_node);
120- status_ = kSuccess ;
121- return true ;
122- }
123- default :
124- status_ = kFail ;
125- return false ;
126- }
127104 }
128105
129106 std::vector<BiDirectedNode*> Filter (const std::vector<BiDirectedNode*>& candidates) override {
130- if (status_ == kSuccess ) {
107+ if (patternFound ) {
131108 return candidates;
132109 } else {
133110 return std::vector<BiDirectedNode*>(0 );
134111 }
135112 }
136113
137114 void Reset () override {
138- CHECK_GE (matched_list_.size (), 1 );
139- auto new_selector = SgDNNLFCSumFuseSelector (quantized_);
140- new_selector.Select (*matched_list_[0 ], nullptr );
141- *this = new_selector;
115+ patternFound = false ;
142116 }
143117};
144118
@@ -147,11 +121,11 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
147121 SgDNNLFCSumFuseProperty () {}
148122
149123 static SubgraphPropertyPtr Create () {
150- static const std::string& name = " DNNL fuse FullyConnected with sum" ;
124+ static const std::string& name = " oneDNN fuse FullyConnected with sum" ;
151125 auto property = std::make_shared<SgDNNLFCSumFuseProperty>();
152126 property->SetAttr <std::string>(" property_name" , name);
153127 property->SetAttr <bool >(" inference_only" , true );
154- if (dmlc::GetEnv (" MXNET_DISABLE_DNNL_FC_SUM " , 0 )) {
128+ if (dmlc::GetEnv (" MXNET_DISABLE_ONEDNN_FC_SUM " , 0 )) {
155129 property->SetAttr <bool >(" disable" , true );
156130 }
157131 return property;
@@ -207,33 +181,33 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
207181 return selector;
208182 }
209183
210- void ConnectSubgraphOutputs (const nnvm::ObjectPtr n ,
184+ void ConnectSubgraphOutputs (const nnvm::ObjectPtr subgraph_node ,
211185 std::vector<nnvm::NodeEntry*>* output_entries) const override {
212186 // Connect all extern output entries to output[0]
213187 for (size_t i = 0 ; i < output_entries->size (); ++i) {
214188 auto entry_ptr = output_entries->at (i);
215- *entry_ptr = nnvm::NodeEntry{n , entry_ptr->index , 0 };
189+ *entry_ptr = nnvm::NodeEntry{subgraph_node , entry_ptr->index , 0 };
216190 }
217191 }
218192
219- void ConnectSubgraphInputs (const nnvm::ObjectPtr n ,
193+ void ConnectSubgraphInputs (const nnvm::ObjectPtr subgraph_node ,
220194 std::vector<nnvm::NodeEntry*>* input_entries,
221195 std::vector<nnvm::NodeEntry>* orig_input_entries) const override {
222- auto sym = n ->attrs .subgraphs [0 ];
223- auto const & fc_param = nnvm::get<DNNLFCFullParam>(n ->attrs .parsed );
224- std::unordered_set<const nnvm::Node*> node_sets ;
196+ auto sym = subgraph_node ->attrs .subgraphs [0 ];
197+ auto const & fc_param = nnvm::get<DNNLFCFullParam>(subgraph_node ->attrs .parsed );
198+ std::unordered_set<const nnvm::Node*> node_set ;
225199 DFSVisit (sym->outputs , [&](const nnvm::ObjectPtr& node) {
226200 if (node->is_variable ()) {
227201 return ;
228202 }
229- node_sets .insert (node.get ());
203+ node_set .insert (node.get ());
230204 if (EndsWith (node->op ()->name , " elemwise_add" )) {
231205 const size_t base_inputs = fc_param.default_param .no_bias ? 3 : 4 ;
232206 // Make sure fc output is the left operand of the add operator, if not:
233207 // - swap inputs of add operator
234208 // - switch add operands sequence to ensure that
235209 // the tensor (sum_tensor) to which FC output is added is the last input.
236- if (node_sets .count (node->inputs [1 ].node .get ())) {
210+ if (node_set .count (node->inputs [1 ].node .get ())) {
237211 // Example of input_entries reordering for channel-wise quantized graph:
238212 // sum_tensor.data --> fc.data
239213 // fc.data --> fc.weight0
@@ -272,12 +246,12 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
272246 }
273247 }
274248 });
275- n ->inputs = *orig_input_entries;
249+ subgraph_node ->inputs = *orig_input_entries;
276250 }
277251};
278252
279253} // namespace op
280254} // namespace mxnet
281255
282256#endif // if MXNET_USE_ONEDNN == 1
283- #endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
257+ #endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
0 commit comments