@@ -153,7 +153,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
153153 std::vector<ov::element::Type>({ov::element::f32 , ov::element::f32 , ov::element::f32 , ov::element::f32 }),
154154 std::vector<Shape>{{2 , 64 , 12 , 64 }, {12 , 1 , 64 , 128 }, {12 , 2 , 64 , 128 }, {1 , 128 , 12 , 64 }, {128 , 12 , 64 }},
155155 false );
156- common_config = ov::snippets::pass::CommonOptimizations::Config (24 , true );
156+ common_config = get_default_common_optimizations_config ();
157+ common_config.set_concurrency (24 );
157158 execute_and_validate_function (*this , f);
158159}
159160
@@ -162,7 +163,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
162163 std::vector<ov::element::Type>({ov::element::f32 , ov::element::f32 , ov::element::f32 , ov::element::f32 }),
163164 std::vector<Shape>{{4 , 32 , 12 , 64 }, {12 , 1 , 64 , 128 }, {12 , 4 , 32 , 128 }, {1 , 128 , 12 , 64 }, {128 , 12 , 64 }},
164165 true );
165- common_config = ov::snippets::pass::CommonOptimizations::Config (16 , true );
166+ common_config = get_default_common_optimizations_config ();
167+ common_config.set_concurrency (16 );
166168 execute_and_validate_function (*this , f);
167169}
168170
@@ -171,7 +173,8 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
171173 std::vector<ov::element::Type>({ov::element::f32 , ov::element::f32 , ov::element::f32 , ov::element::f32 }),
172174 std::vector<Shape>{{1 , 12 , 32 , 16 , 64 }, {1 , 16 , 1 , 64 , 384 }, {1 , 1 , 1 , 1 , 384 }, {1 , 1 , 384 , 16 , 64 }, {1 , 384 , 16 , 64 }},
173175 false );
174- common_config = ov::snippets::pass::CommonOptimizations::Config (60 , true );
176+ common_config = get_default_common_optimizations_config ();
177+ common_config.set_concurrency (60 );
175178 execute_and_validate_function (*this , f);
176179}
177180
@@ -180,46 +183,52 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
180183 std::vector<ov::element::Type>({ov::element::f32 , ov::element::f32 , ov::element::f32 , ov::element::f32 }),
181184 std::vector<Shape>{{1 , 12 , 32 , 16 , 64 }, {1 , 16 , 1 , 64 , 384 }, {1 , 1 , 1 , 1 , 384 }, {1 , 1 , 384 , 16 , 64 }, {1 , 384 , 16 , 64 }},
182185 true );
183- common_config = ov::snippets::pass::CommonOptimizations::Config (60 , true );
186+ common_config = get_default_common_optimizations_config ();
187+ common_config.set_concurrency (60 );
184188 execute_and_validate_function (*this , f);
185189}
186190
187191TEST_F (TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) {
188192 const auto & f = MHAWOTransposeSplitMFunction (std::vector<PartialShape>{{10 , 9216 , 128 }, {10 , 128 , 9216 }, {10 , 9216 , 128 }},
189193 std::vector<ov::element::Type>({ov::element::f32 , ov::element::f32 , ov::element::f32 }),
190194 std::vector<Shape>{{10 , 18 , 512 , 128 }, {10 , 1 , 128 , 9216 }, {10 , 1 , 9216 , 128 }, {10 , 9216 , 128 }});
191- common_config = ov::snippets::pass::CommonOptimizations::Config (18 , true );
195+ common_config = get_default_common_optimizations_config ();
196+ common_config.set_concurrency (18 );
192197 execute_and_validate_function (*this , f);
193198}
194199
195200TEST_F (TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) {
196201 const auto & f = MHAWOTransposeSplitMFunction (std::vector<PartialShape>{{5 , 30 , 32 }, {5 , 32 , 30 }, {5 , 30 , 32 }},
197202 std::vector<ov::element::Type>({ov::element::f32 , ov::element::f32 , ov::element::f32 }),
198203 std::vector<Shape>{{5 , 10 , 3 , 32 }, {5 , 1 , 32 , 30 }, {5 , 1 , 30 , 32 }, {5 , 30 , 32 }});
199- common_config = ov::snippets::pass::CommonOptimizations::Config (32 , true );
204+ common_config = get_default_common_optimizations_config ();
205+ common_config.set_concurrency (32 );
200206 execute_and_validate_function (*this , f);
201207}
202208
203209TEST_F (TokenizeMHASnippetsTests, smoke_Snippets_MHA_4D_SplitM_DynamicParameter) {
204210 const auto &f = MHAFunction (std::vector<PartialShape>{{1 , 128 , 16 , 64 }, {1 , 128 , 16 , 64 }, {1 , 16 , 128 , -1 }, {1 , 128 , 16 , 64 }},
205211 std::vector<ov::element::Type>({ov::element::f32 , ov::element::f32 , ov::element::f32 , ov::element::f32 }), false , false );
206- common_config = ov::snippets::pass::CommonOptimizations::Config (32 , true );
212+ common_config = get_default_common_optimizations_config ();
213+ common_config.set_concurrency (32 );
207214 execute_and_validate_function (*this , f);
208215}
209216
210217TEST_F (TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
211218 const auto & f = MHASelectSplitMFunction (std::vector<PartialShape>{{8 , 512 , 18 }, {8 , 18 , 64 }, {1 , 512 , 64 }, {1 , 1 , 64 }, {8 , 64 , 512 }},
212219 std::vector<Shape>{{8 , 2 , 256 , 18 }, {8 , 1 , 18 , 64 }, {1 , 2 , 256 , 64 }, {1 , 1 , 1 , 64 },
213220 {8 , 1 , 64 , 512 }, {8 , 512 , 512 }});
214- common_config = ov::snippets::pass::CommonOptimizations::Config (16 , true );
221+ common_config = get_default_common_optimizations_config ();
222+ common_config.set_concurrency (16 );
215223 execute_and_validate_function (*this , f);
216224}
217225
218226TEST_F (TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM_ScalarParams) {
219227 const auto & f = MHASelectSplitMFunction (std::vector<PartialShape>{{8 , 512 , 18 }, {8 , 18 , 64 }, {1 }, {64 }, {8 , 64 , 512 }},
220228 std::vector<Shape>{{8 , 2 , 256 , 18 }, {8 , 1 , 18 , 64 }, {}, {},
221229 {8 , 1 , 64 , 512 }, {8 , 512 , 512 }});
222- common_config = ov::snippets::pass::CommonOptimizations::Config (16 , true );
230+ common_config = get_default_common_optimizations_config ();
231+ common_config.set_concurrency (16 );
223232 execute_and_validate_function (*this , f);
224233}
225234
0 commit comments