Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5a87a0c

Browse files
committed
ci test
1 parent e36c9f0 commit 5a87a0c

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/operator/subgraph/dnnl/dnnl_remove_casts_property.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class SgDNNLRemoveCastsSelector : public SubgraphSelectorV2 {
9595
}
9696

9797
void Reset() override {
98-
status_ = kFail;
98+
status_ = kExpand;
9999
castDtype = -1;
100100
}
101101
};
@@ -105,7 +105,7 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty {
105105
SgDNNLRemoveCastsProperty() {}
106106

107107
static SubgraphPropertyPtr Create() {
108-
static const std::string& name = "Remove casts optimization pass";
108+
static const std::string& name = "Remove Casts optimization pass";
109109
auto property = std::make_shared<SgDNNLRemoveCastsProperty>();
110110
property->SetAttr<std::string>("property_name", name);
111111
property->SetAttr<bool>("inference_only", true);
@@ -137,6 +137,15 @@ class SgDNNLRemoveCastsProperty : public SubgraphProperty {
137137
auto selector = std::make_shared<SgDNNLRemoveCastsSelector>();
138138
return selector;
139139
}
140+
141+
void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node,
142+
std::vector<nnvm::NodeEntry*>* output_entries) const override {
143+
// Connect all extern output entries to output[0]
144+
for (size_t i = 0; i < output_entries->size(); ++i) {
145+
auto entry_ptr = output_entries->at(i);
146+
*entry_ptr = nnvm::NodeEntry{subgraph_node, entry_ptr->index, 0};
147+
}
148+
}
140149
};
141150

142151
} // namespace op

0 commit comments

Comments
 (0)