@@ -639,10 +639,10 @@ std::vector<double> E2E_Q_DATA_8 = {
639639ov::Shape E2E_K_SHAPE_8 = {2 , 8 , 2 };
640640std::vector<double > E2E_K_DATA_8 = {
641641 // clang-format off
642- - 1.2870 , - 1.2179 , 0.0316 , 0.0080 , - 0.6171 , 1.0622 , 0.3085 , -0.7751 ,
643- -1.3612 , 0.9485 , -0.0803 , 0.5752 , 0.1925 , - 0.1113 , 1.4693 , 0.0673 ,
644- 0.7422 , 0.7149 , -1.7684 , -0.0651 , - 0.1925 , - 1.4169 , 1.0030 , - 0.8091 ,
645- - 0.7934 , 0.5160 , -0.2543 , 0.1729 , -0.0687 , -1.4245 , 0.0758 , 1.1613
642+ 0.2980 , 0.4959 , - 0.0834 , 0.7015 , 1.2516 , 0.6656 , -2.7873 , 1.9731 ,
643+ -0.4817 , 1.1117 , -0.8096 , - 0.5397 , - 1.0528 , 0.2869 , - 1.1274 , 1.4849 ,
644+ - 0.2468 , - 1.0449 , -1.0085 , -0.3389 , 0.6750 , 0.9095 , 0.4674 , 2.2321 ,
645+ 1.3183 , - 0.3513 , -0.3717 , 0.0176 , -0.2545 , -0.6729 , - 1.1547 , 0.0279
646646 // clang-format on
647647};
648648
@@ -746,8 +746,89 @@ std::vector<E2EBlockSelectTestData> E2E_BLOCK_SELECT_TEST_CASES = {
746746 {{0 , 0 }, {0 , 2 }, {0 , 4 }, {1 , 0 }, {1 , 1 }, {1 , 3 }, {1 , 5 }, {2 , 0 }, {2 , 1 }, {2 , 2 }, {2 , 3 }, {2 , 4 }, {2 , 6 }, {3 , 0 }, {3 , 1 }, {3 , 4 }, {3 , 5 }, {3 , 6 }, {3 , 7 }}
747747 }
748748 // clang-format on
749+ },
750+ {
751+ E2E_Q_SHAPE_8,
752+ E2E_Q_DATA_8,
753+ E2E_K_SHAPE_16,
754+ E2E_K_DATA_16,
755+ /* threshold = */ 0.45 ,
756+ /* block_size = */ 2 ,
757+ /* stride = */ 2 ,
758+
759+ // clang-format off
760+ {
761+ {{0 , 0 }, {0 , 4 }, {1 , 0 }, {1 , 5 }, {2 , 0 }, {2 , 1 }, {2 , 3 }, {2 , 6 }, {3 , 0 }, {3 , 2 }, {3 , 5 }, {3 , 7 }},
762+ {{0 , 0 }, {0 , 2 }, {0 , 4 }, {1 , 0 }, {1 , 5 }, {2 , 0 }, {2 , 4 }, {2 , 6 }, {3 , 0 }, {3 , 5 }, {3 , 7 }}
763+ }
764+ // clang-format on
765+ },
766+ {
767+ E2E_Q_SHAPE_8,
768+ E2E_Q_DATA_8,
769+ E2E_K_SHAPE_16,
770+ E2E_K_DATA_16,
771+ /* threshold = */ 0.45 ,
772+ /* block_size = */ 4 ,
773+ /* stride = */ 2 ,
774+
775+ // clang-format off
776+ {
777+ {{0 , 0 }, {0 , 2 }, {1 , 0 }, {1 , 1 }, {1 , 3 }},
778+ {{0 , 0 }, {0 , 2 }, {1 , 0 }, {1 , 3 }}
779+ }
780+ // clang-format on
781+ },
782+ {
783+ E2E_Q_SHAPE_8,
784+ E2E_Q_DATA_8,
785+ E2E_K_SHAPE_16,
786+ E2E_K_DATA_16,
787+ /* threshold = */ 0.45 ,
788+ /* block_size = */ 4 ,
789+ /* stride = */ 4 ,
790+
791+ // clang-format off
792+ {
793+ {{0 , 0 }, {0 , 2 }, {1 , 0 }, {1 , 3 }},
794+ {{0 , 0 }, {0 , 2 }, {1 , 0 }, {1 , 3 }}
795+ }
796+ // clang-format on
797+ },
798+ {
799+ E2E_Q_SHAPE_8,
800+ E2E_Q_DATA_8,
801+ E2E_K_SHAPE_8,
802+ E2E_K_DATA_8,
803+ /* threshold = */ 0.5 ,
804+ /* block_size = */ 2 ,
805+ /* stride = */ 2 ,
806+
807+ // clang-format off
808+ {
809+ {{0 , 0 }, {1 , 0 }, {1 , 1 }, {2 , 0 }, {2 , 1 }, {2 , 2 }, {3 , 0 }, {3 , 1 }, {3 , 3 }},
810+ {{0 , 0 }, {1 , 0 }, {1 , 1 }, {2 , 0 }, {2 , 2 }, {3 , 0 }, {3 , 3 }}
811+ }
812+ // clang-format on
813+ },
814+ {
815+ E2E_Q_SHAPE_8,
816+ E2E_Q_DATA_8,
817+ E2E_K_SHAPE_8,
818+ E2E_K_DATA_8,
819+ /* threshold = */ 0.2 ,
820+ /* block_size = */ 2 ,
821+ /* stride = */ 2 ,
822+
823+ // clang-format off
824+ {
825+ {{0 , 0 }, {1 , 0 }, {1 , 1 }, {2 , 0 }, {2 , 2 }, {3 , 0 }, {3 , 3 }},
826+ {{0 , 0 }, {1 , 0 }, {1 , 1 }, {2 , 0 }, {2 , 2 }, {3 , 0 }, {3 , 3 }}
827+ }
828+ // clang-format on
749829 }};
750830
831+
751832TEST_P (XAttentionE2EBlockSelectTest, SelectsBlocksCorrectlyFromQKData) {
752833 auto test_struct = GetParam ();
753834 ov::reference::XAttentionBlockSelector<double > selector (test_struct.threshold ,
@@ -762,8 +843,8 @@ TEST_P(XAttentionE2EBlockSelectTest, SelectsBlocksCorrectlyFromQKData) {
762843 ASSERT_EQ (test_result.size (), test_struct.ref_retained_block_indices .size ());
763844 EXPECT_EQ (test_result, test_struct.ref_retained_block_indices );
764845 for (size_t head_idx = 0 ; head_idx < test_result.size (); head_idx++) {
765- std::cout << " Head " << head_idx << std::endl;
766846 if (test_result != test_struct.ref_retained_block_indices ) {
847+ std::cout << " Head " << head_idx << std::endl;
767848 const auto & ref_set = test_struct.ref_retained_block_indices [head_idx];
768849 const auto & test_set = test_result[head_idx];
769850 std::cout << " ref has " << ref_set.size () << " elements, test has " << test_set.size () << std::endl;
0 commit comments