Skip to content

Commit e2f222c

Browse files
committed
llama : add option to skip the compute of a batch
1 parent cd5e3b5 commit e2f222c

File tree

4 files changed

+22
-0
lines changed

4 files changed

+22
-0
lines changed

include/llama.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,8 @@ extern "C" {
907907
// Set abort callback
908908
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
909909

910+
LLAMA_API void llama_set_skip_compute(struct llama_context * ctx, bool val);
911+
910912
// Wait until all computations are finished
911913
// This is automatically done when using one of the functions below to obtain the computation results
912914
// and is not necessary to call it explicitly in most cases

src/llama-context.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,12 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void
691691
}
692692
}
693693

694+
void llama_context::set_skip_compute(bool val) {
695+
LLAMA_LOG_DEBUG("%s: val = %d\n", __func__, val);
696+
697+
skip_compute = val;
698+
}
699+
694700
void llama_context::set_embeddings(bool value) {
695701
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
696702

@@ -799,6 +805,10 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
799805
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
800806
}
801807

808+
if (skip_compute) {
809+
return res;
810+
}
811+
802812
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
803813
if (status != GGML_STATUS_SUCCESS) {
804814
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
@@ -2447,6 +2457,10 @@ void llama_set_abort_callback(llama_context * ctx, bool (*abort_callback)(void *
24472457
ctx->set_abort_callback(abort_callback, abort_callback_data);
24482458
}
24492459

2460+
void llama_set_skip_compute(llama_context * ctx, bool val) {
2461+
ctx->set_skip_compute(val);
2462+
}
2463+
24502464
void llama_set_embeddings(llama_context * ctx, bool embeddings) {
24512465
ctx->set_embeddings(embeddings);
24522466
}

src/llama-context.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ struct llama_context {
7676

7777
void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
7878

79+
void set_skip_compute(bool val);
80+
7981
void set_embeddings (bool value);
8082
void set_causal_attn(bool value);
8183
void set_warmup(bool value);
@@ -279,6 +281,8 @@ struct llama_context {
279281
ggml_abort_callback abort_callback = nullptr;
280282
void * abort_callback_data = nullptr;
281283

284+
bool skip_compute = false; // skip the actual computation of the model (useful for benchmarking)
285+
282286
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
283287

284288
// buffer types used for the compute buffer of each backend

tools/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,11 +2138,13 @@ int main(int argc, char ** argv) {
21382138
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
21392139
i + 1, params.reps);
21402140
}
2141+
llama_set_skip_compute(ctx, true);
21412142
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
21422143
if (!res) {
21432144
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
21442145
exit(1);
21452146
}
2147+
llama_set_skip_compute(ctx, false);
21462148
}
21472149

21482150
uint64_t t_start = get_time_ns();

0 commit comments

Comments
 (0)