Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit e36c9f0

Browse files
authored
Refactor fc_sum_fuse (#21077)
* Refactor fc_sum_fuse * Fix sanity * Simplify Selector * Restore ConnectSubgraphOutputs in fc_sum_fuse_property * Fix node name
1 parent 26243ee commit e36c9f0

File tree

2 files changed

+38
-64
lines changed

2 files changed

+38
-64
lines changed

src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h renamed to src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h

Lines changed: 37 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
this output is scaled to the proper range.
2727
*/
2828

29-
#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
30-
#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
29+
#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
30+
#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
3131
#if MXNET_USE_ONEDNN == 1
3232

3333
#include <memory>
@@ -55,27 +55,21 @@ inline bool EndsWith(std::string const& value, std::string const& ending) {
5555
class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
5656
private:
5757
bool quantized_;
58-
SelectStatus status_ = kFail;
59-
std::vector<const BiDirectedNode*> matched_list_;
58+
bool patternFound = false;
6059

6160
public:
6261
explicit SgDNNLFCSumFuseSelector(bool quantized) : quantized_(quantized) {}
6362

6463
bool Select(const BiDirectedNode& seed_node,
6564
const std::shared_ptr<NodeAttr>& node_attr) override {
6665
const auto n = seed_node.node;
67-
if (n->op() == Op::Get("_sg_onednn_fully_connected")) {
68-
if (SupportDNNLAttr(node_attr) && (seed_node.outputs.size() == 1)) {
69-
auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
70-
if ((!quantized_) || (fc_param.dnnl_param.quantized && !fc_param.dnnl_param.with_eltwise)) {
71-
// Start subgraph when fusing for floats (quantized_ is false for ONEDNN backend) or
72-
// when FC is already quantized (second pass for ONEDNN_QUANTIZE) but not already fuzed
73-
// with elemwise operator.
74-
status_ = kStart;
75-
matched_list_.clear();
76-
matched_list_.push_back(&seed_node);
77-
return true;
78-
}
66+
if (n->op() == Op::Get("_sg_onednn_fully_connected") && seed_node.outputs.size() == 1) {
67+
auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
68+
if (!quantized_ || (fc_param.dnnl_param.quantized && !fc_param.dnnl_param.with_eltwise)) {
69+
// Start subgraph when fusing for floats (quantized_ is false for ONEDNN backend) or
70+
// when FC is already quantized (second pass for ONEDNN_QUANTIZE) but not already fused
71+
// with elemwise operator.
72+
return true;
7973
}
8074
}
8175
return false;
@@ -88,57 +82,37 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
8882
bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& output_node) override {
8983
const auto cur_n = cur_node.node;
9084
const auto output_n = output_node.node;
91-
if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) {
85+
if (patternFound || output_n->is_variable()) {
9286
return false;
9387
}
94-
// If n isn't the last matched node, then we encoutered an internal
95-
// branch, we should pop out the node behind n and stop fusion.
96-
if (matched_list_.back() != &cur_node) {
97-
if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) != matched_list_.end()) {
98-
while (matched_list_.back() != &cur_node) {
99-
matched_list_.pop_back();
88+
89+
// Find _contrib_quantized_elemwise_add or elemwise_add
90+
if (EndsWith(output_n->op()->name, "elemwise_add")) {
91+
if (quantized_) {
92+
auto const& fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
93+
if (!fc_param.dnnl_param.enable_float_output) {
94+
// For quantized graph, when FC floating point output is not enabled elementwise add must
95+
// also be quantized (min and max value have to be already stored in elementwise add).
96+
CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
10097
}
10198
}
102-
status_ = kSuccess;
99+
patternFound = true;
100+
return true;
101+
} else {
103102
return false;
104103
}
105-
106-
switch (status_) {
107-
case kStart:
108-
// Find _contrib_quantized_elemwise_add or elemwise_add
109-
if (EndsWith(output_n->op()->name, "elemwise_add")) {
110-
if (quantized_) {
111-
auto const& fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
112-
if (!fc_param.dnnl_param.enable_float_output) {
113-
// For quantized graph, when FC floating point output is not enabled
114-
// elementwise add must also be quantized (min and max value have to be already stored
115-
// in elementwise add).
116-
CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
117-
}
118-
}
119-
matched_list_.push_back(&output_node);
120-
status_ = kSuccess;
121-
return true;
122-
}
123-
default:
124-
status_ = kFail;
125-
return false;
126-
}
127104
}
128105

129106
std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& candidates) override {
130-
if (status_ == kSuccess) {
107+
if (patternFound) {
131108
return candidates;
132109
} else {
133110
return std::vector<BiDirectedNode*>(0);
134111
}
135112
}
136113

137114
void Reset() override {
138-
CHECK_GE(matched_list_.size(), 1);
139-
auto new_selector = SgDNNLFCSumFuseSelector(quantized_);
140-
new_selector.Select(*matched_list_[0], nullptr);
141-
*this = new_selector;
115+
patternFound = false;
142116
}
143117
};
144118

@@ -147,11 +121,11 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
147121
SgDNNLFCSumFuseProperty() {}
148122

149123
static SubgraphPropertyPtr Create() {
150-
static const std::string& name = "DNNL fuse FullyConnected with sum";
124+
static const std::string& name = "oneDNN fuse FullyConnected with sum";
151125
auto property = std::make_shared<SgDNNLFCSumFuseProperty>();
152126
property->SetAttr<std::string>("property_name", name);
153127
property->SetAttr<bool>("inference_only", true);
154-
if (dmlc::GetEnv("MXNET_DISABLE_DNNL_FC_SUM", 0)) {
128+
if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_FC_SUM", 0)) {
155129
property->SetAttr<bool>("disable", true);
156130
}
157131
return property;
@@ -207,33 +181,33 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
207181
return selector;
208182
}
209183

210-
void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
184+
void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node,
211185
std::vector<nnvm::NodeEntry*>* output_entries) const override {
212186
// Connect all extern output entries to output[0]
213187
for (size_t i = 0; i < output_entries->size(); ++i) {
214188
auto entry_ptr = output_entries->at(i);
215-
*entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
189+
*entry_ptr = nnvm::NodeEntry{subgraph_node, entry_ptr->index, 0};
216190
}
217191
}
218192

219-
void ConnectSubgraphInputs(const nnvm::ObjectPtr n,
193+
void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node,
220194
std::vector<nnvm::NodeEntry*>* input_entries,
221195
std::vector<nnvm::NodeEntry>* orig_input_entries) const override {
222-
auto sym = n->attrs.subgraphs[0];
223-
auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
224-
std::unordered_set<const nnvm::Node*> node_sets;
196+
auto sym = subgraph_node->attrs.subgraphs[0];
197+
auto const& fc_param = nnvm::get<DNNLFCFullParam>(subgraph_node->attrs.parsed);
198+
std::unordered_set<const nnvm::Node*> node_set;
225199
DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) {
226200
if (node->is_variable()) {
227201
return;
228202
}
229-
node_sets.insert(node.get());
203+
node_set.insert(node.get());
230204
if (EndsWith(node->op()->name, "elemwise_add")) {
231205
const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
232206
// Make sure fc output is the left operand of the add operator, if not:
233207
// - swap inputs of add operator
234208
// - switch add operands sequence to ensure that
235209
// the tensor (sum_tensor) to which FC output is added is the last input.
236-
if (node_sets.count(node->inputs[1].node.get())) {
210+
if (node_set.count(node->inputs[1].node.get())) {
237211
// Example of input_entries reordering for channel-wise quantized graph:
238212
// sum_tensor.data --> fc.data
239213
// fc.data --> fc.weight0
@@ -272,12 +246,12 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
272246
}
273247
}
274248
});
275-
n->inputs = *orig_input_entries;
249+
subgraph_node->inputs = *orig_input_entries;
276250
}
277251
};
278252

279253
} // namespace op
280254
} // namespace mxnet
281255

282256
#endif // if MXNET_USE_ONEDNN == 1
283-
#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
257+
#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_

src/operator/subgraph/dnnl/dnnl_subgraph_property.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#include "dnnl_pow_mul_scalar_property.h"
3131
#include "dnnl_transformer_qk_property.h"
3232
#include "dnnl_transformer_valatt_property.h"
33-
#include "dnnl_fc_sum_fuse.h"
33+
#include "dnnl_fc_sum_fuse_property.h"
3434
#include "dnnl_remove_casts_property.h"
3535

3636
namespace mxnet {

0 commit comments

Comments
 (0)