Skip to content

Commit 9fa3c12

Browse files
committed
Fix the rest of the issues
1 parent e5f1f7b commit 9fa3c12

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

src/core/reference/include/openvino/reference/xattention.hpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ class XAttentionBlockSelector {
144144
}
145145
}
146146

147+
/** Applies the softmax causal mask along the last two dimensions of the rank-3 input tensor in-place.
148+
* @param in_out_data Pointer to the softmax input values (logits).
149+
* @param in_out_shape Shape of the input tensor. Expected shape is [num_heads, num_query_tokens /
150+
* stride, num_key_tokens / stride].
151+
*/
147152
void apply_causal_mask_(T* in_out_data, const Shape& in_out_shape) {
148153
OPENVINO_ASSERT(in_out_shape.size() == 3);
149154
OPENVINO_ASSERT(in_out_shape[1] <= in_out_shape[2]);
@@ -153,7 +158,8 @@ class XAttentionBlockSelector {
153158
size_t head_offset = head_idx * in_out_shape[1] * in_out_shape[2];
154159
for (size_t query_dim_idx = 0; query_dim_idx < in_out_shape[1]; query_dim_idx++) {
155160
size_t query_dim_offset = query_dim_idx * in_out_shape[2];
156-
for (size_t key_dim_idx = key_dim - query_dim + query_dim_idx + 1; key_dim_idx < key_dim; key_dim_idx++) {
161+
for (size_t key_dim_idx = key_dim - query_dim + query_dim_idx + 1; key_dim_idx < key_dim;
162+
key_dim_idx++) {
157163
in_out_data[head_offset + query_dim_offset + key_dim_idx] = -INFINITY;
158164
}
159165
}
@@ -222,11 +228,11 @@ class XAttentionBlockSelector {
222228
}
223229

224230
/** Selects the elements of the input tensor along the last dimension, independently along the first two dimensions,
225-
* so that the elements constitute a smallest subset amounting to a sum portion no less than `threshold` of the
226-
* element sum. The last two dimensions are treated as the query-block and key-block dimensions in the context
227-
* of attention matrix scores, and the first-in-row, the "diagonal" and "non-causal" elements are
228-
* disregarded when calculating the sum. "Non-causal" elements are never preserved, while "diagonal" and
229-
* first-in-row elements are always preserved.
231+
* so that the selected elements constitute a smallest subset amounting to a sum portion no less than `threshold`
232+
* of the total "causal" element sum. "Causal" is understood in the sense of the last two dimensions being
233+
* treated as the query-block and key-block dimensions in the context of attention matrix scores. The
234+
* first-in-row, the "diagonal" and "non-causal" elements are disregarded when calculating the sum. "Non-causal"
235+
* elements are never preserved, while "diagonal" and first-in-row elements are always preserved.
230236
* @param blocked_scores_data Pointer to the blocked score input.
231237
* @param blocked_attention_scores_shape Shape of the blocked score input tensor. Expected shape is [num_heads,
232238
* num_query_tokens / block_size, num_key_tokens / block_size]
@@ -256,27 +262,30 @@ class XAttentionBlockSelector {
256262
for (size_t q_block_idx = 0; q_block_idx < blocked_attention_scores_shape[1]; q_block_idx++) {
257263
std::priority_queue<IndexAndScore> indices_and_scores_queue;
258264
double total_sum = 0.0;
265+
double cumsum = 0.0;
259266
for (size_t k_block_idx = 0; k_block_idx < blocked_attention_scores_shape[2]; k_block_idx++) {
260267
if (k_block_idx >
261268
(blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) {
262269
// Disregard non-causal blocks entirely
263270
continue;
264271
}
272+
size_t target_offset = head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx;
273+
T current_score = *(blocked_attention_scores_data + target_offset);
274+
total_sum += current_score;
275+
265276
if ((k_block_idx ==
266277
(blocked_attention_scores_shape[2] - blocked_attention_scores_shape[1] + q_block_idx)) ||
267278
k_block_idx == 0) {
268-
// We preserve first-in-row and diagonal blocks always, and do not include their score in the
269-
// cumulative sum, i.e. we only preserve the fraction of the non-diagonal blocks' attention mass
279+
// We preserve first-in-row and diagonal blocks always, and include their score in the
280+
// cumulative sum. The target for the rest of the blocks in row is to fill up the
281+
// rest of the attention mass fraction so that with the diagonal and first blocks they
282+
// comprise the `threshold` portion of the entire causal attention mass in this row
270283
retval[head_idx].insert({q_block_idx, k_block_idx});
284+
cumsum += current_score;
271285
} else {
272-
size_t target_offset =
273-
head_offset + blocked_attention_scores_shape[2] * q_block_idx + k_block_idx;
274-
T current_score = *(blocked_attention_scores_data + target_offset);
275-
total_sum += current_score;
276286
indices_and_scores_queue.push({{q_block_idx, k_block_idx}, current_score});
277287
}
278288
}
279-
double cumsum = 0.0;
280289
double required_sum = m_threshold * total_sum;
281290
while (cumsum < required_sum && !indices_and_scores_queue.empty()) {
282291
auto index_and_largest_score = indices_and_scores_queue.top();

src/core/tests/reference/xattention.cpp

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -639,10 +639,10 @@ std::vector<double> E2E_Q_DATA_8 = {
639639
ov::Shape E2E_K_SHAPE_8 = {2, 8, 2};
640640
std::vector<double> E2E_K_DATA_8 = {
641641
// clang-format off
642-
-1.2870, -1.2179, 0.0316, 0.0080, -0.6171, 1.0622, 0.3085, -0.7751,
643-
-1.3612, 0.9485, -0.0803, 0.5752, 0.1925, -0.1113, 1.4693, 0.0673,
644-
0.7422, 0.7149, -1.7684, -0.0651, -0.1925, -1.4169, 1.0030, -0.8091,
645-
-0.7934, 0.5160, -0.2543, 0.1729, -0.0687, -1.4245, 0.0758, 1.1613
642+
0.2980, 0.4959, -0.0834, 0.7015, 1.2516, 0.6656, -2.7873, 1.9731,
643+
-0.4817, 1.1117, -0.8096, -0.5397, -1.0528, 0.2869, -1.1274, 1.4849,
644+
-0.2468, -1.0449, -1.0085, -0.3389, 0.6750, 0.9095, 0.4674, 2.2321,
645+
1.3183, -0.3513, -0.3717, 0.0176, -0.2545, -0.6729, -1.1547, 0.0279
646646
// clang-format on
647647
};
648648

@@ -746,8 +746,89 @@ std::vector<E2EBlockSelectTestData> E2E_BLOCK_SELECT_TEST_CASES = {
746746
{{0, 0}, {0, 2}, {0, 4}, {1, 0}, {1, 1}, {1, 3}, {1, 5}, {2, 0}, {2, 1}, {2, 2}, {2, 3}, {2, 4}, {2, 6}, {3, 0}, {3, 1}, {3, 4}, {3, 5}, {3, 6}, {3, 7}}
747747
}
748748
// clang-format on
749+
},
750+
{
751+
E2E_Q_SHAPE_8,
752+
E2E_Q_DATA_8,
753+
E2E_K_SHAPE_16,
754+
E2E_K_DATA_16,
755+
/* threshold = */ 0.45,
756+
/* block_size = */ 2,
757+
/* stride = */ 2,
758+
759+
// clang-format off
760+
{
761+
{{0, 0}, {0, 4}, {1, 0}, {1, 5}, {2, 0}, {2, 1}, {2, 3}, {2, 6}, {3, 0}, {3, 2}, {3, 5}, {3, 7}},
762+
{{0, 0}, {0, 2}, {0, 4}, {1, 0}, {1, 5}, {2, 0}, {2, 4}, {2, 6}, {3, 0}, {3, 5}, {3, 7}}
763+
}
764+
// clang-format on
765+
},
766+
{
767+
E2E_Q_SHAPE_8,
768+
E2E_Q_DATA_8,
769+
E2E_K_SHAPE_16,
770+
E2E_K_DATA_16,
771+
/* threshold = */ 0.45,
772+
/* block_size = */ 4,
773+
/* stride = */ 2,
774+
775+
// clang-format off
776+
{
777+
{{0, 0}, {0, 2}, {1, 0}, {1, 1}, {1, 3}},
778+
{{0, 0}, {0, 2}, {1, 0}, {1, 3}}
779+
}
780+
// clang-format on
781+
},
782+
{
783+
E2E_Q_SHAPE_8,
784+
E2E_Q_DATA_8,
785+
E2E_K_SHAPE_16,
786+
E2E_K_DATA_16,
787+
/* threshold = */ 0.45,
788+
/* block_size = */ 4,
789+
/* stride = */ 4,
790+
791+
// clang-format off
792+
{
793+
{{0, 0}, {0, 2}, {1, 0}, {1, 3}},
794+
{{0, 0}, {0, 2}, {1, 0}, {1, 3}}
795+
}
796+
// clang-format on
797+
},
798+
{
799+
E2E_Q_SHAPE_8,
800+
E2E_Q_DATA_8,
801+
E2E_K_SHAPE_8,
802+
E2E_K_DATA_8,
803+
/* threshold = */ 0.5,
804+
/* block_size = */ 2,
805+
/* stride = */ 2,
806+
807+
// clang-format off
808+
{
809+
{{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 1}, {2, 2}, {3, 0}, {3, 1}, {3, 3}},
810+
{{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 2}, {3, 0}, {3, 3}}
811+
}
812+
// clang-format on
813+
},
814+
{
815+
E2E_Q_SHAPE_8,
816+
E2E_Q_DATA_8,
817+
E2E_K_SHAPE_8,
818+
E2E_K_DATA_8,
819+
/* threshold = */ 0.2,
820+
/* block_size = */ 2,
821+
/* stride = */ 2,
822+
823+
// clang-format off
824+
{
825+
{{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 2}, {3, 0}, {3, 3}},
826+
{{0, 0}, {1, 0}, {1, 1}, {2, 0}, {2, 2}, {3, 0}, {3, 3}}
827+
}
828+
// clang-format on
749829
}};
750830

831+
751832
TEST_P(XAttentionE2EBlockSelectTest, SelectsBlocksCorrectlyFromQKData) {
752833
auto test_struct = GetParam();
753834
ov::reference::XAttentionBlockSelector<double> selector(test_struct.threshold,
@@ -762,8 +843,8 @@ TEST_P(XAttentionE2EBlockSelectTest, SelectsBlocksCorrectlyFromQKData) {
762843
ASSERT_EQ(test_result.size(), test_struct.ref_retained_block_indices.size());
763844
EXPECT_EQ(test_result, test_struct.ref_retained_block_indices);
764845
for (size_t head_idx = 0; head_idx < test_result.size(); head_idx++) {
765-
std::cout << "Head " << head_idx << std::endl;
766846
if (test_result != test_struct.ref_retained_block_indices) {
847+
std::cout << "Head " << head_idx << std::endl;
767848
const auto& ref_set = test_struct.ref_retained_block_indices[head_idx];
768849
const auto& test_set = test_result[head_idx];
769850
std::cout << "ref has " << ref_set.size() << " elements, test has " << test_set.size() << std::endl;

0 commit comments

Comments
 (0)