Skip to content

Commit b56011e

Browse files
maxnickv-Golubev
andauthored
[CPU] Introduce GatherMatmul operation to optimize MoE pattern (#32450)
### Details: In this PR we introduce yet another operation "GatherMatmu", which essentially does gemv operations over the current tokens and the active experts. As the first step, we perform gemv operation using the dnnl::inner_product. But obviously this solution is suboptimal, as it doesn't give a fine grain control over parallelization, and in the case of many tokens being processed by a specific expert (prefill), having gemm operation may be more optimal as the tokens may be batched and we can do SIMD level parallelization by tokens as well. Also this PR contains all the essential transformations that allow to enable a few common MoE patterns. MoE pattern matcher is based on #32183 Related oneDNN fork PR: openvinotoolkit/oneDNN#292 ### Tickets: - CVS-171910 --------- Co-authored-by: Vladislav Golubev <[email protected]>
1 parent fe33012 commit b56011e

File tree

31 files changed

+3454
-209
lines changed

31 files changed

+3454
-209
lines changed

src/common/transformations/include/transformations/op_conversions/convert_fc_to_compressed.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,14 @@
44

55
#pragma once
66

7+
#include <memory>
8+
#include <tuple>
9+
#include <vector>
10+
711
#include "openvino/pass/matcher_pass.hpp"
12+
#include "openvino/pass/pattern/matcher.hpp"
813
#include "ov_ops/fully_connected.hpp"
14+
#include "transformations/pattern_blocks/compressed_weights_block.hpp"
915
#include "transformations_visibility.hpp"
1016

1117
namespace ov {
@@ -27,4 +33,26 @@ class ov::pass::ConvertFullyConnectedToFullyConnectedCompressed : public ov::pas
2733
const std::vector<ov::element::Type>& supported_weights_types,
2834
SupportsPredicate supports_config = nullptr,
2935
bool convert_u4zp_to_u8 = false);
36+
37+
/**
38+
* @brief Processes compressed weights from a pattern block and prepares them for compressed operations.
39+
*
40+
* @param weights_block The CompressedWeightsBlock pattern containing the weight compression graph
41+
* @param pattern_map The pattern value map from the matcher containing matched nodes
42+
* @param convert_u4zp_to_u8 Flag indicating whether to convert u4 zero points to u8
43+
* @param has_transpose Flag indicating whether the weights require transpose operation
44+
* @param grouped Flag indicating whether the compression uses grouped quantization
45+
* @param batched_weights Flag indicating whether the weights have a batch dimension
46+
* @param result_nodes Output vector to collect intermediate nodes created during processing
47+
*
48+
* @return A tuple containing processed compressed weights, decompression scales, and decompression zero points.
49+
*/
50+
static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>>
51+
process_compressed_weights(const std::shared_ptr<ov::pass::pattern::op::CompressedWeightsBlock>& weights_block,
52+
const ov::pass::pattern::PatternValueMap& pattern_map,
53+
bool convert_u4zp_to_u8,
54+
bool has_transpose,
55+
bool grouped,
56+
bool batched_weights,
57+
std::vector<std::shared_ptr<ov::Node>>& result_nodes);
3058
};
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/pass/matcher_pass.hpp"
8+
#include "openvino/pass/pattern/op/block.hpp"
9+
#include "ov_ops/fully_connected.hpp"
10+
#include "transformations_visibility.hpp"
11+
12+
namespace ov::pass::pattern::op {
13+
14+
class TRANSFORMATIONS_API CompressedWeightsBlock;
15+
16+
} // namespace ov::pass::pattern::op
17+
18+
class ov::pass::pattern::op::CompressedWeightsBlock : public ov::pass::pattern::op::Block {
19+
public:
20+
CompressedWeightsBlock(const std::vector<ov::element::Type>& supported_weights_types,
21+
const std::set<size_t>& supported_weights_ranks);
22+
};

src/common/transformations/src/transformations/op_conversions/convert_fc_to_compressed.cpp

Lines changed: 126 additions & 122 deletions
Large diffs are not rendered by default.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/pattern_blocks/compressed_weights_block.hpp"
6+
7+
#include <algorithm>
8+
#include <memory>
9+
10+
#include "openvino/core/graph_util.hpp"
11+
#include "openvino/core/rt_info.hpp"
12+
#include "openvino/core/type/element_type.hpp"
13+
#include "openvino/op/constant.hpp"
14+
#include "openvino/op/convert.hpp"
15+
#include "openvino/op/multiply.hpp"
16+
#include "openvino/op/reshape.hpp"
17+
#include "openvino/op/subtract.hpp"
18+
#include "openvino/op/transpose.hpp"
19+
#include "openvino/pass/pattern/op/optional.hpp"
20+
#include "openvino/pass/pattern/op/or.hpp"
21+
#include "openvino/pass/pattern/op/pattern.hpp"
22+
#include "openvino/pass/pattern/op/wrap_type.hpp"
23+
#include "ov_ops/fully_connected.hpp"
24+
#include "ov_ops/fully_connected_compressed.hpp"
25+
#include "transformations/utils/utils.hpp"
26+
27+
using namespace ov::pass::pattern;
28+
ov::pass::pattern::op::CompressedWeightsBlock::CompressedWeightsBlock(
29+
const std::vector<ov::element::Type>& supported_weights_types,
30+
const std::set<size_t>& supported_weights_ranks)
31+
: Block({}, {}, "CompressedWeightsBlock") {
32+
auto weights = wrap_type<ov::op::v0::Constant>(type_matches_any(supported_weights_types));
33+
auto convert = wrap_type<ov::op::v0::Convert>({weights});
34+
35+
auto sub_const = wrap_type<ov::op::v0::Constant>();
36+
auto sub_convert_const = wrap_type<ov::op::v0::Convert>({sub_const});
37+
auto sub_with_convert = wrap_type<ov::op::v1::Subtract>({convert, sub_convert_const});
38+
auto sub_no_convert = wrap_type<ov::op::v1::Subtract>({convert, sub_const});
39+
auto subtract = sub_with_convert | sub_no_convert;
40+
41+
auto mul_const = wrap_type<ov::op::v0::Constant>();
42+
auto mul_convert_const = wrap_type<ov::op::v0::Convert>({mul_const});
43+
auto mul_scale = mul_const | mul_convert_const;
44+
45+
auto mul_with_sub = wrap_type<ov::op::v1::Multiply>({subtract, mul_scale});
46+
auto mul_no_sub = wrap_type<ov::op::v1::Multiply>({convert, mul_scale});
47+
auto mul = mul_with_sub | mul_no_sub;
48+
49+
auto reshape_predicate = [supported_weights_ranks](const ov::Output<ov::Node>& output) {
50+
const auto& in_ps = output.get_node()->get_input_partial_shape(0);
51+
const auto& out_ps = output.get_node()->get_output_partial_shape(0);
52+
std::set<size_t> supported_weights_ranks_before_reshape;
53+
for (auto r : supported_weights_ranks) {
54+
supported_weights_ranks_before_reshape.insert(r + 1);
55+
}
56+
return in_ps.rank().is_static() && out_ps.rank().is_static() &&
57+
supported_weights_ranks_before_reshape.count(in_ps.size()) &&
58+
supported_weights_ranks.count(out_ps.size());
59+
};
60+
auto reshape_const = wrap_type<ov::op::v0::Constant>();
61+
auto reshape = wrap_type<ov::op::v1::Reshape>({mul, reshape_const}, reshape_predicate);
62+
63+
auto transpose_input = reshape | mul;
64+
auto transpose_const = wrap_type<ov::op::v0::Constant>();
65+
auto transpose = wrap_type<ov::op::v1::Transpose>({transpose_input, transpose_const});
66+
67+
auto weights_input = optional<ov::op::v0::Convert>({reshape | transpose | mul});
68+
69+
// Block initialization
70+
m_inputs = ov::OutputVector{weights};
71+
m_outputs = ov::OutputVector{weights_input};
72+
REGISTER_ANCHORS(this,
73+
weights,
74+
convert,
75+
sub_const,
76+
sub_with_convert,
77+
sub_no_convert,
78+
mul_const,
79+
transpose,
80+
transpose_const);
81+
}

src/core/include/openvino/pass/pattern/op/block.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
namespace {
1010

1111
// _MAKE_ANCHOR is an internal macro for REGISTER_ANCHORS that is not supposed to used separately.
12-
#define _MAKE_ANCHOR(x) block->register_anchor(#x, x);
12+
#define _MAKE_ANCHOR(block, x) (block)->register_anchor(#x, x);
1313

1414
} // namespace
1515

@@ -23,9 +23,9 @@ namespace ov::pass::pattern::op {
2323
*
2424
*/
2525

26-
#define REGISTER_ANCHORS(block, ...) \
27-
do { \
28-
FOR_EACH(_MAKE_ANCHOR, __VA_ARGS__) \
26+
#define REGISTER_ANCHORS(block, ...) \
27+
do { \
28+
FOR_EACH(_MAKE_ANCHOR, block, __VA_ARGS__) \
2929
} while (0)
3030

3131
/**
@@ -95,10 +95,11 @@ class OPENVINO_API Block : public Pattern {
9595
return m_named_anchors;
9696
}
9797

98-
private:
98+
protected:
9999
OutputVector m_inputs;
100100
OutputVector m_outputs;
101101

102+
private:
102103
std::map<std::string, Output<Node>> m_named_anchors;
103104
};
104105

src/core/include/openvino/pass/pattern/op/block_util.hpp

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,49 @@
33
namespace {
44

55
// FOR_EACH macros up to 16 arguments:
6-
#define FOR_EACH_1(M, x1) M(x1)
7-
#define FOR_EACH_2(M, x1, x2) M(x1) M(x2)
8-
#define FOR_EACH_3(M, x1, x2, x3) M(x1) M(x2) M(x3)
9-
#define FOR_EACH_4(M, x1, x2, x3, x4) M(x1) M(x2) M(x3) M(x4)
10-
#define FOR_EACH_5(M, x1, x2, x3, x4, x5) M(x1) M(x2) M(x3) M(x4) M(x5)
11-
#define FOR_EACH_6(M, x1, x2, x3, x4, x5, x6) M(x1) M(x2) M(x3) M(x4) M(x5) M(x6)
12-
#define FOR_EACH_7(M, x1, x2, x3, x4, x5, x6, x7) M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7)
13-
#define FOR_EACH_8(M, x1, x2, x3, x4, x5, x6, x7, x8) M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8)
14-
#define FOR_EACH_9(M, x1, x2, x3, x4, x5, x6, x7, x8, x9) M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8) M(x9)
15-
#define FOR_EACH_10(M, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) \
16-
M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8) M(x9) M(x10)
17-
#define FOR_EACH_11(M, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11) \
18-
M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8) M(x9) M(x10) M(x11)
19-
#define FOR_EACH_12(M, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12) \
20-
M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8) M(x9) M(x10) M(x11) M(x12)
21-
#define FOR_EACH_13(M, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13) \
22-
M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8) M(x9) M(x10) M(x11) M(x12) M(x13)
23-
#define FOR_EACH_14(M, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14) \
24-
M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8) M(x9) M(x10) M(x11) M(x12) M(x13) M(x14)
25-
#define FOR_EACH_15(M, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15) \
26-
M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8) M(x9) M(x10) M(x11) M(x12) M(x13) M(x14) M(x15)
27-
#define FOR_EACH_16(M, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, x16) \
28-
M(x1) M(x2) M(x3) M(x4) M(x5) M(x6) M(x7) M(x8) M(x9) M(x10) M(x11) M(x12) M(x13) M(x14) M(x15) M(x16)
6+
#define FOR_EACH_1(M, B, x1) M(B, x1)
7+
#define FOR_EACH_2(M, B, x1, x2) M(B, x1) M(B, x2)
8+
#define FOR_EACH_3(M, B, x1, x2, x3) M(B, x1) M(B, x2) M(B, x3)
9+
#define FOR_EACH_4(M, B, x1, x2, x3, x4) M(B, x1) M(B, x2) M(B, x3) M(B, x4)
10+
#define FOR_EACH_5(M, B, x1, x2, x3, x4, x5) M(B, x1) M(B, x2) M(B, x3) M(B, x4) M(B, x5)
11+
#define FOR_EACH_6(M, B, x1, x2, x3, x4, x5, x6) M(B, x1) M(B, x2) M(B, x3) M(B, x4) M(B, x5) M(B, x6)
12+
#define FOR_EACH_7(M, B, x1, x2, x3, x4, x5, x6, x7) M(B, x1) M(B, x2) M(B, x3) M(B, x4) M(B, x5) M(B, x6) M(B, x7)
13+
#define FOR_EACH_8(M, B, x1, x2, x3, x4, x5, x6, x7, x8) \
14+
M(B, x1) M(B, x2) M(B, x3) M(B, x4) M(B, x5) M(B, x6) M(B, x7) M(B, x8)
15+
#define FOR_EACH_9(M, B, x1, x2, x3, x4, x5, x6, x7, x8, x9) \
16+
M(B, x1) M(B, x2) M(B, x3) M(B, x4) M(B, x5) M(B, x6) M(B, x7) M(B, x8) M(B, x9)
17+
#define FOR_EACH_10(M, B, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10) \
18+
M(B, x1) M(B, x2) M(B, x3) M(B, x4) M(B, x5) M(B, x6) M(B, x7) M(B, x8) M(B, x9) M(B, x10)
19+
#define FOR_EACH_11(M, B, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11) \
20+
M(B, x1) M(B, x2) M(B, x3) M(B, x4) M(B, x5) M(B, x6) M(B, x7) M(B, x8) M(B, x9) M(B, x10) M(B, x11)
21+
#define FOR_EACH_12(M, B, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12) \
22+
M(B, x1) M(B, x2) M(B, x3) M(B, x4) M(B, x5) M(B, x6) M(B, x7) M(B, x8) M(B, x9) M(B, x10) M(B, x11) M(B, x12)
23+
#define FOR_EACH_13(M, B, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13) \
24+
M(B, x1) \
25+
M(B, x2) M(B, x3) M(B, x4) M(B, x5) M(B, x6) M(B, x7) M(B, x8) M(B, x9) M(B, x10) M(B, x11) M(B, x12) M(B, x13)
26+
#define FOR_EACH_14(M, B, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14) \
27+
M(B, x1) \
28+
M(B, x2) \
29+
M(B, x3) M(B, x4) M(B, x5) M(B, x6) M(B, x7) M(B, x8) M(B, x9) M(B, x10) M(B, x11) M(B, x12) M(B, x13) M(B, x14)
30+
#define FOR_EACH_15(M, B, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15) \
31+
M(B, x1) \
32+
M(B, x2) \
33+
M(B, x3) \
34+
M(B, x4) M(B, x5) M(B, x6) M(B, x7) M(B, x8) M(B, x9) M(B, x10) M(B, x11) M(B, x12) M(B, x13) M(B, x14) M(B, x15)
35+
#define FOR_EACH_16(M, B, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15, x16) \
36+
M(B, x1) \
37+
M(B, x2) \
38+
M(B, x3) \
39+
M(B, x4) \
40+
M(B, x5) M(B, x6) M(B, x7) M(B, x8) M(B, x9) M(B, x10) M(B, x11) M(B, x12) M(B, x13) M(B, x14) M(B, x15) M(B, x16)
2941

30-
#define GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, NAME, ...) NAME
42+
#define GET_MACRO(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, NAME, ...) NAME
3143

3244
#define EXPAND(x) x
3345

34-
#define FOR_EACH(M, ...) \
35-
EXPAND(GET_MACRO(__VA_ARGS__, \
46+
#define FOR_EACH(M, B, ...) \
47+
EXPAND(GET_MACRO(_0, \
48+
__VA_ARGS__, \
3649
FOR_EACH_16, \
3750
FOR_EACH_15, \
3851
FOR_EACH_14, \
@@ -48,6 +61,6 @@ namespace {
4861
FOR_EACH_4, \
4962
FOR_EACH_3, \
5063
FOR_EACH_2, \
51-
FOR_EACH_1)(M, __VA_ARGS__))
64+
FOR_EACH_1)(M, B, __VA_ARGS__))
5265

5366
} // namespace

src/plugins/intel_cpu/src/cpu_types.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ static const TypeToNameMap& get_type_to_name_tbl() {
265265
{"QKVProjection", Type::QKVProjection},
266266
{"RMS", Type::RMS},
267267
{"SearchSorted", Type::SearchSorted},
268-
{"LoraSubgraph", Type::LoRA}};
268+
{"LoraSubgraph", Type::LoRA},
269+
{"BatchGatherMatmul", Type::GatherMatmul},
270+
{"BatchGatherMatmulCompressed", Type::GatherMatmul}};
269271
return type_to_name_tbl;
270272
}
271273

@@ -400,6 +402,7 @@ std::string NameFromType(const Type type) {
400402
CASE(SearchSorted);
401403
CASE(SegmentMax);
402404
CASE(LoRA);
405+
CASE(GatherMatmul);
403406
CASE(Unknown);
404407
}
405408
#undef CASE

src/plugins/intel_cpu/src/cpu_types.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ enum class Type : uint8_t {
137137
RMS,
138138
SearchSorted,
139139
SegmentMax,
140-
LoRA
140+
LoRA,
141+
GatherMatmul
141142
};
142143

143144
enum class Algorithm : uint8_t {

src/plugins/intel_cpu/src/graph.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,6 +2009,7 @@ void Graph::EnforceInferencePrecision() {
20092009
Type::Interpolate, // super resolution nets
20102010
Type::PagedAttention, // page attention
20112011
Type::QKVProjection,
2012+
Type::GatherMatmul,
20122013
Type::LLMMLP)) {
20132014
continue; // stop at significant nodes
20142015
}

src/plugins/intel_cpu/src/nodes/fullyconnected.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ class FullyConnected : public Node {
9191
this->attrs.weightsNonTransposed = weightsNonTransposed;
9292
}
9393

94-
void fuseDecompressionMultiply(const MemoryCPtr& memory);
95-
void fuseDecompressionSubtract(const MemoryCPtr& memory);
96-
9794
protected:
9895
void toNumaNodeImpl(int numaID) override;
9996

0 commit comments

Comments
 (0)