Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions src/common/snippets/src/pass/align_element_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,6 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
std::shared_ptr<ov::Node> consumer = shape_infer_leaf ? shape_infer_leaf : results[i];
auto parent_output = consumer->get_input_source_output(0);

// Snippets supports Transpose only after Parameter or before Result nodes
// So we have to insert Convert before Transpose (if there is) on Subgraph outputs
const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent_output.get_node_shared_ptr());
if (transpose) {
OPENVINO_ASSERT(
parent_output.get_target_inputs().size() == 1,
"If Result has Transpose on input, this Result must be single consumer of the Transpose");
parent_output = transpose->get_input_source_output(0);
consumer = transpose;
}

// If there is already Convert[needed_in_type->original_type] and this node has only one consumer, we can
// remove the Convert, since the sequence existing Convert[needed_in_type->original_type] -> new
// Convert[original_type->needed_in_type] is redundant
Expand All @@ -81,9 +70,6 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)

consumer->set_argument(0, convert);
consumer->validate_and_infer_types();
if (transpose) {
results[i]->validate_and_infer_types();
}
is_modified = true;
}
}
Expand Down
59 changes: 32 additions & 27 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@

# include "emitters/snippets/x64/cpu_generator.hpp"
# include "executors/x64/subgraph.hpp"
# include "snippets/lowered/port_descriptor.hpp"
# include "snippets/op/brgemm.hpp"
# include "snippets/pass/matmul_to_brgemm.hpp"
# include "snippets/utils/utils.hpp"
# include "transformations/snippets/x64/op/brgemm_utils.hpp"
#elif defined(OPENVINO_ARCH_ARM64)
# include <cpu/aarch64/cpu_isa_traits.hpp>
Expand Down Expand Up @@ -86,6 +87,7 @@
# include "snippets/lowered/pass/init_loops.hpp"
# include "snippets/lowered/pass/insert_buffers.hpp"
# include "snippets/lowered/pass/insert_loops.hpp"
# include "snippets/pass/fuse_transpose_brgemm.hpp"
# include "transformations/snippets/common/pass/enforce_precision.hpp"
# include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
# include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp"
Expand Down Expand Up @@ -552,34 +554,37 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {

if (any_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Notes:
// - MatMul decomposition will be run later again for case if BF16 enforcement is not happened
// - `MatMulToBrgemm` pass fuse `transpose_a` and `transpose_b` from MatMul to inputs of Brgemm as layouts.
// These layouts are resized to ranks of input shapes. But since `Canonicalization` might
// reshape shapes, the pass `MatMulToBrgemm` should be after the pass `Canonicalization` to
// fuse layouts with ranks aligned with updated shapes after RankNormalization insertions.
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::snippets::pass::Canonicalization,
ov::snippets::pass::MatMulToBrgemm);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(
Place::After,
ov::snippets::pass::FuseTransposeBrgemm,
pass::EnforcePrecision,
element::f32,
context->getConfig().inferencePrecision,
[](const std::shared_ptr<ov::Node>& op) {
std::set<std::vector<ov::element::Type>> types;
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
const auto& a_port =
ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(op->input(0));
// WA: We can't perform precision enforcement in case of strided access to A matrix:
// snippets eltwise loops for precision conversion are generated by last 2 dims,
// which are not [M, K] in case of strided access in brgemm A
// There are no limitations for B matrix, since precision conversion is fused in BrgemmCopyB
if (ov::snippets::utils::is_planar_layout(a_port->get_layout())) {
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
types.insert({ov::element::f16, ov::element::f16});
}
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
types.insert({ov::element::bf16, ov::element::bf16});
}
}
}
return types;
});
// Note: EnforcePrecision might also eliminate Convert pairs (e.g. bf16->f32->bf16),
// so FuseTransposeBrgemm has to be run after it as well
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::snippets::pass::MatMulToBrgemm,
pass::EnforcePrecision,
element::f32,
context->getConfig().inferencePrecision,
[](const std::shared_ptr<ov::Node>& op) {
std::set<std::vector<ov::element::Type>> types;
if (ov::is_type<ov::snippets::op::Brgemm>(op)) {
if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) {
types.insert({ov::element::f16, ov::element::f16});
}
if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) {
types.insert({ov::element::bf16, ov::element::bf16});
}
}
return types;
});
ov::snippets::pass::FuseTransposeBrgemm);
}

SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ bool pass::EliminateBrgemmCopyB::run_on_model(const std::shared_ptr<ov::Model>&
// Since repacking is moved out of Subgraph body,
// the rest weights subgraph must be updated with precision after repacking
param->set_element_type(copy_b_node->get_config().wei_dt());
// Note: validation is called manually since set_element_type doesn't update output element type
param->validate_and_infer_types();
if (pattern_map.count(m_rank_norm)) {
pattern_map.at(m_rank_norm).get_node_shared_ptr()->validate_and_infer_types();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::ValuesIn(precision_bf16_if_supported(4)),
::testing::Values(ov::element::f32),
::testing::Values(ov::element::bf16),
::testing::Values(false),
::testing::Values(MHA::default_thread_count),
::testing::Values(8), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output
::testing::Values(6), // MHA + 5 Converts on inputs and output
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
::testing::Values(2), // decomposed Transpose + MHA
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);
Expand All @@ -182,6 +182,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16_f32_in_prc,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false}),
::testing::Values(MHA::default_thread_count),
::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output
::testing::Values(2), // decomposed Transpose + MHA
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
Expand Down
14 changes: 9 additions & 5 deletions src/tests/functional/plugin/shared/src/snippets/mha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ void MHABase::generate_inputs(const std::vector<ov::Shape>& targetInputStaticSha
const auto& model_input = model_inputs[i];
ov::Tensor tensor;
ov::test::utils::InputGenerateData in_data;
const bool bf16_precision =
configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>() == ov::element::bf16 ||
model_input.get_element_type() == ov::element::bf16;
// To avoid big relative errors in the vicinity of zero, only positive values are generated for bf16 precision
in_data.start_from = model_input.get_element_type() == ov::element::bf16 ? 0 : -1;
in_data.start_from = bf16_precision ? 0 : -1;
in_data.range = 2;
in_data.resolution = 256;
tensor =
Expand All @@ -55,16 +58,17 @@ void MHABase::SetUp() {
setInferenceType(prc);
}

void MHABase::init_thresholds() {
void MHABase::init_thresholds() {
// Note: Libxsmm calculates Exp in a slightly different way, so the abs values might differ a bit. Ticket: 130699
#ifdef SNIPPETS_LIBXSMM_TPP
abs_threshold = 1e-6;
#endif
if (inType == ov::element::bf16)
auto infer_precision = configuration.at(ov::hint::inference_precision.name()).as<ov::element::Type>();
if (infer_precision == ov::element::bf16)
rel_threshold = 0.05f;
if (inType == ov::element::f16)
if (infer_precision == ov::element::f16)
abs_threshold = 2e-2;
}
}

std::string MHA::getTestCaseName(const testing::TestParamInfo<ov::test::snippets::MHAParams>& obj) {
const auto& [input_shapes,
Expand Down
Loading