@@ -846,6 +846,8 @@ struct MHAHelper {
846846 [](size_t q_blk_rt, size_t k_blk_rt) {
847847 return std::pair<size_t , size_t >{q_blk_rt, k_blk_rt};
848848 };
849+ // Sparse attention mask pointer for current softmax kernel processing
850+ uint8_t * xattn_mask = nullptr ;
849851 if (!sparse_attention_mask.empty ()) {
850852 sparse_scale = (_sparse_mask_block_size == 0 || _sparse_mask_block_size == _block_size)
851853 ? 1
@@ -912,25 +914,6 @@ struct MHAHelper {
912914 }
913915 }
914916
915- // Instead of writing -inf directly into scores, build a softmax mask (0/-inf) and pass it to the kernel
916- DATA_TYPE* softmax_mask = nullptr ;
917- std::vector<DATA_TYPE> softmax_mask_storage;
918- if (!sparse_attention_mask.empty ()) {
919- const size_t padded_len = rnd_up (cur_kv_len, _block_size);
920- softmax_mask_storage.resize (padded_len);
921- // Initialize to -inf by default; then set positions for allowed blocks to 0
922- const DATA_TYPE neg_inf_val = static_cast <DATA_TYPE>(-std::numeric_limits<float >::infinity ());
923- std::fill (softmax_mask_storage.begin (), softmax_mask_storage.end (), neg_inf_val);
924- for (size_t k = 0 ; k < cur_kv_len; ++k) {
925- size_t k_blk = k / _block_size;
926- auto [q_m, k_m] = map_to_mask_idx (q_blk, k_blk);
927- if (sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_m, k_m)[0 ]) {
928- softmax_mask_storage[k] = static_cast <DATA_TYPE>(0 );
929- }
930- }
931- softmax_mask = softmax_mask_storage.data ();
932- }
933-
934917 for (size_t m = q_start; m < q_end; m++) {
935918 // apply attention mask & sofmax
936919 auto ncausal = (cur_kv_len - q_cnt + (m - q_start) + 1 );
@@ -948,17 +931,30 @@ struct MHAHelper {
948931 start_idx = ncausal - _sliding_window;
949932 new_causal = _sliding_window;
950933 }
934+
935+ // Handle sparse attention mask for sliding window
936+ if (!sparse_attention_mask.empty ()) {
937+ // Get the original xattn_mask and calculate offset
938+ auto * original_mask = reinterpret_cast <uint8_t *>(
939+ sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk / sparse_scale));
940+ size_t mask_start_offset = start_idx / _sparse_mask_block_size;
941+ xattn_mask = original_mask + mask_start_offset;
942+ }
943+
951944 attn_softmax_kernel<float >(score + start_idx,
952945 reinterpret_cast <DATA_TYPE*>(score) + start_idx,
953946 revised_d_scale,
954947 alibi_lookup,
955- reinterpret_cast < void *>(softmax_mask + start_idx) ,
948+ nullptr ,
956949 nullptr ,
957950 false ,
958951 new_causal,
959952 rnd_up (cur_kv_len, _block_size) - start_idx,
960953 precision_of<DATA_TYPE>::value,
961- precision_of<DATA_TYPE>::value);
954+ precision_of<DATA_TYPE>::value,
955+ 0 .f ,
956+ xattn_mask,
957+ _sparse_mask_block_size);
962958
963959 memset (score, 0 , sizeof (DATA_TYPE) * start_idx);
964960 } else {
@@ -969,18 +965,24 @@ struct MHAHelper {
969965 alibi_slope = alibi_slopes.ptr <float >()[h];
970966 alibi_lookup = _alibi_lookup.ptr <float >() + _alibi_lookup.m_dims [0 ] - ncausal;
971967 }
968+ xattn_mask = sparse_attention_mask.empty ()
969+ ? nullptr
970+ : reinterpret_cast <uint8_t *>(
971+ sparse_attention_mask[batch_in_seq].ptr <bool >(h, q_blk / sparse_scale));
972972 attn_softmax_kernel<float >(score,
973973 reinterpret_cast <DATA_TYPE*>(score),
974974 revised_d_scale,
975975 alibi_lookup,
976- reinterpret_cast < void *>(softmax_mask) ,
976+ nullptr ,
977977 nullptr ,
978978 false ,
979979 ncausal,
980980 rnd_up (cur_kv_len, _block_size),
981981 precision_of<DATA_TYPE>::value,
982982 precision_of<DATA_TYPE>::value,
983- alibi_slope);
983+ alibi_slope,
984+ xattn_mask,
985+ _sparse_mask_block_size);
984986 }
985987 if (score_output && m >= q_start_idx_score) {
986988 auto * score_block_ptr =
@@ -2238,15 +2240,15 @@ struct AttentionExecutor : public PagedAttentionExecutor {
22382240 // TODO: support multiple batches
22392241 for (size_t seq_idx = 0 ; seq_idx < 1 ; seq_idx++) {
22402242 if (q.size (0 ) > 1 ) {
2241- # if defined(OPENVINO_ARCH_X86_64)
2243+ # if defined(OPENVINO_ARCH_X86_64)
22422244 masks[seq_idx] = xattn_estimate (q,
22432245 k,
22442246 x_attention_block_size,
22452247 x_attention_stride,
22462248 1 ,
22472249 threshold.ptr <float >()[seq_idx],
22482250 true );
2249- # endif
2251+ # endif
22502252 }
22512253 }
22522254 return masks;
0 commit comments