diff --git a/src/common/snippets/src/pass/align_element_types.cpp b/src/common/snippets/src/pass/align_element_types.cpp index b7f4932e2a88fc..483a8a3ac83d5e 100644 --- a/src/common/snippets/src/pass/align_element_types.cpp +++ b/src/common/snippets/src/pass/align_element_types.cpp @@ -47,17 +47,6 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr& m) std::shared_ptr consumer = shape_infer_leaf ? shape_infer_leaf : results[i]; auto parent_output = consumer->get_input_source_output(0); - // Snippets supports Transpose only after Parameter or before Result nodes - // So we have to insert Convert before Transpose (if there is) on Subgraph outputs - const auto transpose = ov::as_type_ptr(parent_output.get_node_shared_ptr()); - if (transpose) { - OPENVINO_ASSERT( - parent_output.get_target_inputs().size() == 1, - "If Result has Transpose on input, this Result must be single consumer of the Transpose"); - parent_output = transpose->get_input_source_output(0); - consumer = transpose; - } - // If there is already Convert[needed_in_type->original_type] and this node has only one consumer, we can // remove the Convert, since the sequence existing Convert[needed_in_type->original_type] -> new // Convert[original_type->needed_in_type] is redundant @@ -81,9 +70,6 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr& m) consumer->set_argument(0, convert); consumer->validate_and_infer_types(); - if (transpose) { - results[i]->validate_and_infer_types(); - } is_modified = true; } } diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index e02b003a37237f..3a5675debffcd1 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -46,8 +46,9 @@ # include "emitters/snippets/x64/cpu_generator.hpp" # include "executors/x64/subgraph.hpp" +# include "snippets/lowered/port_descriptor.hpp" # include "snippets/op/brgemm.hpp" -# include "snippets/pass/matmul_to_brgemm.hpp" +# include "snippets/utils/utils.hpp" # include "transformations/snippets/x64/op/brgemm_utils.hpp" #elif defined(OPENVINO_ARCH_ARM64) # include @@ -86,6 +87,7 @@ # include "snippets/lowered/pass/init_loops.hpp" # include "snippets/lowered/pass/insert_buffers.hpp" # include "snippets/lowered/pass/insert_loops.hpp" +# include "snippets/pass/fuse_transpose_brgemm.hpp" # include "transformations/snippets/common/pass/enforce_precision.hpp" # include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp" # include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp" @@ -552,34 +554,38 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { if (any_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) { - // enforce BF16 precisions to supported operations - // MatMul has to be decomposed to Brgemm operations before enforcement - // Notes: - // - MatMul decomposition will be run later again for case if BF16 enforcement is not happened - // - `MatMulToBrgemm` pass fuse `transpose_a` and `transpose_b` from MatMul to inputs of Brgemm as layouts. - // These layouts are resized to ranks of input shapes. But since `Canonicalization` might - // reshape shapes, the pass `MatMulToBrgemm` should be after the pass `Canonicalization` to - // fuse layouts with ranks aligned with updated shapes after RankNormalization insertions. - SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, - ov::snippets::pass::Canonicalization, - ov::snippets::pass::MatMulToBrgemm); + SNIPPETS_REGISTER_PASS_RELATIVE_X86_64( + Place::After, + ov::snippets::pass::FuseTransposeBrgemm, + pass::EnforcePrecision, + element::f32, + context->getConfig().inferencePrecision, + [](const std::shared_ptr& op) { + std::set> types; + if (ov::is_type(op)) { + const auto& a_port = + ov::snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(op->input(0)); + // WA: We can't perform precision enforcement in case of strided access to A matrix: + // snippets eltwise loops for precision conversion are generated by last 2 dims, + // which are not [M, K] in case of strided access in brgemm A + // There are no limitations for B matrix, since precision conversion is fused in BrgemmCopyB + // Ticket: 177121 + if (ov::snippets::utils::is_planar_layout(a_port->get_layout())) { + if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) { + types.insert({ov::element::f16, ov::element::f16}); + } + if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) { + types.insert({ov::element::bf16, ov::element::bf16}); + } + } + } + return types; + }); + // Note: EnforcePrecision might also eliminate Convert pairs (e.g. bf16->f32->bf16), + // so FuseTransposeBrgemm has to be run after it as well SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, - ov::snippets::pass::MatMulToBrgemm, pass::EnforcePrecision, - element::f32, - context->getConfig().inferencePrecision, - [](const std::shared_ptr& op) { - std::set> types; - if (ov::is_type(op)) { - if (ov::intel_cpu::brgemm_utils::is_fp16_supported()) { - types.insert({ov::element::f16, ov::element::f16}); - } - if (ov::intel_cpu::brgemm_utils::is_bf16_supported()) { - types.insert({ov::element::bf16, ov::element::bf16}); - } - } - return types; - }); + ov::snippets::pass::FuseTransposeBrgemm); } SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp index 7dc32ae3621472..005b9005f9a899 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp @@ -66,6 +66,8 @@ bool pass::EliminateBrgemmCopyB::run_on_model(const std::shared_ptr& // Since repacking is moved out of Subgraph body, // the rest weights subgraph must be updated with precision after repacking param->set_element_type(copy_b_node->get_config().wei_dt()); + // Note: validation is called manually since set_element_type doesn't update output element type + param->validate_and_infer_types(); if (pattern_map.count(m_rank_norm)) { pattern_map.at(m_rank_norm).get_node_shared_ptr()->validate_and_infer_types(); } diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index a7b034ab1d7aa5..1694577604582c 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -160,11 +160,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D, MHA, ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), ::testing::ValuesIn(precision_bf16_if_supported(4)), - ::testing::Values(ov::element::f32), + ::testing::Values(ov::element::bf16), ::testing::Values(false), ::testing::Values(MHA::default_thread_count), - ::testing::Values(8), // decomposed Transpose + MHA + 5 Converts + 1 Transpose on output - ::testing::Values(6), // MHA + 5 Converts on inputs and output + ::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output + ::testing::Values(2), // decomposed Transpose + MHA ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::empty_plugin_config)), MHA::getTestCaseName); @@ -182,6 +182,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16_f32_in_prc, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f32), + ::testing::ValuesIn({false}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(3), // decomposed Transpose + MHA + 1 Transpose on output + ::testing::Values(2), // decomposed Transpose + MHA + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), + MHA::getTestCaseName); + INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply, MHA, ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 946267107ba20d..a811ee6cddb871 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -28,8 +28,11 @@ void MHABase::generate_inputs(const std::vector& targetInputStaticSha const auto& model_input = model_inputs[i]; ov::Tensor tensor; ov::test::utils::InputGenerateData in_data; + const bool bf16_precision = + configuration.at(ov::hint::inference_precision.name()).as() == ov::element::bf16 || + model_input.get_element_type() == ov::element::bf16; // To avoid big relative errors in the vicinity of zero, only positive values are generated for bf16 precision - in_data.start_from = model_input.get_element_type() == ov::element::bf16 ? 0 : -1; + in_data.start_from = bf16_precision ? 0 : -1; in_data.range = 2; in_data.resolution = 256; tensor = @@ -55,16 +58,17 @@ void MHABase::SetUp() { setInferenceType(prc); } - void MHABase::init_thresholds() { +void MHABase::init_thresholds() { // Note: Libxsmm calculates Exp in a slightly different way, so the abs values might differ a bit. Ticket: 130699 #ifdef SNIPPETS_LIBXSMM_TPP abs_threshold = 1e-6; #endif - if (inType == ov::element::bf16) + auto infer_precision = configuration.at(ov::hint::inference_precision.name()).as(); + if (infer_precision == ov::element::bf16) rel_threshold = 0.05f; - if (inType == ov::element::f16) + if (infer_precision == ov::element::f16) abs_threshold = 2e-2; - } +} std::string MHA::getTestCaseName(const testing::TestParamInfo& obj) { const auto& [input_shapes,