From 7d2daf851b2526233e7d116098a9340db5a0d0d8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Nov 2025 13:24:03 +0200 Subject: [PATCH 1/3] bench : cache llama_context state at depth --- tools/llama-bench/llama-bench.cpp | 36 ++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 0de07b9811268..f6eb066f2bcb0 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -1919,6 +1919,12 @@ struct sql_printer : public printer { } }; +struct ctx_state { + int depth = 0; // in tokens + + std::vector buf; // the llama_context state buffer +}; + static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); @@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) { llama_model * lmodel = nullptr; const cmd_params_instance * prev_inst = nullptr; + // store the llama_context state at the previous depth that we performed a test + // ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721 + ctx_state cstate; + int params_idx = 0; auto params_count = params_instances.size(); for (const auto & inst : params_instances) { @@ -2134,14 +2144,24 @@ int main(int argc, char ** argv) { llama_memory_clear(llama_get_memory(ctx), false); if (t.n_depth > 0) { - if (params.progress) { - fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count, - i + 1, params.reps); - } - bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads); - if (!res) { - fprintf(stderr, "%s: error: failed to run depth\n", __func__); - exit(1); + if (t.n_depth == cstate.depth) { + // if previously we have computed at this depth, just restore the state + llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0); + } else { + if (params.progress) { + fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count, + i + 1, params.reps); + } + bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads); + if (!res) { + fprintf(stderr, "%s: error: failed to run depth\n", __func__); + exit(1); + } + + // store the context state for reuse in later runs + cstate.depth = t.n_depth; + cstate.buf.resize(llama_state_seq_get_size(ctx, 0)); + llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0); } } From a99f7cea19dd4d51e2b8f215b99d92275594887e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 3 Nov 2025 16:01:27 +0200 Subject: [PATCH 2/3] cont : handle failures to restore the old state --- tools/llama-bench/llama-bench.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index f6eb066f2bcb0..fe299482b5c08 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -2144,10 +2144,18 @@ int main(int argc, char ** argv) { llama_memory_clear(llama_get_memory(ctx), false); if (t.n_depth > 0) { - if (t.n_depth == cstate.depth) { + bool is_cached = t.n_depth == cstate.depth; + + if (is_cached) { // if previously we have computed at this depth, just restore the state - llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0); - } else { + const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0); + if (ret == 0) { + // if the old state is incompatible with the current context - reprocess from scratch + is_cached = false; + } + } + + if (!is_cached) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count, i + 1, params.reps); From 9c6bc80edad4bba913eba4c001baf79362b85a16 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 5 Nov 2025 09:19:44 +0200 Subject: [PATCH 3/3] cont : print information when the state is being reused --- tools/llama-bench/llama-bench.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index fe299482b5c08..852a512451d64 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -2170,6 +2170,11 @@ int main(int argc, char ** argv) { cstate.depth = t.n_depth; cstate.buf.resize(llama_state_seq_get_size(ctx, 0)); llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0); + } else { + if (params.progress) { + fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count, + i + 1, params.reps); + } } }