Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

#pragma once

#include <functional>
#include <memory>
#include <utility>

#include "openvino/pass/matcher_pass.hpp"

namespace ov::snippets::pass {
Expand All @@ -24,12 +28,19 @@ class CommonOptimizations : public ov::pass::MatcherPass {
* @ingroup snippets
*/
struct Config {
using TransposeSupportCallback = std::function<bool(const std::shared_ptr<const ov::Node>&)>;

Config(size_t concurrency, bool split_m_dimension)
: m_concurrency(concurrency),
m_split_m_dimension(split_m_dimension) {
OPENVINO_ASSERT(concurrency > 0, "Concurrency should be greater than 0");
}

void set_concurrency(size_t concurrency) {
OPENVINO_ASSERT(concurrency > 0, "Concurrency should be greater than 0");
m_concurrency = concurrency;
}

[[nodiscard]] size_t get_concurrency() const {
return m_concurrency;
}
Expand All @@ -38,10 +49,23 @@ class CommonOptimizations : public ov::pass::MatcherPass {
return m_split_m_dimension;
}

void set_transpose_support_callback(TransposeSupportCallback cb) {
m_transpose_support_cb = std::move(cb);
}

[[nodiscard]] const TransposeSupportCallback& get_transpose_support_callback() const {
return m_transpose_support_cb;
}

private:
size_t m_concurrency = 0;
// True if "SplitDimensionM" optimization is enabled.
bool m_split_m_dimension = true;
// Callback to determine whether a given Transpose is supported inside Subgraph.
// If empty, all Transposes are treated as unsupported.
TransposeSupportCallback m_transpose_support_cb = [](const std::shared_ptr<const ov::Node>&) {
return false;
};
};

explicit CommonOptimizations(const Config& config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ namespace ov::snippets::pass {
class ExtractUnsupportedTransposes : public CommonOptimizations::SubgraphPass {
public:
OPENVINO_RTTI("ExtractUnsupportedTransposes", "0");
ExtractUnsupportedTransposes() : SubgraphPass("ExtractUnsupportedTransposes") {}
explicit ExtractUnsupportedTransposes(CommonOptimizations::Config::TransposeSupportCallback transpose_support_cb)
: SubgraphPass("ExtractUnsupportedTransposes"),
m_transpose_support_cb(std::move(transpose_support_cb)) {}

bool run_on_subgraph(const std::shared_ptr<op::Subgraph>& subgraph) override;

private:
CommonOptimizations::Config::TransposeSupportCallback m_transpose_support_cb;
};

} // namespace ov::snippets::pass
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class SnippetsTokenization : public ov::pass::ModelPass {
TokenizeMHASnippets::Config mha_config,
TokenizeMLPSeqSnippets::Config mlp_seq_config)
: m_tokenization_config(config),
m_common_optimizations_config(common_config),
m_common_optimizations_config(std::move(common_config)),
m_mha_config(std::move(mha_config)),
m_mlp_seq_config(std::move(mlp_seq_config)) {}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
#pragma once

#include <functional>
#include <memory>

#include "openvino/core/node.hpp"
Expand Down Expand Up @@ -50,4 +51,16 @@ std::shared_ptr<ov::snippets::op::Subgraph> tokenize_ordered_nodes(const ov::Nod
* @return The estimated number of body parameters needed for this operation
*/
size_t get_potential_body_params(const std::shared_ptr<ov::Node>& op);

/**
* @brief Builds a transpose support callback suitable for CommonOptimizations configuration.
* The callback returns true for Transpose nodes that are considered supported by Snippets.
* If `include_brgemm_case` is true, the callback additionally allows the specific
* MHA fusion-related transpose order when the Transpose feeds MatMul (Brgemm case).
* Independently of the flag, the decomposed transpose order accepted by MHA tokenization is allowed.
*
* @param include_brgemm_case if true, apply extra MatMul(Brgemm)-related order check
* @return std::function predicate that can be passed to set_transpose_support_callback
*/
std::function<bool(const std::shared_ptr<const ov::Node>&)> make_transpose_support_callback(bool include_brgemm_case);
} // namespace ov::snippets::utils
5 changes: 4 additions & 1 deletion src/common/snippets/src/pass/common_optimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ CommonOptimizations::CommonOptimizations(const CommonOptimizations::Config& conf
// At the moment only non-scalar Constants of FakeQuantize can be inside Subgraph
// so we can enable ExtractConstants pass for quantized models
REGISTER_SNIPPETS_PASS(subgraph_manager, ov::snippets::pass::ExtractConstants, is_quantized);
REGISTER_SNIPPETS_PASS(subgraph_manager, ov::snippets::pass::ExtractUnsupportedTransposes, is_domain_sensitive);
REGISTER_SNIPPETS_PASS(subgraph_manager,
ov::snippets::pass::ExtractUnsupportedTransposes,
is_domain_sensitive,
config.get_transpose_support_callback());
REGISTER_SNIPPETS_PASS(subgraph_manager,
ov::snippets::pass::SplitDimensionM,
is_domain_sensitive && config.get_split_m_dimension(),
Expand Down
19 changes: 7 additions & 12 deletions src/common/snippets/src/pass/extract_unsupported_transposes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "openvino/opsets/opset1.hpp"
#include "snippets/itt.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/pass/mha_tokenization.hpp"

bool ov::snippets::pass::ExtractUnsupportedTransposes::run_on_subgraph(const std::shared_ptr<op::Subgraph>& subgraph) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractUnsupportedTransposes");
Expand All @@ -37,23 +36,19 @@ bool ov::snippets::pass::ExtractUnsupportedTransposes::run_on_subgraph(const std
continue;
}

const auto& order = ov::as_type_ptr<opset1::Constant>(transpose->get_input_node_shared_ptr(1));
OPENVINO_ASSERT(order, "ExtractUnsupportedTransposes expects Transposes with constant order");

const auto order_value = order->cast_vector<int>();
const auto transpose_child = *(transpose->get_output_target_inputs(0).begin());
const auto is_brgemm_case = ov::is_type<opset1::MatMul>(transpose_child.get_node()->shared_from_this());
// If Transpose is supported (can be decomposed or fused into Brgemm), skip
// [116568]: It should be covered by TransposeDecomposition::is_supported or FuseTransposeBrgemm::is_supported
if (order_value.size() > 2 &&
((is_brgemm_case && TokenizeMHASnippets::get_fusion_transpose_order(order_value.size()) == order_value) ||
(TokenizeMHASnippets::get_decomposed_transpose_order(order_value.size()) == order_value))) {
OPENVINO_ASSERT(m_transpose_support_cb,
"Transpose support callback is not set in ExtractUnsupportedTransposes pass");
bool is_supported = m_transpose_support_cb(transpose);
if (is_supported) {
continue;
}

// If the transpose isn't supported - we have to extract it from Subgraph
transpose->set_argument(0, subgraph->input_value(i));
subgraph->set_argument(i, transpose);
OPENVINO_ASSERT(!transpose->get_output_target_inputs(0).empty(),
"ExtractUnsupportedTransposes pass supports only Transpose nodes with at least one consumer");
const auto transpose_child = *(transpose->get_output_target_inputs(0).begin());
transpose_child.replace_source_output(parameter);
parameter->set_partial_shape(transpose->get_output_partial_shape(0));
updated = true;
Expand Down
40 changes: 40 additions & 0 deletions src/common/snippets/src/utils/tokenization_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <climits>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <map>
#include <memory>
Expand All @@ -32,11 +33,14 @@
#include "openvino/core/type.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/fake_quantize.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/result.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/attr_types.hpp"
#include "openvino/opsets/opset1.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/tokenization_config.hpp"
#include "snippets/remarks.hpp"
Expand Down Expand Up @@ -81,6 +85,42 @@ auto outputs_are_not_broadcastable(const std::shared_ptr<const Node>& node) -> b
}
} // namespace

std::function<bool(const std::shared_ptr<const ov::Node>&)> make_transpose_support_callback(bool include_brgemm_case) {
using ov::op::v0::Constant;
using ov::op::v0::MatMul;
using ov::op::v1::Transpose;

return [include_brgemm_case](const std::shared_ptr<const ov::Node>& node) -> bool {
const auto transpose = ov::as_type_ptr<const Transpose>(node->shared_from_this());
if (!transpose) {
return false;
}
const auto order = ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
if (!order) {
return false;
}
const auto order_value = order->cast_vector<int>();
if (order_value.size() <= 2) {
return false;
}

bool allow = false;
if (include_brgemm_case) {
const auto& outputs = transpose->get_output_target_inputs(0);
if (!outputs.empty()) {
const auto child_node = outputs.begin()->get_node()->shared_from_this();
const bool is_brgemm_case = ov::is_type<MatMul>(child_node);
allow = allow || (is_brgemm_case && ov::snippets::pass::TokenizeMHASnippets::get_fusion_transpose_order(
order_value.size()) == order_value);
}
}
// Always allow decomposed order accepted by MHA tokenization
allow = allow || (ov::snippets::pass::TokenizeMHASnippets::get_decomposed_transpose_order(order_value.size()) ==
order_value);
return allow;
};
}

bool tokenize_node(const std::shared_ptr<ov::Node>& node, const TokenizationConfig& config) {
const auto getFusedNames = [](const std::shared_ptr<Node>& n) -> std::string {
auto rt_info = n->get_rt_info();
Expand Down
27 changes: 18 additions & 9 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
false);
common_config = ov::snippets::pass::CommonOptimizations::Config(24, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(24);
execute_and_validate_function(*this, f);
}

Expand All @@ -162,7 +163,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{4, 32, 12, 64}, {12, 1, 64, 128}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}},
true);
common_config = ov::snippets::pass::CommonOptimizations::Config(16, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(16);
execute_and_validate_function(*this, f);
}

Expand All @@ -171,7 +173,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
false);
common_config = ov::snippets::pass::CommonOptimizations::Config(60, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(60);
execute_and_validate_function(*this, f);
}

Expand All @@ -180,46 +183,52 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
true);
common_config = ov::snippets::pass::CommonOptimizations::Config(60, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(60);
execute_and_validate_function(*this, f);
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) {
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{10, 18, 512, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}});
common_config = ov::snippets::pass::CommonOptimizations::Config(18, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(18);
execute_and_validate_function(*this, f);
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) {
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{5, 30, 32}, {5, 32, 30}, {5, 30, 32}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{5, 10, 3, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 30, 32}});
common_config = ov::snippets::pass::CommonOptimizations::Config(32, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(32);
execute_and_validate_function(*this, f);
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_4D_SplitM_DynamicParameter) {
const auto &f = MHAFunction(std::vector<PartialShape>{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 128, -1}, {1, 128, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), false, false);
common_config = ov::snippets::pass::CommonOptimizations::Config(32, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(32);
execute_and_validate_function(*this, f);
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}},
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {1, 2, 256, 64}, {1, 1, 1, 64},
{8, 1, 64, 512}, {8, 512, 512}});
common_config = ov::snippets::pass::CommonOptimizations::Config(16, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(16);
execute_and_validate_function(*this, f);
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM_ScalarParams) {
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1}, {64}, {8, 64, 512}},
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {}, {},
{8, 1, 64, 512}, {8, 512, 512}});
common_config = ov::snippets::pass::CommonOptimizations::Config(16, true);
common_config = get_default_common_optimizations_config();
common_config.set_concurrency(16);
execute_and_validate_function(*this, f);
}

Expand Down
11 changes: 10 additions & 1 deletion src/common/snippets/tests/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@

#include <limits>

#include "openvino/op/constant.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/transpose.hpp"
#include "snippets/pass/common_optimizations.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/mlp_seq_tokenization.hpp"
#include "snippets/pass/tokenization_config.hpp"
#include "snippets/utils/tokenization_utils.hpp"

namespace ov {
namespace test {
Expand All @@ -22,7 +26,12 @@ TokenizationConfig get_default_tokenization_config() {
}

CommonOptimizations::Config get_default_common_optimizations_config() {
static const CommonOptimizations::Config conf(1, true);
static CommonOptimizations::Config conf(1, true);
static bool initialized = false;
if (!initialized) {
conf.set_transpose_support_callback(ov::snippets::utils::make_transpose_support_callback(true));
initialized = true;
}
return conf;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
# include "openvino/op/subtract.hpp"
# include "snippets/pass/common_optimizations.hpp"
# include "snippets/pass/split_dimension_m.hpp"
# include "snippets/utils/tokenization_utils.hpp"
# include "transformations/common_optimizations/rms_fusion.hpp"
# include "transformations/cpu_opset/common/op/sdpa.hpp"
# include "transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.hpp"
Expand Down Expand Up @@ -266,6 +267,7 @@
# include "low_precision/reduce_min.hpp"
# include "low_precision/reduce_sum.hpp"
# include "openvino/opsets/opset1_decl.hpp"
# include "snippets/utils/tokenization_utils.hpp"
# include "transformations/cpu_opset/arm/pass/convert_group_conv.hpp"
# include "transformations/cpu_opset/arm/pass/convert_group_conv1d.hpp"
# include "transformations/cpu_opset/arm/pass/convert_reduce_multi_axis.hpp"
Expand Down Expand Up @@ -1226,6 +1228,17 @@ void Transformations::MainSnippets() {
// Config::SnippetsMode::IgnoreCallback
bool split_m_dimension = !ignoreCallback;
CommonOptimizations::Config common_optimizations_config(concurrency, split_m_dimension);
#if defined(OPENVINO_ARCH_X86_64)
common_optimizations_config.set_transpose_support_callback(
ov::snippets::utils::make_transpose_support_callback(true));
#elif defined(OPENVINO_ARCH_ARM64)
common_optimizations_config.set_transpose_support_callback(
ov::snippets::utils::make_transpose_support_callback(false));
#else
common_optimizations_config.set_transpose_support_callback([](const std::shared_ptr<const ov::Node>&) -> bool {
return false;
});
#endif

// [111813]: At the moment Snippets supports Transpose on output of MHA pattern only if it is an one node between
// MatMul and Result. However there may be Convert [f32->bf16] before Result since:
Expand Down
Loading
Loading