Skip to content

Commit c1af16e

Browse files
committed
coreml: Add a few missing operators
1 parent 860d085 commit c1af16e

File tree

5 files changed

+36
-8
lines changed

5 files changed

+36
-8
lines changed

onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
118118
} else if (op_type == "PRelu") {
119119
coreml_op_type = "prelu";
120120
add_alpha = true;
121+
} else if (op_type == "Softplus") {
122+
coreml_op_type = "softplus";
123+
} else if (op_type == "Elu") {
124+
coreml_op_type = "elu";
125+
add_alpha = true;
121126
} else {
122127
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
123128
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
@@ -141,7 +146,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
141146
}
142147
} else {
143148
NodeAttrHelper helper(node);
144-
const auto alpha = helper.Get("alpha", 0.01f);
149+
const auto alpha = helper.Get("alpha", "Elu" == op_type ? 1.0f : 0.01f);
145150

146151
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
147152
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha));
@@ -259,8 +264,10 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp
259264
const logging::Logger& logger) const {
260265
const auto& op_type = node.OpType();
261266

262-
if (op_type == "Gelu" && !input_params.create_mlprogram) {
263-
return false;
267+
if (!input_params.create_mlprogram) {
268+
if (op_type == "Gelu" || op_type == "Softplus" || op_type == "Elu") {
269+
return false;
270+
}
264271
}
265272
if (op_type == "PRelu") {
266273
return IsPReluOpSupported(node, input_params, logger);
@@ -269,8 +276,13 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp
269276
return true;
270277
}
271278

272-
int ActivationOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
273-
// All ops opset 5- uses consumed_inputs attribute which is not supported for now
279+
int ActivationOpBuilder::GetMinSupportedOpSet(const Node& node) const {
280+
const auto& op_type(node.OpType());
281+
// Softplus was unmodified from opset 1 to 21 (with no attributes).
282+
if (op_type == "Softplus") {
283+
return 1;
284+
}
285+
// All other ops opset 5- uses consumed_inputs attribute which is not supported for now.
274286
return 6;
275287
}
276288

@@ -286,6 +298,8 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration
286298
"PRelu",
287299
"LeakyRelu",
288300
"Gelu",
301+
"Softplus",
302+
"Elu",
289303
};
290304

291305
op_registrations.builders.push_back(std::make_unique<ActivationOpBuilder>());

onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInpu
105105
bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const {
106106
auto since_version = node.SinceVersion();
107107
if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) {
108-
LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset ["
108+
LOGS(logger, VERBOSE) << node.OpType() << " is only supported for opset ["
109109
<< GetMinSupportedOpSet(node) << ", "
110110
<< GetMaxSupportedOpSet(node) << "]";
111111
return false;

onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
3737
coreml_op_type = "erf";
3838
} else if (op_type == "Round") {
3939
coreml_op_type = "round";
40+
} else if (op_type == "Exp") {
41+
coreml_op_type = "exp";
42+
} else if (op_type == "Log") {
43+
coreml_op_type = "log";
4044
} else {
4145
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
4246
"UnaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type);
@@ -79,8 +83,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
7983

8084
bool UnaryOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
8185
const logging::Logger& /*logger*/) const {
82-
if (!input_params.create_mlprogram && (node.OpType() == "Erf" || node.OpType() == "Round")) {
83-
return false;
86+
if (!input_params.create_mlprogram) {
87+
if (node.OpType() == "Erf" || node.OpType() == "Round" || node.OpType() == "Exp" || node.OpType() == "Log") {
88+
return false;
89+
}
8490
}
8591
return true;
8692
}

onnxruntime/core/providers/coreml/builders/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
2222
CreateActivationOpBuilder("PRelu", op_registrations);
2323
CreateActivationOpBuilder("LeakyRelu", op_registrations);
2424
CreateActivationOpBuilder("Gelu", op_registrations);
25+
CreateActivationOpBuilder("Softplus", op_registrations);
26+
CreateActivationOpBuilder("Elu", op_registrations);
2527

2628
// Unary ops
2729
CreateUnaryOpBuilder("Erf", op_registrations);
2830
CreateUnaryOpBuilder("Reciprocal", op_registrations);
2931
CreateUnaryOpBuilder("Round", op_registrations);
3032
CreateUnaryOpBuilder("Sqrt", op_registrations);
33+
CreateUnaryOpBuilder("Exp", op_registrations);
34+
CreateUnaryOpBuilder("Log", op_registrations);
3135

3236
// Binary elementwise ops
3337
CreateBinaryOpBuilder("Add", op_registrations);

tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
1313
|ai.onnx:ConvTranspose|Weight and bias must be constant.<br/>padding_type of SAME_UPPER/SAME_LOWER is not supported.<br/>kernel_shape must have default values.<br/>output_shape is not supported.<br/>output_padding must have default values.|
1414
|ai.onnx:DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.|
1515
|ai.onnx:Div||
16+
|ai.onnx:Elu||
1617
|ai.onnx:Erf||
18+
|ai.onnx:Exp||
1719
|ai.onnx:Gemm|Input B must be constant.|
1820
|ai.onnx:Gelu||
1921
|ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
@@ -23,6 +25,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
2325
|ai.onnx:InstanceNormalization||
2426
|ai.onnx:LayerNormalization||
2527
|ai.onnx:LeakyRelu||
28+
|ai.onnx:Log||
2629
|ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.|
2730
|ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
2831
|ai.onnx:Max||
@@ -39,6 +42,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
3942
|ai.onnx:Round||
4043
|ai.onnx:Shape||
4144
|ai.onnx:Slice|starts/ends/axes/steps must be constant initializers.|
45+
|ai.onnx:Softplus||
4246
|ai.onnx:Split|If provided, `splits` must be constant.|
4347
|ai.onnx:Sub||
4448
|ai.onnx:Sigmoid||

0 commit comments

Comments
 (0)