Skip to content

Commit e208a93

Browse files
0xfedcafeMirceaDan99
authored andcommitted
[Snippets] [ARM]: Fixed bug in PReLU emitter, enabled PReLU, Sqrt, Round tokenization (openvinotoolkit#28223)
### Details: - Fixed a bug in PReLU emitter - Enabled PReLU, Sqrt, Round tokenization - All local tests pass ### Tickets: - openvinotoolkit#28161
1 parent 291c9e9 commit e208a93

File tree

4 files changed

+80
-26
lines changed

4 files changed

+80
-26
lines changed

src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "common/utils.hpp"
1010
#include "emitters/utils.hpp"
11+
#include "openvino/core/type/element_type.hpp"
1112
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
1213

1314
namespace ov {
@@ -2128,7 +2129,7 @@ size_t jit_prelu_emitter::get_aux_vecs_count() const {
21282129

21292130
std::set<std::vector<element::Type>> jit_prelu_emitter::get_supported_precisions(
21302131
const std::shared_ptr<ov::Node>& node) {
2131-
return {{element::f32}};
2132+
return {{element::f32, element::f32}};
21322133
}
21332134

21342135
void jit_prelu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs,

src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
#include "emitters/snippets/cpu_runtime_configurator.hpp"
1414
#include "emitters/utils.hpp"
1515
#include "jit_snippets_emitters.hpp"
16+
#include "openvino/core/type.hpp"
17+
#include "openvino/op/prelu.hpp"
18+
#include "openvino/op/round.hpp"
19+
#include "openvino/op/sqrt.hpp"
1620
#include "openvino/opsets/opset13.hpp"
21+
#include "snippets/emitter.hpp"
22+
#include "snippets/lowered/expression.hpp"
1723
#include "snippets/snippets_isa.hpp"
1824
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
1925
#include "transformations/snippets/common/op/fused_mul_add.hpp"
@@ -44,7 +50,7 @@ namespace ov {
4450
{ \
4551
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
4652
const auto& n = expr->get_node(); \
47-
const auto& gelu = std::dynamic_pointer_cast<ov::op::v7::Gelu>(n); \
53+
const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \
4854
if (gelu == nullptr) { \
4955
OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \
5056
} \
@@ -73,6 +79,37 @@ namespace ov {
7379
} \
7480
}
7581

82+
#define CREATE_ROUND_V5_EMITTER(e_type_from_zero, e_type_even) \
83+
{ \
84+
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
85+
const auto& n = expr->get_node(); \
86+
const auto& round = ov::as_type_ptr<ov::op::v5::Round>(n); \
87+
if (round == nullptr) { \
88+
OPENVINO_THROW("Can't cast to ov::op::v5::Round"); \
89+
} \
90+
const auto roundingMode = round->get_mode(); \
91+
if (roundingMode == ov::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO) { \
92+
return std::make_shared<e_type_from_zero>(h.get(), isa, n); \
93+
} else if (roundingMode == ov::op::v5::Round::RoundMode::HALF_TO_EVEN) { \
94+
return std::make_shared<e_type_even>(h.get(), isa, n); \
95+
} else { \
96+
OPENVINO_THROW("Unsupported Round mode"); \
97+
} \
98+
}, \
99+
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
100+
const auto& round = std::dynamic_pointer_cast<ov::op::v5::Round>(n); \
101+
if (round == nullptr) { \
102+
OPENVINO_THROW("Can't cast to ov::op::v5::Round"); \
103+
} \
104+
if (round->get_mode() == ov::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO) { \
105+
return e_type_from_zero::get_supported_precisions(n); \
106+
} else if (round->get_mode() == ov::op::v5::Round::RoundMode::HALF_TO_EVEN) { \
107+
return e_type_even::get_supported_precisions(n); \
108+
} \
109+
OPENVINO_THROW("Unsupported Round mode"); \
110+
} \
111+
}
112+
76113
class jit_snippet : public dnnl::impl::cpu::aarch64::jit_generator {
77114
public:
78115
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_snippet)
@@ -155,8 +192,12 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa)
155192
CREATE_GELU_V7_EMITTER(jit_gelu_erf_emitter, jit_gelu_tanh_emitter);
156193
jitters[ov::op::v4::HSwish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_hswish_emitter);
157194
jitters[ov::op::v4::Mish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_mish_emitter);
195+
jitters[ov::op::v0::PRelu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_prelu_emitter);
158196
jitters[ov::op::v0::Relu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_relu_emitter);
197+
jitters[ov::op::v5::Round::get_type_info_static()] =
198+
CREATE_ROUND_V5_EMITTER(jit_round_half_away_from_zero_emitter, jit_round_half_to_even_emitter);
159199
jitters[ov::op::v0::Sigmoid::get_type_info_static()] = CREATE_CPU_EMITTER(jit_sigmoid_emitter);
200+
jitters[ov::op::v0::Sqrt::get_type_info_static()] = CREATE_CPU_EMITTER(jit_sqrt_emitter);
160201
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter);
161202
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter);
162203

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#include <ov_ops/gather_compressed.hpp>
1313

1414
#include "openvino/op/paged_attention.hpp"
15+
#include "openvino/op/prelu.hpp"
16+
#include "openvino/op/round.hpp"
17+
#include "openvino/op/sqrt.hpp"
1518
#include "openvino/opsets/opset1.hpp"
1619
#include "openvino/opsets/opset10.hpp"
1720
#include "openvino/opsets/opset2.hpp"
@@ -1123,16 +1126,17 @@ void Transformations::MainSnippets(void) {
11231126
ov::is_type<ov::op::v0::Clamp>(n) || ov::is_type<ov::op::v0::Ceiling>(n) ||
11241127
ov::is_type<ov::op::v0::Convert>(n) || ov::is_type<ov::op::v1::Divide>(n) ||
11251128
ov::is_type<ov::op::v0::Elu>(n) || ov::is_type<ov::op::v0::Exp>(n) ||
1126-
ov::is_type<ov::op::v0::Floor>(n) || ov::is_type<ov::op::v1::FloorMod>(n) ||
1127-
ov::is_type<ov::op::v0::Gelu>(n) || ov::is_type<ov::op::v7::Gelu>(n) ||
1128-
ov::is_type<ov::op::v4::HSwish>(n) || ov::is_type<ov::op::v1::Maximum>(n) ||
1129+
ov::is_type<ov::op::v1::Equal>(n) || ov::is_type<ov::op::v0::Floor>(n) ||
1130+
ov::is_type<ov::op::v1::FloorMod>(n) || ov::is_type<ov::op::v0::Gelu>(n) ||
1131+
ov::is_type<ov::op::v7::Gelu>(n) || ov::is_type<ov::op::v1::Greater>(n) ||
1132+
ov::is_type<ov::op::v1::GreaterEqual>(n) || ov::is_type<ov::op::v4::HSwish>(n) ||
1133+
ov::is_type<ov::op::v1::LessEqual>(n) || ov::is_type<ov::op::v1::Maximum>(n) ||
11291134
ov::is_type<ov::op::v1::Minimum>(n) || ov::is_type<ov::op::v4::Mish>(n) ||
11301135
ov::is_type<ov::op::v1::Mod>(n) || ov::is_type<ov::op::v1::Multiply>(n) ||
1131-
ov::is_type<ov::op::v0::Relu>(n) || ov::is_type<ov::op::v0::Sigmoid>(n) ||
1132-
ov::is_type<ov::op::v1::Subtract>(n) || ov::is_type<ov::op::v4::Swish>(n) ||
1133-
ov::is_type<ov::op::v1::Equal>(n) || ov::is_type<ov::op::v1::Greater>(n) ||
1134-
ov::is_type<ov::op::v1::GreaterEqual>(n) || ov::is_type<ov::op::v1::LessEqual>(n) ||
1135-
ov::is_type<ov::op::v0::Tanh>(n));
1136+
ov::is_type<ov::op::v0::PRelu>(n) || ov::is_type<ov::op::v0::Relu>(n) ||
1137+
ov::is_type<ov::op::v5::Round>(n) || ov::is_type<ov::op::v0::Sigmoid>(n) ||
1138+
ov::is_type<ov::op::v0::Sqrt>(n) || ov::is_type<ov::op::v1::Subtract>(n) ||
1139+
ov::is_type<ov::op::v4::Swish>(n) || ov::is_type<ov::op::v0::Tanh>(n));
11361140
#else
11371141
// CPU Plugin support Swish in Subgraph via conversion to SwichCPU which assumes second input to be constant,
11381142
// and CPU Plugin does not support Mish for x64

src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/classes/activation.cpp

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,9 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
197197
(activation_type == utils::ActivationTypes::Sqrt) ||
198198
(activation_type == utils::ActivationTypes::Swish) ||
199199
(activation_type == utils::ActivationTypes::LogicalNot) ||
200-
(activation_type == utils::ActivationTypes::Tanh))) {
200+
(activation_type == utils::ActivationTypes::Tanh) ||
201+
(activation_type == utils::ActivationTypes::RoundHalfAwayFromZero) ||
202+
(activation_type == utils::ActivationTypes::RoundHalfToEven))) {
201203
return "jit";
202204
}
203205

@@ -209,7 +211,9 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
209211
if ((activation_type == utils::ActivationTypes::Floor) ||
210212
(activation_type == utils::ActivationTypes::Ceiling) ||
211213
(activation_type == utils::ActivationTypes::IsNaN) ||
212-
(activation_type == utils::ActivationTypes::IsFinite)) {
214+
(activation_type == utils::ActivationTypes::IsFinite) ||
215+
(activation_type == utils::ActivationTypes::RoundHalfAwayFromZero) ||
216+
(activation_type == utils::ActivationTypes::RoundHalfToEven)) {
213217
return "ref";
214218
}
215219
return "acl";
@@ -265,22 +269,26 @@ const std::map<utils::ActivationTypes, std::vector<std::vector<float>>>& activat
265269

266270
const std::map<utils::ActivationTypes, std::vector<std::vector<float>>>& activationTypesSnippets() {
267271
static const std::map<utils::ActivationTypes, std::vector<std::vector<float>>> activationTypes {
268-
{Abs, {{}}},
269-
{Exp, {{}}},
270-
{Ceiling, {{}}},
271-
{Clamp, {{-2.0f, 2.0f}}},
272-
{Elu, {{0.1f}}},
273-
{Floor, {{}}},
274-
{GeluErf, {{}}},
275-
{GeluTanh, {{}}},
276-
{Relu, {{}}},
277-
{HSwish, {{}}},
272+
{Abs, {{}}},
273+
{Exp, {{}}},
274+
{Ceiling, {{}}},
275+
{Clamp, {{-2.0f, 2.0f}}},
276+
{Elu, {{0.1f}}},
277+
{Floor, {{}}},
278+
{GeluErf, {{}}},
279+
{GeluTanh, {{}}},
280+
{Relu, {{}}},
281+
{HSwish, {{}}},
282+
{PReLu, {{-0.01f}}},
283+
{Sqrt, {{}}},
284+
{RoundHalfToEven, {{}}},
285+
{RoundHalfAwayFromZero, {{}}},
278286
#if defined(OPENVINO_ARCH_ARM64)
279-
{Mish, {{}}},
287+
{Mish, {{}}},
280288
#endif
281-
{Sigmoid, {{}}},
282-
{Swish, {{0.1f}}},
283-
{Tanh, {{}}},
289+
{Sigmoid, {{}}},
290+
{Swish, {{0.1f}}},
291+
{Tanh, {{}}},
284292
};
285293

286294
return activationTypes;

0 commit comments

Comments
 (0)