Skip to content

Commit 989491c

Browse files
authored
[NNAPI EP] Make partitioning stop ops configurable. (#8444)
Enable NNAPI EP partitioning stop ops to be overridden by a session configuration option.
1 parent 892ac9f commit 989491c

File tree

6 files changed

+160
-13
lines changed

6 files changed

+160
-13
lines changed

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,16 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "
6060
// "1": default, thread will spin a number of times before blocking
6161
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
6262
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";
63+
64+
// NNAPI EP keys begin
65+
// Note: These options should be specified prior to appending the NNAPI EP to the session options object in order for
66+
// them to take effect.
67+
68+
// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
69+
// run by the NNAPI EP.
70+
// The value should be a ","-delimited list of op types. For example, "Add,Sub".
71+
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
72+
// exclusion, set the value to "".
73+
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";
74+
75+
// NNAPI EP keys end
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <string_view>
7+
#include <vector>
8+
9+
#include "core/common/common.h"
10+
11+
namespace onnxruntime {
12+
namespace utils {
13+
14+
/**
15+
* Splits a string into substrings delimited by the given delimiter string.
16+
* @param string_to_split The string to split.
17+
* @param delimiter The delimiter string.
18+
* @param keep_empty Whether to keep empty substrings.
19+
* @return The split substrings.
20+
*/
21+
inline std::vector<std::string_view> SplitString(std::string_view string_to_split, std::string_view delimiter,
22+
bool keep_empty = false) {
23+
ORT_ENFORCE(!delimiter.empty(), "delimiter must not be empty");
24+
std::vector<std::string_view> result{};
25+
std::string_view::size_type segment_begin_pos = 0;
26+
while (segment_begin_pos != std::string_view::npos) {
27+
const std::string_view::size_type segment_end_pos = string_to_split.find(delimiter, segment_begin_pos);
28+
const bool is_segment_empty = segment_begin_pos == segment_end_pos || segment_begin_pos == string_to_split.size();
29+
if (!is_segment_empty || keep_empty) {
30+
result.push_back(string_to_split.substr(segment_begin_pos, segment_end_pos - segment_begin_pos));
31+
}
32+
segment_begin_pos = (segment_end_pos == std::string_view::npos)
33+
? segment_end_pos
34+
: segment_end_pos + delimiter.size();
35+
}
36+
return result;
37+
}
38+
39+
} // namespace utils
40+
} // namespace onnxruntime

onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
#include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h"
55

6+
#include "core/common/string_utils.h"
67
#include "core/framework/allocatormgr.h"
78
#include "core/framework/compute_capability.h"
89
#include "core/graph/graph_viewer.h"
10+
#include "core/platform/env.h"
911
#include "core/providers/common.h"
1012
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
1113
#include "core/providers/nnapi/nnapi_builtin/builders/op_support_checker.h"
@@ -20,17 +22,31 @@
2022

2123
namespace onnxruntime {
2224

25+
namespace {
26+
2327
constexpr const char* NNAPI = "Nnapi";
2428

2529
constexpr std::array kDefaultPartitioningStopOps{
2630
"NonMaxSuppression",
2731
};
2832

29-
NnapiExecutionProvider::NnapiExecutionProvider(uint32_t nnapi_flags)
33+
std::unordered_set<std::string> GetPartitioningStopOps(const optional<std::unordered_set<std::string>>& partitioning_stop_ops) {
34+
if (!partitioning_stop_ops.has_value()) {
35+
LOGS_DEFAULT(VERBOSE) << "Using default partitioning stop ops list.";
36+
return std::unordered_set<std::string>(kDefaultPartitioningStopOps.begin(), kDefaultPartitioningStopOps.end());
37+
}
38+
39+
LOGS_DEFAULT(INFO) << "Using partitioning stop ops list from configuration.";
40+
return partitioning_stop_ops.value();
41+
}
42+
43+
} // namespace
44+
45+
NnapiExecutionProvider::NnapiExecutionProvider(uint32_t nnapi_flags,
46+
const optional<std::unordered_set<std::string>>& partitioning_stop_ops)
3047
: IExecutionProvider{onnxruntime::kNnapiExecutionProvider, true},
3148
nnapi_flags_(nnapi_flags),
32-
// TODO make this configurable
33-
partitioning_stop_ops_(kDefaultPartitioningStopOps.begin(), kDefaultPartitioningStopOps.end()) {
49+
partitioning_stop_ops_(GetPartitioningStopOps(partitioning_stop_ops)) {
3450
AllocatorCreationInfo device_info(
3551
[](int) {
3652
return std::make_unique<CPUAllocator>(OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator));

onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include "core/common/optional.h"
67
#include "core/framework/execution_provider.h"
78
#include "core/providers/nnapi/nnapi_provider_factory.h"
89

@@ -13,7 +14,9 @@ class Model;
1314

1415
class NnapiExecutionProvider : public IExecutionProvider {
1516
public:
16-
NnapiExecutionProvider(uint32_t nnapi_flags);
17+
explicit NnapiExecutionProvider(uint32_t nnapi_flags,
18+
const optional<std::unordered_set<std::string>>& partitioning_stop_ops = nullopt);
19+
1720
virtual ~NnapiExecutionProvider();
1821

1922
std::vector<std::unique_ptr<ComputeCapability>>
Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,59 @@
11
// Copyright 2019 JD.com Inc. JD AI
22

33
#include "core/providers/nnapi/nnapi_provider_factory.h"
4-
#include "core/session/abi_session_options_impl.h"
5-
#include "nnapi_builtin/nnapi_execution_provider.h"
64

7-
using namespace onnxruntime;
5+
#include "core/common/optional.h"
6+
#include "core/common/string_utils.h"
7+
#include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h"
8+
#include "core/session/abi_session_options_impl.h"
9+
#include "core/session/onnxruntime_session_options_config_keys.h"
810

911
namespace onnxruntime {
12+
13+
namespace {
1014
struct NnapiProviderFactory : IExecutionProviderFactory {
11-
NnapiProviderFactory(uint32_t nnapi_flags)
12-
: nnapi_flags_(nnapi_flags) {}
15+
NnapiProviderFactory(uint32_t nnapi_flags,
16+
const optional<std::unordered_set<std::string>>& partitioning_stop_ops)
17+
: nnapi_flags_(nnapi_flags),
18+
partitioning_stop_ops_(partitioning_stop_ops) {}
19+
1320
~NnapiProviderFactory() override {}
1421

1522
std::unique_ptr<IExecutionProvider> CreateProvider() override;
16-
uint32_t nnapi_flags_;
23+
24+
private:
25+
const uint32_t nnapi_flags_;
26+
const optional<std::unordered_set<std::string>> partitioning_stop_ops_;
1727
};
1828

1929
std::unique_ptr<IExecutionProvider> NnapiProviderFactory::CreateProvider() {
20-
return std::make_unique<NnapiExecutionProvider>(nnapi_flags_);
30+
return std::make_unique<NnapiExecutionProvider>(nnapi_flags_, partitioning_stop_ops_);
2131
}
2232

33+
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi_Internal(
34+
uint32_t nnapi_flags, const optional<std::unordered_set<std::string>>& partitioning_stop_ops) {
35+
return std::make_shared<NnapiProviderFactory>(nnapi_flags, partitioning_stop_ops);
36+
}
37+
} // namespace
38+
2339
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(uint32_t nnapi_flags) {
24-
return std::make_shared<onnxruntime::NnapiProviderFactory>(nnapi_flags);
40+
return CreateExecutionProviderFactory_Nnapi_Internal(nnapi_flags, nullopt);
2541
}
42+
2643
} // namespace onnxruntime
2744

2845
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, uint32_t nnapi_flags) {
29-
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi(nnapi_flags));
46+
const auto partitioning_stop_ops = [&]() -> onnxruntime::optional<std::unordered_set<std::string>> {
47+
if (std::string partitioning_stop_ops_value{};
48+
options->value.config_options.TryGetConfigEntry(kOrtSessionOptionsConfigNnapiEpPartitioningStopOps,
49+
partitioning_stop_ops_value)) {
50+
const auto partitioning_stop_ops_list = onnxruntime::utils::SplitString(partitioning_stop_ops_value, ",");
51+
return std::unordered_set<std::string>(partitioning_stop_ops_list.begin(), partitioning_stop_ops_list.end());
52+
}
53+
return onnxruntime::nullopt;
54+
}();
55+
56+
options->provider_factories.push_back(
57+
onnxruntime::CreateExecutionProviderFactory_Nnapi_Internal(nnapi_flags, partitioning_stop_ops));
3058
return nullptr;
3159
}

onnxruntime/test/common/string_utils_test.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
#include "core/common/make_string.h"
55
#include "core/common/parse_string.h"
6+
#include "core/common/string_utils.h"
7+
8+
#include <algorithm>
69

710
#include "gtest/gtest.h"
811

@@ -89,5 +92,49 @@ TEST(StringUtilsTest, MakeStringAndTryParseStringWithCustomType) {
8992
ASSERT_EQ(parsed_s, s);
9093
}
9194

95+
TEST(StringUtilsTest, SplitString) {
96+
auto run_test = [](const std::string& string_to_split, const std::string& delimiter,
97+
const std::vector<std::string>& expected_substrings_with_empty) {
98+
SCOPED_TRACE(MakeString("string_to_split: \"", string_to_split, "\", delimiter: \"", delimiter, "\""));
99+
100+
auto test_split = [&](const std::vector<std::string>& expected_substrings, bool keep_empty) {
101+
SCOPED_TRACE(MakeString("keep_empty: ", keep_empty));
102+
103+
const auto actual_substrings = utils::SplitString(string_to_split, delimiter, keep_empty);
104+
ASSERT_EQ(actual_substrings.size(), expected_substrings.size());
105+
for (size_t i = 0; i < actual_substrings.size(); ++i) {
106+
EXPECT_EQ(actual_substrings[i], expected_substrings[i]) << "i=" << i;
107+
}
108+
};
109+
110+
test_split(expected_substrings_with_empty, true);
111+
112+
const std::vector<std::string> expected_substrings_without_empty = [&]() {
113+
std::vector<std::string> result = expected_substrings_with_empty;
114+
result.erase(std::remove_if(result.begin(), result.end(),
115+
[](const std::string& value) { return value.empty(); }),
116+
result.end());
117+
return result;
118+
}();
119+
test_split(expected_substrings_without_empty, false);
120+
};
121+
122+
run_test("a,b,c", ",", {"a", "b", "c"});
123+
run_test(",a,,b,,,c,", ",", {"", "a", "", "b", "", "", "c", ""});
124+
run_test("one_delimiter_two_delimiter_", "_delimiter_", {"one", "two", ""});
125+
run_test("aaaaaaa", "aa", {"", "", "", "a"});
126+
run_test("abcabaabc", "abc", {"", "aba", ""});
127+
run_test("leading,", ",", {"leading", ""});
128+
run_test(",trailing", ",", {"", "trailing"});
129+
run_test("", ",", {""});
130+
run_test(",", ",", {"", ""});
131+
}
132+
133+
#ifndef ORT_NO_EXCEPTIONS
134+
TEST(StringUtilsTest, SplitStringWithEmptyDelimiter) {
135+
EXPECT_THROW(utils::SplitString("a", ""), OnnxRuntimeException);
136+
}
137+
#endif
138+
92139
} // namespace test
93140
} // namespace onnxruntime

0 commit comments

Comments
 (0)