Skip to content

Conversation

@yael-works
Copy link
Contributor

New Attention Mechanism: SparseK Attention (CPU Backend)

This PR introduces a new attention mechanism called SparseK Attention, implemented from scratch as a new operator within the GGML framework, currently with CPU backend support.


Overview

SparseK Attention is a selective and efficient attention mechanism inspired by Flash Attention, but introduces additional sparsity through:

  • Top-K filtering – keeps only the strongest attention weights.
  • Local windowing – limits attention to a configurable local context.
  • Global stride – adds periodic global connections between tokens.

Implementation Details

  • Added new operator: GGML_OP_SPARSEK_ATTN defined in ggml.h and ggml.c.
  • Implemented construction function ggml_sparsek_attn() that creates a computation node with parameters (k_top, win_local, stride_global).
  • Added full CPU backend implementation in:
    • ggml-cpu/ops.h
    • ggml-cpu/ops.cpp
    • ggml-cpu.c

The CPU version includes:

  • Scaled dot-product computation QKᵀ / √d
  • Dynamic Top-K filtering
  • Softmax normalization
  • Multiplication with V

Next Steps

Our next goal is to extend SparseK Attention to the SYCL (GPU) backend in order to:

  • Measure and compare performance between CPU and GPU implementations.
  • Optimize kernel execution for sparse attention patterns.
  • Validate correctness and scaling on Intel GPUs.

We are submitting this initial CPU implementation first to ensure review, integration, and baseline correctness before introducing GPU acceleration.


Co-Authors

Co-authored-by: Yael Shuker ([email protected])
Co-authored-by: Gitty Burstein ([email protected])

@GittyBurstein
Copy link
Contributor

GittyBurstein commented Oct 28, 2025

Hi @CISC and @NeoZhangJianyu,

We’d appreciate it if you could review our PR implementing the new SPARSEK Attention operator.
We ran internal validation tests we created ourselves, and all passed successfully.

This contribution was developed jointly by both of us (@yael-works and @GittyBurstein ).
Please make sure the PR reflects both contributors — if needed, we can adjust the commit authors accordingly.

Thanks in advance for your time and feedback!

@CISC
Copy link
Collaborator

CISC commented Oct 28, 2025

We are talking about this SparseK, right?

@yael-works
Copy link
Contributor Author

yael-works commented Oct 28, 2025

yes! @CISC

@github-actions github-actions bot added testing Everything test related ggml changes relating to the ggml tensor library for machine learning labels Oct 28, 2025
@CISC
Copy link
Collaborator

CISC commented Oct 30, 2025

You need to rebase to fix Server CI failures, also please fix whitespaces:
https://github.com/ggml-org/llama.cpp/actions/runs/18935125175/job/54060021809

@GittyBurstein
Copy link
Contributor

Hi @CISC,
Just to clarify — the failing tests are unrelated to my changes.
This PR only introduces the new SPARSEK Attention operator within GGML and doesn’t modify any existing server or inference logic.

I’d really appreciate it if you could review the code itself so we can move forward with the merge —
all SPARSEK-related tests are passing successfully.

Thanks!

@CISC
Copy link
Collaborator

CISC commented Oct 31, 2025

Hi @CISC, Just to clarify — the failing tests are unrelated to my changes. This PR only introduces the new SPARSEK Attention operator within GGML and doesn’t modify any existing server or inference logic.

Yes, as mentioned, will be resolved if you rebase, it's ok. :)

I’d really appreciate it if you could review the code itself so we can move forward with the merge — all SPARSEK-related tests are passing successfully.

So, my main challenge is where/what/when will SparseK be used? I can't recall seeing any actual implementation being used in the wild. This also means we don't really have any reference to test it against...

@GittyBurstein
Copy link
Contributor

GittyBurstein commented Oct 31, 2025

@CISC
The current PR focuses solely on adding the SparseK Attention operator at the GGML level (CPU backend).
At this stage, it isn’t directly integrated into the model’s runtime pipeline — it’s designed as a standalone operator for experimentation and future extensions.

Once this PR is merged, the operator can be connected to higher-level use cases such as:

  • selective attention mechanisms for long-context models,

  • experimental low-latency or memory-efficient inference,

  • or research benchmarking against variants like Flash Attention or block-sparse implementations....
    Do you have any other idea that could demonstrate or validate this even better?

Thank you!!

@CISC
Copy link
Collaborator

CISC commented Oct 31, 2025

I think @ggerganov will have to weigh in on this.

@ggerganov
Copy link
Member

Sparse attention implementations such as DSA and SparseK should leverage the existing FA implementations and mask filtering logic. No need to introduce new operators and duplicate all the existing work that already went into optimizing FA.

@yael-works yael-works force-pushed the feature/sparsek-attn-sycl branch from 77f4088 to 22c063e Compare November 2, 2025 09:53
@yael-works
Copy link
Contributor Author

Hi @ggerganov and @CISC,
The branch has been successfully rebased on the latest master.
All SparseK Attention tests are passing, and the PR is ready for final review and merge.
Thanks for the feedback and support!
— Yael & Gitty

@yael-works yael-works force-pushed the feature/sparsek-attn-sycl branch from 16d7eee to 556ab36 Compare November 3, 2025 09:21
@yael-works
Copy link
Contributor Author

Hi @ggerganov and @CISC,
Following @ggerganov’s feedback, we refactored SparseK to reuse the existing FlashAttention logic rather than maintaining a separate operator.
The new design integrates SparseK’s sparsity mechanism (Top-K + local + stride) within the FlashAttention extension path.
This keeps the optimization benefits of FlashAttention while allowing selective sparse attention behavior — all tested and validated on CPU backend.

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea was more along the following lines:

  • Sparse attention implementations should somehow compute a sparse KQ mask. Depending on the specifics (e.g. local windows, top-k product, deepseek lightning stuff, etc.) this can be done in different way, but generally it should require some extra logic when constructing the compute graph
  • Then we pass the sparse KQ mask (i.e. a normal mask but with extra -INF values where we don't have to compute the attention) to ggml_flash_attn_ext and we delegate the filtering logic to the backend implementation. For example, the Metal backend will already skip large amount of the filtered values depending on the KQ mask contents (#16372). Similar or better logic can be added to the other backend implementations.

I think at most, the only change to the existing ggml_flash_attn_ext API would be to provide a "mask hint" that would inform the backend what kind of mask to expect (causal, sparse, etc.). End the rest of the changes should be at the compute graph level and at the backend implementation for filtering the -INF values. Let me know if this makes sense.

@GittyBurstein
Copy link
Contributor

@ggerganov
Before we start implementing, we want to make sure we understand correctly —
We’re not creating a separate operator for SparseK at all, but instead just adding a mask that integrates with ggml_flash_attn_ext, right?

And if that’s the case, where exactly should the mask implementation be added — inside the compute graph logic, or only for testing (e.g., in test-backend-ops)?
thanks!
Yael & Gitty

@ggerganov
Copy link
Member

We’re not creating a separate operator for SparseK at all, but instead just adding a mask that integrates with ggml_flash_attn_ext, right?

In llama.cpp, the mask is already being created and passed to ggml_flash_attn_ext. Currently, we populate the mask outside of the compute graph because it is static - i.e. depends only on the token positions in the sequences:

void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data;
const int64_t n_kv = dst->ne[0];
const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch
GGML_ASSERT(n_tokens%n_stream == 0);
// n_tps == n_tokens_per_stream
const int64_t n_tps = n_tokens/n_stream;
const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD);
std::fill(data, data + ggml_nelements(dst), -INFINITY);
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
// Causal mask:
// xxx-------
// xxxx------
// xxxxx-----
// Non-causal mask:
// xxxxx-----
// xxxxx-----
// xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
// TODO: optimize this section
for (uint32_t h = 0; h < 1; ++h) {
for (uint32_t s = 0; s < n_stream; ++s) {
for (uint32_t ii = 0; ii < n_tps; ++ii) {
const uint32_t i = s*n_tps + ii;
const llama_seq_id seq_id = ubatch->seq_id[i][0];
const auto & cells = v_cells[seq_to_stream[seq_id]];
const llama_pos p1 = ubatch->pos[i];
// for M-RoPE
const bool is_2d = ubatch->is_pos_2d();
const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0;
const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0;
const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii);
for (uint32_t j = 0; j < n_kv; ++j) {
if (cells.is_empty(j)) {
continue;
}
// mask the token if not the same sequence
if (!cells.seq_has(j, seq_id)) {
continue;
}
const llama_pos p0 = cells.pos_get(j);
// mask future tokens
if (causal_attn && p0 > p1) {
continue;
}
// M-RoPE causal mask
if (causal_attn && is_2d && p0 == p1) {
const auto & p0_ext = cells.ext_get(j);
if (p0_ext.is_2d_gt(p1_x, p1_y)) {
continue;
}
}
// apply SWA if any
if (is_masked_swa(p0, p1)) {
continue;
}
data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f;
}
}
}
}
}

I think that the sparse attention implementations should augment this static mask through some extra logic. This extra logic should be implemented for example in the llm_graph_context::build_attn methods. This specific logic could potentially require some new ggml operators, but in general it boils down to setting certain elements of the kq_mask tensor to -INF in some way.

From there, the FA implementations will deal with the provided mask in their own way (i.e. by skipping computations when possible).

And if that’s the case, where exactly should the mask implementation be added — inside the compute graph logic, or only for testing (e.g., in test-backend-ops)?

For testing, you can already take a look how we create KQ masks with blocks of -INF values here:

// generate an F16 mask where certain blocks are randomly masked with -INF value
static void init_tensor_kq_mask(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
GGML_ASSERT(tensor->type == GGML_TYPE_F16);
GGML_TENSOR_LOCALS( int32_t, ne, tensor, ne);
std::vector<float> data_f32(ne0*ne1*ne2*ne3);
std::vector<ggml_fp16_t> data_f16(ne0*ne1*ne2*ne3);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(min, max);
for (size_t i = 0; i < data_f32.size(); i++) {
data_f32[i] = dis(gen);
}
// block size
const int blck0 = 128;
const int blck1 = 64;
// number of INF blocks
const int n_inf_blocks = 0.1*(ne0*ne1*ne2*ne3)/(blck0*blck1);
for (int b = 0; b < n_inf_blocks; b++) {
const int p3 = (rd() % ne3);
const int p2 = (rd() % ne2);
const int p1 = (rd() % ne1);
const int p0 = (rd() % ne0);
for (int i1 = 0; i1 < blck1 && p1 + i1 < ne1; i1++) {
const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0;
for (int i0 = 0; i0 < blck0 && p0 + i0 < ne0; i0++) {
data_f32[idx + i0] = -INFINITY;
}
}
}
ggml_fp32_to_fp16_row(data_f32.data(), data_f16.data(), ne0*ne1*ne2*ne3);
ggml_backend_tensor_set(tensor, data_f16.data(), 0, data_f16.size()*sizeof(ggml_fp16_t));
}

I imagine that we would need tests that create various sorts of sparse masks and simply run ggml_flash_attn_ext as we do now. And also additional tests as needed, depending on what new operators for constructing these sparse masks are introduced.

yael-works and others added 2 commits November 13, 2025 17:40
…oader support

Includes only implementation files:
- llama-graph: dynamic SparseK mask builder + integration point
- llama-model: GGUF key loading for SparseK parameters
- llama-model-loader: template instantiations for bool keys
- llama-hparams: new SparseK fields
- convert_hf_to_gguf.py: emit SparseK keys in GGUF
Co-authored-by: Gitty Burstein <[email protected]>

Co-authored-by: Yael Shuker <[email protected]>
Comment on lines 798 to 802
if (n_head := self.find_hparam(["num_attention_heads", "n_head", "n_heads"], optional=True)) is not None:
self.gguf_writer.add_head_count(n_head)
logger.info(f"gguf: head count = {n_head}")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restore this please.

Comment on lines 802 to 807
# === SparseK dynamic attention metadata ===
self.gguf_writer.add_key("llama.sparsek.enable", int(self.hparams.get("sparsek_enable", 0)))
self.gguf_writer.add_key("llama.sparsek.top_k", int(self.hparams.get("sparsek_topk", 0)))
self.gguf_writer.add_key("llama.sparsek.window", int(self.hparams.get("sparsek_window", 0)))
self.gguf_writer.add_key("llama.sparsek.stride", int(self.hparams.get("sparsek_stride", 0)))
# ============================================
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was not what @ggerganov meant, I don't believe these are actual config values from a real model.

Additionally, don't use add_key directly and certainly not the llama namespace like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I fixed it!

@github-actions github-actions bot added the python python script changes label Nov 13, 2025
Gitty Burstein and others added 4 commits November 14, 2025 02:00
Co-authored-by: Gitty Burstein <[email protected]>

Co-authored-by: Yael Shuker <[email protected]>
Co-authored-by: Gitty Burstein <[email protected]>

Co-authored-by: Yael Shuker <[email protected]>
Co-authored-by: Gitty Burstein <[email protected]>

Co-authored-by: Yael Shuker <[email protected]>
@yael-works yael-works force-pushed the feature/sparsek-attn-sycl branch from 5c3c65c to 5798c33 Compare November 16, 2025 11:25
@@ -0,0 +1,282 @@
// tests/test-sparsek_kq_mask.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As SparseK is special case of attention, no need to create new test case cpp file.

  1. Move it in test-backend-ops.cpp.
  2. Follow the existed framework of test-backend-ops.cpp.
    like test_cases.emplace_back(new test_flash_attn_ext( hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV))
    It help to create more cases for OPs and extend to support GPU in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we did it!


buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
// increase meta buffer slightly to accommodate extra nodes from SparseK
int64_t max_nodes_ex = max_nodes + 128; // safety headroom
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to define max_nodes_ex here.
If you feel max_nodes is not big enough, increase it directly.
max_nodes & max_nodes_ex will make next developer confused.

+128 is strange too.
If max_nodes are not correct in same case, we should fix it directly.
+128 will hide some issue. If it's not enough for new mode, the crash issue won't be fixed easily.

If you have no strong reason to approve max_nodes is wrong, suggest not change it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After I removed the extra +128 headroom, the model started crashing for me.
I’m trying to understand where the actual memory pools / buffer sizes are defined and managed in the code.
Where is the correct place to adjust the memory allocations, instead of relying on a hard-coded offset?

Co-authored-by: Yael Shuker <[email protected]>
Co-authored-by: Gitty Burstein <[email protected]>
@yael-works yael-works force-pushed the feature/sparsek-attn-sycl branch from a9d2015 to 060ee50 Compare November 17, 2025 11:54
@yael-works
Copy link
Contributor Author

yael-works commented Nov 17, 2025

Hi @ggerganov @NeoZhangJianyu @CISC

Summary of Updates

  • Removed the dedicated operator GGML_OP_SPARSEK_ATTN — Sparse-K is now integrated exclusively through the dynamic mask inside ggml_flash_attn_ext.
  • All Sparse-K parameters are now sourced strictly from GGUF metadata (no environment variables).
  • Cleaned legacy code paths, removed unnecessary reshapes, and reduced the overall graph-node count.
  • Added a dedicated KQ-mask test in test-backend-ops.cpp.

Performance (CODEGEMMA — Sparse-K vs Baseline)

n_ctx Prompt Eval Total Time
1024 2.3× faster ~16% improvement

Additional Notes

  • No memory footprint change.
  • No eval-per-token regression (expected for batch=1).
  • Sparse-K provides the largest gains during prompt ingestion, where attention dominates runtime.
  • Reverting the graph-node reduction to its original size causes runtime crashes — the optimization is required for stability.
  • Benchmarked with and without Sparse-K on CODEGEMMA under identical conditions.

Benchmark Command Used

./build/bin/llama-cli \
  -m /home/gitty/models/codegemma-sparsek-Q4_K_M.gguf \
  -ngl 0 \
  -c 256 \
  -n 400 \
  -t 8

@NeoZhangJianyu
Copy link
Collaborator

Added full CPU backend implementation in:

  • ggml-cpu/ops.h
  • ggml-cpu/ops.cpp
  • ggml-cpu.c

It's good to see the improvement of this PR and shared info of the result.

Could you update these info in the description of PR too?
replace some wrong info, like:

Added full CPU backend implementation in:
ggml-cpu/ops.h
ggml-cpu/ops.cpp
ggml-cpu.c

@GittyBurstein
Copy link
Contributor

Thank you for the feedback!
@NeoZhangJianyu — we will update the comment accordingly.
Is there anything else you would like us to adjust, and what would you suggest as the next steps for us?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning python python script changes testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants