|
13 | 13 | #include "emitters/snippets/cpu_runtime_configurator.hpp" |
14 | 14 | #include "emitters/utils.hpp" |
15 | 15 | #include "jit_snippets_emitters.hpp" |
| 16 | +#include "openvino/core/type.hpp" |
| 17 | +#include "openvino/op/prelu.hpp" |
| 18 | +#include "openvino/op/round.hpp" |
| 19 | +#include "openvino/op/sqrt.hpp" |
16 | 20 | #include "openvino/opsets/opset13.hpp" |
| 21 | +#include "snippets/emitter.hpp" |
| 22 | +#include "snippets/lowered/expression.hpp" |
17 | 23 | #include "snippets/snippets_isa.hpp" |
18 | 24 | #include "transformations/cpu_opset/common/op/swish_cpu.hpp" |
19 | 25 | #include "transformations/snippets/common/op/fused_mul_add.hpp" |
@@ -44,7 +50,7 @@ namespace ov { |
44 | 50 | { \ |
45 | 51 | [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \ |
46 | 52 | const auto& n = expr->get_node(); \ |
47 | | - const auto& gelu = std::dynamic_pointer_cast<ov::op::v7::Gelu>(n); \ |
| 53 | + const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \ |
48 | 54 | if (gelu == nullptr) { \ |
49 | 55 | OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \ |
50 | 56 | } \ |
@@ -73,6 +79,37 @@ namespace ov { |
73 | 79 | } \ |
74 | 80 | } |
75 | 81 |
|
| 82 | +#define CREATE_ROUND_V5_EMITTER(e_type_from_zero, e_type_even) \ |
| 83 | + { \ |
| 84 | + [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \ |
| 85 | + const auto& n = expr->get_node(); \ |
| 86 | + const auto& round = ov::as_type_ptr<ov::op::v5::Round>(n); \ |
| 87 | + if (round == nullptr) { \ |
| 88 | + OPENVINO_THROW("Can't cast to ov::op::v5::Round"); \ |
| 89 | + } \ |
| 90 | + const auto roundingMode = round->get_mode(); \ |
| 91 | + if (roundingMode == ov::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO) { \ |
| 92 | + return std::make_shared<e_type_from_zero>(h.get(), isa, n); \ |
| 93 | + } else if (roundingMode == ov::op::v5::Round::RoundMode::HALF_TO_EVEN) { \ |
| 94 | + return std::make_shared<e_type_even>(h.get(), isa, n); \ |
| 95 | + } else { \ |
| 96 | + OPENVINO_THROW("Unsupported Round mode"); \ |
| 97 | + } \ |
| 98 | + }, \ |
| 99 | + [](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \ |
| 100 | + const auto& round = std::dynamic_pointer_cast<ov::op::v5::Round>(n); \ |
| 101 | + if (round == nullptr) { \ |
| 102 | + OPENVINO_THROW("Can't cast to ov::op::v5::Round"); \ |
| 103 | + } \ |
| 104 | + if (round->get_mode() == ov::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO) { \ |
| 105 | + return e_type_from_zero::get_supported_precisions(n); \ |
| 106 | + } else if (round->get_mode() == ov::op::v5::Round::RoundMode::HALF_TO_EVEN) { \ |
| 107 | + return e_type_even::get_supported_precisions(n); \ |
| 108 | + } \ |
| 109 | + OPENVINO_THROW("Unsupported Round mode"); \ |
| 110 | + } \ |
| 111 | + } |
| 112 | + |
76 | 113 | class jit_snippet : public dnnl::impl::cpu::aarch64::jit_generator { |
77 | 114 | public: |
78 | 115 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_snippet) |
@@ -155,8 +192,12 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa) |
155 | 192 | CREATE_GELU_V7_EMITTER(jit_gelu_erf_emitter, jit_gelu_tanh_emitter); |
156 | 193 | jitters[ov::op::v4::HSwish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_hswish_emitter); |
157 | 194 | jitters[ov::op::v4::Mish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_mish_emitter); |
| 195 | + jitters[ov::op::v0::PRelu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_prelu_emitter); |
158 | 196 | jitters[ov::op::v0::Relu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_relu_emitter); |
| 197 | + jitters[ov::op::v5::Round::get_type_info_static()] = |
| 198 | + CREATE_ROUND_V5_EMITTER(jit_round_half_away_from_zero_emitter, jit_round_half_to_even_emitter); |
159 | 199 | jitters[ov::op::v0::Sigmoid::get_type_info_static()] = CREATE_CPU_EMITTER(jit_sigmoid_emitter); |
| 200 | + jitters[ov::op::v0::Sqrt::get_type_info_static()] = CREATE_CPU_EMITTER(jit_sqrt_emitter); |
160 | 201 | jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter); |
161 | 202 | jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter); |
162 | 203 |
|
|
0 commit comments