Skip to content

Commit 51840cb

Browse files
committed
optimize sparse softmax mask in sparse mask attention: extend attn_softmax_kernel(float) to support sparse mask
1 parent 48dc5a3 commit 51840cb

File tree

2 files changed

+187
-66
lines changed

2 files changed

+187
-66
lines changed

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,8 @@ struct MHAHelper {
846846
[](size_t q_blk_rt, size_t k_blk_rt) {
847847
return std::pair<size_t, size_t>{q_blk_rt, k_blk_rt};
848848
};
849+
// Sparse attention mask pointer for current softmax kernel processing
850+
uint8_t* xattn_mask = nullptr;
849851
if (!sparse_attention_mask.empty()) {
850852
sparse_scale = (_sparse_mask_block_size == 0 || _sparse_mask_block_size == _block_size)
851853
? 1
@@ -912,25 +914,6 @@ struct MHAHelper {
912914
}
913915
}
914916

915-
// Instead of writing -inf directly into scores, build a softmax mask (0/-inf) and pass it to the kernel
916-
DATA_TYPE* softmax_mask = nullptr;
917-
std::vector<DATA_TYPE> softmax_mask_storage;
918-
if (!sparse_attention_mask.empty()) {
919-
const size_t padded_len = rnd_up(cur_kv_len, _block_size);
920-
softmax_mask_storage.resize(padded_len);
921-
// Initialize to -inf by default; then set positions for allowed blocks to 0
922-
const DATA_TYPE neg_inf_val = static_cast<DATA_TYPE>(-std::numeric_limits<float>::infinity());
923-
std::fill(softmax_mask_storage.begin(), softmax_mask_storage.end(), neg_inf_val);
924-
for (size_t k = 0; k < cur_kv_len; ++k) {
925-
size_t k_blk = k / _block_size;
926-
auto [q_m, k_m] = map_to_mask_idx(q_blk, k_blk);
927-
if (sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_m, k_m)[0]) {
928-
softmax_mask_storage[k] = static_cast<DATA_TYPE>(0);
929-
}
930-
}
931-
softmax_mask = softmax_mask_storage.data();
932-
}
933-
934917
for (size_t m = q_start; m < q_end; m++) {
935918
// apply attention mask & sofmax
936919
auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1);
@@ -948,17 +931,30 @@ struct MHAHelper {
948931
start_idx = ncausal - _sliding_window;
949932
new_causal = _sliding_window;
950933
}
934+
935+
// Handle sparse attention mask for sliding window
936+
if (!sparse_attention_mask.empty()) {
937+
// Get the original xattn_mask and calculate offset
938+
auto* original_mask = reinterpret_cast<uint8_t*>(
939+
sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk / sparse_scale));
940+
size_t mask_start_offset = start_idx / _sparse_mask_block_size;
941+
xattn_mask = original_mask + mask_start_offset;
942+
}
943+
951944
attn_softmax_kernel<float>(score + start_idx,
952945
reinterpret_cast<DATA_TYPE*>(score) + start_idx,
953946
revised_d_scale,
954947
alibi_lookup,
955-
reinterpret_cast<void*>(softmax_mask + start_idx),
948+
nullptr,
956949
nullptr,
957950
false,
958951
new_causal,
959952
rnd_up(cur_kv_len, _block_size) - start_idx,
960953
precision_of<DATA_TYPE>::value,
961-
precision_of<DATA_TYPE>::value);
954+
precision_of<DATA_TYPE>::value,
955+
0.f,
956+
xattn_mask,
957+
_sparse_mask_block_size);
962958

963959
memset(score, 0, sizeof(DATA_TYPE) * start_idx);
964960
} else {
@@ -969,18 +965,24 @@ struct MHAHelper {
969965
alibi_slope = alibi_slopes.ptr<float>()[h];
970966
alibi_lookup = _alibi_lookup.ptr<float>() + _alibi_lookup.m_dims[0] - ncausal;
971967
}
968+
xattn_mask = sparse_attention_mask.empty()
969+
? nullptr
970+
: reinterpret_cast<uint8_t*>(
971+
sparse_attention_mask[batch_in_seq].ptr<bool>(h, q_blk / sparse_scale));
972972
attn_softmax_kernel<float>(score,
973973
reinterpret_cast<DATA_TYPE*>(score),
974974
revised_d_scale,
975975
alibi_lookup,
976-
reinterpret_cast<void*>(softmax_mask),
976+
nullptr,
977977
nullptr,
978978
false,
979979
ncausal,
980980
rnd_up(cur_kv_len, _block_size),
981981
precision_of<DATA_TYPE>::value,
982982
precision_of<DATA_TYPE>::value,
983-
alibi_slope);
983+
alibi_slope,
984+
xattn_mask,
985+
_sparse_mask_block_size);
984986
}
985987
if (score_output && m >= q_start_idx_score) {
986988
auto* score_block_ptr =
@@ -2238,15 +2240,15 @@ struct AttentionExecutor : public PagedAttentionExecutor {
22382240
// TODO: support multiple batches
22392241
for (size_t seq_idx = 0; seq_idx < 1; seq_idx++) {
22402242
if (q.size(0) > 1) {
2241-
#if defined(OPENVINO_ARCH_X86_64)
2243+
# if defined(OPENVINO_ARCH_X86_64)
22422244
masks[seq_idx] = xattn_estimate(q,
22432245
k,
22442246
x_attention_block_size,
22452247
x_attention_stride,
22462248
1,
22472249
threshold.ptr<float>()[seq_idx],
22482250
true);
2249-
#endif
2251+
# endif
22502252
}
22512253
}
22522254
return masks;

0 commit comments

Comments
 (0)