Skip to content

Commit 997c43b

Browse files
authored
[Snippets][CPU] Make ExtractUnsupportedTransposes HW dependent (#32676)
### Details: Introduce a callback function for `ExtractUnsupportedTransposes` pass as a part of `CommonOptimizations::Config` to customize pass behavior depending on Transpose support. For example, ARM64 platform supports transpose decomposition, but MatMul with Transpose A/B is not supported so far. Rest of the (potential) platforms mark Transpose as not supported completely An alternative approach for #32592 ### Tickets: - 176061
1 parent af827d8 commit 997c43b

File tree

16 files changed

+183
-36
lines changed

16 files changed

+183
-36
lines changed

src/common/snippets/include/snippets/pass/common_optimizations.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
#pragma once
66

7+
#include <functional>
8+
#include <memory>
9+
#include <utility>
10+
711
#include "openvino/pass/matcher_pass.hpp"
812

913
namespace ov::snippets::pass {
@@ -24,12 +28,19 @@ class CommonOptimizations : public ov::pass::MatcherPass {
2428
* @ingroup snippets
2529
*/
2630
struct Config {
31+
using TransposeSupportCallback = std::function<bool(const std::shared_ptr<const ov::Node>&)>;
32+
2733
Config(size_t concurrency, bool split_m_dimension)
2834
: m_concurrency(concurrency),
2935
m_split_m_dimension(split_m_dimension) {
3036
OPENVINO_ASSERT(concurrency > 0, "Concurrency should be greater than 0");
3137
}
3238

39+
void set_concurrency(size_t concurrency) {
40+
OPENVINO_ASSERT(concurrency > 0, "Concurrency should be greater than 0");
41+
m_concurrency = concurrency;
42+
}
43+
3344
[[nodiscard]] size_t get_concurrency() const {
3445
return m_concurrency;
3546
}
@@ -38,10 +49,23 @@ class CommonOptimizations : public ov::pass::MatcherPass {
3849
return m_split_m_dimension;
3950
}
4051

52+
void set_transpose_support_callback(TransposeSupportCallback cb) {
53+
m_transpose_support_cb = std::move(cb);
54+
}
55+
56+
[[nodiscard]] const TransposeSupportCallback& get_transpose_support_callback() const {
57+
return m_transpose_support_cb;
58+
}
59+
4160
private:
4261
size_t m_concurrency = 0;
4362
// True if "SplitDimensionM" optimization is enabled.
4463
bool m_split_m_dimension = true;
64+
// Callback to determine whether a given Transpose is supported inside Subgraph.
65+
// If empty, all Transposes are treated as unsupported.
66+
TransposeSupportCallback m_transpose_support_cb = [](const std::shared_ptr<const ov::Node>&) {
67+
return false;
68+
};
4569
};
4670

4771
explicit CommonOptimizations(const Config& config);

src/common/snippets/include/snippets/pass/extract_unsupported_transposes.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@ namespace ov::snippets::pass {
2121
class ExtractUnsupportedTransposes : public CommonOptimizations::SubgraphPass {
2222
public:
2323
OPENVINO_RTTI("ExtractUnsupportedTransposes", "0");
24-
ExtractUnsupportedTransposes() : SubgraphPass("ExtractUnsupportedTransposes") {}
24+
explicit ExtractUnsupportedTransposes(CommonOptimizations::Config::TransposeSupportCallback transpose_support_cb)
25+
: SubgraphPass("ExtractUnsupportedTransposes"),
26+
m_transpose_support_cb(std::move(transpose_support_cb)) {}
2527

2628
bool run_on_subgraph(const std::shared_ptr<op::Subgraph>& subgraph) override;
29+
30+
private:
31+
CommonOptimizations::Config::TransposeSupportCallback m_transpose_support_cb;
2732
};
2833

2934
} // namespace ov::snippets::pass

src/common/snippets/include/snippets/pass/tokenization.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class SnippetsTokenization : public ov::pass::ModelPass {
7979
TokenizeMHASnippets::Config mha_config,
8080
TokenizeMLPSeqSnippets::Config mlp_seq_config)
8181
: m_tokenization_config(config),
82-
m_common_optimizations_config(common_config),
82+
m_common_optimizations_config(std::move(common_config)),
8383
m_mha_config(std::move(mha_config)),
8484
m_mlp_seq_config(std::move(mlp_seq_config)) {}
8585
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;

src/common/snippets/include/snippets/utils/tokenization_utils.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
*/
99
#pragma once
1010

11+
#include <functional>
1112
#include <memory>
1213

1314
#include "openvino/core/node.hpp"
@@ -50,4 +51,16 @@ std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes(const ov::Nod
5051
* @return The estimated number of body parameters needed for this operation
5152
*/
5253
size_t get_potential_body_params(const std::shared_ptr<ov::Node>& op);
54+
55+
/**
56+
* @brief Builds a transpose support callback suitable for CommonOptimizations configuration.
57+
* The callback returns true for Transpose nodes that are considered supported by Snippets.
58+
* If `include_brgemm_case` is true, the callback additionally allows the specific
59+
* MHA fusion-related transpose order when the Transpose feeds MatMul (Brgemm case).
60+
* Independently of the flag, the decomposed transpose order accepted by MHA tokenization is allowed.
61+
*
62+
* @param include_brgemm_case if true, apply extra MatMul(Brgemm)-related order check
63+
* @return std::function predicate that can be passed to set_transpose_support_callback
64+
*/
65+
std::function<bool(const std::shared_ptr<const ov::Node>&)> make_transpose_support_callback(bool include_brgemm_case);
5366
} // namespace ov::snippets::utils

src/common/snippets/src/pass/common_optimizations.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ CommonOptimizations::CommonOptimizations(const CommonOptimizations::Config& conf
5656
// At the moment only non-scalar Constants of FakeQuantize can be inside Subgraph
5757
// so we can enable ExtractConstants pass for quantized models
5858
REGISTER_SNIPPETS_PASS(subgraph_manager, ov::snippets::pass::ExtractConstants, is_quantized);
59-
REGISTER_SNIPPETS_PASS(subgraph_manager, ov::snippets::pass::ExtractUnsupportedTransposes, is_domain_sensitive);
59+
REGISTER_SNIPPETS_PASS(subgraph_manager,
60+
ov::snippets::pass::ExtractUnsupportedTransposes,
61+
is_domain_sensitive,
62+
config.get_transpose_support_callback());
6063
REGISTER_SNIPPETS_PASS(subgraph_manager,
6164
ov::snippets::pass::SplitDimensionM,
6265
is_domain_sensitive && config.get_split_m_dimension(),

src/common/snippets/src/pass/extract_unsupported_transposes.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "openvino/opsets/opset1.hpp"
1313
#include "snippets/itt.hpp"
1414
#include "snippets/op/subgraph.hpp"
15-
#include "snippets/pass/mha_tokenization.hpp"
1615

1716
bool ov::snippets::pass::ExtractUnsupportedTransposes::run_on_subgraph(const std::shared_ptr<op::Subgraph>& subgraph) {
1817
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractUnsupportedTransposes");
@@ -37,23 +36,19 @@ bool ov::snippets::pass::ExtractUnsupportedTransposes::run_on_subgraph(const std
3736
continue;
3837
}
3938

40-
const auto& order = ov::as_type_ptr<opset1::Constant>(transpose->get_input_node_shared_ptr(1));
41-
OPENVINO_ASSERT(order, "ExtractUnsupportedTransposes expects Transposes with constant order");
42-
43-
const auto order_value = order->cast_vector<int>();
44-
const auto transpose_child = *(transpose->get_output_target_inputs(0).begin());
45-
const auto is_brgemm_case = ov::is_type<opset1::MatMul>(transpose_child.get_node()->shared_from_this());
46-
// If Transpose is supported (can be decomposed or fused into Brgemm), skip
47-
// [116568]: It should be covered by TransposeDecomposition::is_supported or FuseTransposeBrgemm::is_supported
48-
if (order_value.size() > 2 &&
49-
((is_brgemm_case && TokenizeMHASnippets::get_fusion_transpose_order(order_value.size()) == order_value) ||
50-
(TokenizeMHASnippets::get_decomposed_transpose_order(order_value.size()) == order_value))) {
39+
OPENVINO_ASSERT(m_transpose_support_cb,
40+
"Transpose support callback is not set in ExtractUnsupportedTransposes pass");
41+
bool is_supported = m_transpose_support_cb(transpose);
42+
if (is_supported) {
5143
continue;
5244
}
5345

5446
// If the transpose isn't supported - we have to extract it from Subgraph
5547
transpose->set_argument(0, subgraph->input_value(i));
5648
subgraph->set_argument(i, transpose);
49+
OPENVINO_ASSERT(!transpose->get_output_target_inputs(0).empty(),
50+
"ExtractUnsupportedTransposes pass supports only Transpose nodes with at least one consumer");
51+
const auto transpose_child = *(transpose->get_output_target_inputs(0).begin());
5752
transpose_child.replace_source_output(parameter);
5853
parameter->set_partial_shape(transpose->get_output_partial_shape(0));
5954
updated = true;

src/common/snippets/src/utils/tokenization_utils.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <climits>
1010
#include <cstddef>
1111
#include <cstdint>
12+
#include <functional>
1213
#include <iterator>
1314
#include <map>
1415
#include <memory>
@@ -32,11 +33,14 @@
3233
#include "openvino/core/type.hpp"
3334
#include "openvino/op/constant.hpp"
3435
#include "openvino/op/fake_quantize.hpp"
36+
#include "openvino/op/matmul.hpp"
3537
#include "openvino/op/parameter.hpp"
3638
#include "openvino/op/result.hpp"
39+
#include "openvino/op/transpose.hpp"
3740
#include "openvino/op/util/attr_types.hpp"
3841
#include "openvino/opsets/opset1.hpp"
3942
#include "snippets/op/subgraph.hpp"
43+
#include "snippets/pass/mha_tokenization.hpp"
4044
#include "snippets/pass/tokenization.hpp"
4145
#include "snippets/pass/tokenization_config.hpp"
4246
#include "snippets/remarks.hpp"
@@ -81,6 +85,37 @@ auto outputs_are_not_broadcastable(const std::shared_ptr<const Node>& node) -> b
8185
}
8286
} // namespace
8387

88+
std::function<bool(const std::shared_ptr<const ov::Node>&)> make_transpose_support_callback(bool include_brgemm_case) {
89+
using ov::op::v0::Constant;
90+
using ov::op::v0::MatMul;
91+
using ov::op::v1::Transpose;
92+
93+
return [include_brgemm_case](const std::shared_ptr<const ov::Node>& node) -> bool {
94+
const auto transpose = ov::as_type_ptr<const Transpose>(node->shared_from_this());
95+
OPENVINO_ASSERT(transpose, "make_transpose_support_callback expects a Transpose node");
96+
const auto order = ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
97+
OPENVINO_ASSERT(order, "make_transpose_support_callback expects a Constant order input");
98+
const auto order_value = order->cast_vector<int>();
99+
if (order_value.size() <= 2) {
100+
return false;
101+
}
102+
103+
bool allow = false;
104+
if (include_brgemm_case) {
105+
const auto& outputs = transpose->get_output_target_inputs(0);
106+
OPENVINO_ASSERT(!outputs.empty(), "Transpose should have at least one output consumer");
107+
const auto child_node = outputs.begin()->get_node()->shared_from_this();
108+
const bool is_brgemm_case = ov::is_type<MatMul>(child_node);
109+
allow = allow || (is_brgemm_case && ov::snippets::pass::TokenizeMHASnippets::get_fusion_transpose_order(
110+
order_value.size()) == order_value);
111+
}
112+
// Always allow decomposed order accepted by MHA tokenization
113+
allow = allow || (ov::snippets::pass::TokenizeMHASnippets::get_decomposed_transpose_order(order_value.size()) ==
114+
order_value);
115+
return allow;
116+
};
117+
}
118+
84119
bool tokenize_node(const std::shared_ptr<ov::Node>& node, const TokenizationConfig& config) {
85120
const auto getFusedNames = [](const std::shared_ptr<Node>& n) -> std::string {
86121
auto rt_info = n->get_rt_info();

src/common/snippets/tests/src/pass/mha_tokenization.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
153153
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
154154
std::vector<Shape>{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
155155
false);
156-
common_config = ov::snippets::pass::CommonOptimizations::Config(24, true);
156+
common_config = get_default_common_optimizations_config();
157+
common_config.set_concurrency(24);
157158
execute_and_validate_function(*this, f);
158159
}
159160

@@ -162,7 +163,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
162163
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
163164
std::vector<Shape>{{4, 32, 12, 64}, {12, 1, 64, 128}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}},
164165
true);
165-
common_config = ov::snippets::pass::CommonOptimizations::Config(16, true);
166+
common_config = get_default_common_optimizations_config();
167+
common_config.set_concurrency(16);
166168
execute_and_validate_function(*this, f);
167169
}
168170

@@ -171,7 +173,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
171173
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
172174
std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
173175
false);
174-
common_config = ov::snippets::pass::CommonOptimizations::Config(60, true);
176+
common_config = get_default_common_optimizations_config();
177+
common_config.set_concurrency(60);
175178
execute_and_validate_function(*this, f);
176179
}
177180

@@ -180,46 +183,52 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
180183
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
181184
std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
182185
true);
183-
common_config = ov::snippets::pass::CommonOptimizations::Config(60, true);
186+
common_config = get_default_common_optimizations_config();
187+
common_config.set_concurrency(60);
184188
execute_and_validate_function(*this, f);
185189
}
186190

187191
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) {
188192
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}},
189193
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
190194
std::vector<Shape>{{10, 18, 512, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}});
191-
common_config = ov::snippets::pass::CommonOptimizations::Config(18, true);
195+
common_config = get_default_common_optimizations_config();
196+
common_config.set_concurrency(18);
192197
execute_and_validate_function(*this, f);
193198
}
194199

195200
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) {
196201
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{5, 30, 32}, {5, 32, 30}, {5, 30, 32}},
197202
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
198203
std::vector<Shape>{{5, 10, 3, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 30, 32}});
199-
common_config = ov::snippets::pass::CommonOptimizations::Config(32, true);
204+
common_config = get_default_common_optimizations_config();
205+
common_config.set_concurrency(32);
200206
execute_and_validate_function(*this, f);
201207
}
202208

203209
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_4D_SplitM_DynamicParameter) {
204210
const auto &f = MHAFunction(std::vector<PartialShape>{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 128, -1}, {1, 128, 16, 64}},
205211
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), false, false);
206-
common_config = ov::snippets::pass::CommonOptimizations::Config(32, true);
212+
common_config = get_default_common_optimizations_config();
213+
common_config.set_concurrency(32);
207214
execute_and_validate_function(*this, f);
208215
}
209216

210217
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
211218
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}},
212219
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {1, 2, 256, 64}, {1, 1, 1, 64},
213220
{8, 1, 64, 512}, {8, 512, 512}});
214-
common_config = ov::snippets::pass::CommonOptimizations::Config(16, true);
221+
common_config = get_default_common_optimizations_config();
222+
common_config.set_concurrency(16);
215223
execute_and_validate_function(*this, f);
216224
}
217225

218226
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM_ScalarParams) {
219227
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1}, {64}, {8, 64, 512}},
220228
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {}, {},
221229
{8, 1, 64, 512}, {8, 512, 512}});
222-
common_config = ov::snippets::pass::CommonOptimizations::Config(16, true);
230+
common_config = get_default_common_optimizations_config();
231+
common_config.set_concurrency(16);
223232
execute_and_validate_function(*this, f);
224233
}
225234

src/common/snippets/tests/src/utils.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "snippets/pass/mha_tokenization.hpp"
1111
#include "snippets/pass/mlp_seq_tokenization.hpp"
1212
#include "snippets/pass/tokenization_config.hpp"
13+
#include "snippets/utils/tokenization_utils.hpp"
1314

1415
namespace ov {
1516
namespace test {
@@ -22,7 +23,12 @@ TokenizationConfig get_default_tokenization_config() {
2223
}
2324

2425
CommonOptimizations::Config get_default_common_optimizations_config() {
25-
static const CommonOptimizations::Config conf(1, true);
26+
static CommonOptimizations::Config conf(1, true);
27+
static bool initialized = false;
28+
if (!initialized) {
29+
conf.set_transpose_support_callback(ov::snippets::utils::make_transpose_support_callback(true));
30+
initialized = true;
31+
}
2632
return conf;
2733
}
2834

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@
233233
# include "openvino/op/subtract.hpp"
234234
# include "snippets/pass/common_optimizations.hpp"
235235
# include "snippets/pass/split_dimension_m.hpp"
236+
# include "snippets/utils/tokenization_utils.hpp"
236237
# include "transformations/common_optimizations/rms_fusion.hpp"
237238
# include "transformations/cpu_opset/common/op/sdpa.hpp"
238239
# include "transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.hpp"
@@ -266,6 +267,7 @@
266267
# include "low_precision/reduce_min.hpp"
267268
# include "low_precision/reduce_sum.hpp"
268269
# include "openvino/opsets/opset1_decl.hpp"
270+
# include "snippets/utils/tokenization_utils.hpp"
269271
# include "transformations/cpu_opset/arm/pass/convert_group_conv.hpp"
270272
# include "transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp"
271273
# include "transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp"
@@ -1226,6 +1228,17 @@ void Transformations::MainSnippets() {
12261228
// Config::SnippetsMode::IgnoreCallback
12271229
bool split_m_dimension = !ignoreCallback;
12281230
CommonOptimizations::Config common_optimizations_config(concurrency, split_m_dimension);
1231+
#if defined(OPENVINO_ARCH_X86_64)
1232+
common_optimizations_config.set_transpose_support_callback(
1233+
ov::snippets::utils::make_transpose_support_callback(true));
1234+
#elif defined(OPENVINO_ARCH_ARM64)
1235+
common_optimizations_config.set_transpose_support_callback(
1236+
ov::snippets::utils::make_transpose_support_callback(false));
1237+
#else
1238+
common_optimizations_config.set_transpose_support_callback([](const std::shared_ptr<const ov::Node>&) -> bool {
1239+
return false;
1240+
});
1241+
#endif
12291242

12301243
// [111813]: At the moment Snippets supports Transpose on output of MHA pattern only if it is an one node between
12311244
// MatMul and Result. However there may be Convert [f32->bf16] before Result since:

0 commit comments

Comments
 (0)