Skip to content

Commit 9e4cbd5

Browse files
committed
bench : cache llama_context state at depth
1 parent 070ff4d commit 9e4cbd5

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

tools/llama-bench/llama-bench.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,6 +1919,12 @@ struct sql_printer : public printer {
19191919
}
19201920
};
19211921

1922+
struct ctx_state {
1923+
int depth = 0; // in tokens
1924+
1925+
std::vector<uint8_t> buf; // the llama_context state buffer
1926+
};
1927+
19221928
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
19231929
llama_set_n_threads(ctx, n_threads, n_threads);
19241930

@@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) {
20512057
llama_model * lmodel = nullptr;
20522058
const cmd_params_instance * prev_inst = nullptr;
20532059

2060+
// store the llama_context state at the maximum depth that we have encountered so far
2061+
// ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
2062+
ctx_state cstate;
2063+
20542064
int params_idx = 0;
20552065
auto params_count = params_instances.size();
20562066
for (const auto & inst : params_instances) {
@@ -2134,14 +2144,27 @@ int main(int argc, char ** argv) {
21342144
llama_memory_clear(llama_get_memory(ctx), false);
21352145

21362146
if (t.n_depth > 0) {
2137-
if (params.progress) {
2138-
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
2139-
i + 1, params.reps);
2140-
}
2141-
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
2142-
if (!res) {
2143-
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
2144-
exit(1);
2147+
if (t.n_depth <= cstate.depth) {
2148+
// if previously we have computed at this depth, just restore the state
2149+
llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
2150+
2151+
// trim out any extra tokens from the old run
2152+
llama_memory_seq_rm(llama_get_memory(ctx), 0, t.n_depth, -1);
2153+
} else {
2154+
if (params.progress) {
2155+
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
2156+
i + 1, params.reps);
2157+
}
2158+
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
2159+
if (!res) {
2160+
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
2161+
exit(1);
2162+
}
2163+
2164+
// store the context state for reuse in later runs
2165+
cstate.depth = t.n_depth;
2166+
cstate.buf.resize(llama_state_seq_get_size(ctx, 0));
2167+
llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
21452168
}
21462169
}
21472170

0 commit comments

Comments
 (0)