Skip to content

Commit f6278e3

Browse files
committed
Add load endpoint and update --detach to use it
Update documentation for new load endpoint and --detach mode Signed-off-by: Eric Curtin <[email protected]>
1 parent 5435e65 commit f6278e3

File tree

8 files changed

+217
-10
lines changed

8 files changed

+217
-10
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ MODEL_RUNNER_HOST=http://localhost:13434 ./model-cli list
9393

9494
# Pull and run a model
9595
MODEL_RUNNER_HOST=http://localhost:13434 ./model-cli run ai/smollm2 "Hello, how are you?"
96+
97+
# Load a model into memory without interaction (detached mode)
98+
MODEL_RUNNER_HOST=http://localhost:13434 ./model-cli run --detach ai/smollm2
9699
```
97100

98101
#### Option 2: Using Docker
@@ -195,6 +198,11 @@ curl http://localhost:8080/engines/llama.cpp/v1/chat/completions -X POST -d '{
195198
]
196199
}'
197200

201+
# Load a model into memory (without inference)
202+
curl http://localhost:8080/engines/llama.cpp/load -X POST -d '{
203+
"model": "ai/smollm2"
204+
}'
205+
198206
# Delete a model
199207
curl http://localhost:8080/models/ai/smollm2 -X DELETE
200208

cmd/cli/commands/run.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
208208
// Create a cancellable context for the chat request
209209
// This allows us to cancel the request if the user presses Ctrl+C during response generation
210210
chatCtx, cancelChat := context.WithCancel(cmd.Context())
211-
211+
212212
// Set up signal handler to cancel the context on Ctrl+C
213213
sigChan := make(chan os.Signal, 1)
214214
signal.Notify(sigChan, syscall.SIGINT)
@@ -222,7 +222,7 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.
222222
}()
223223

224224
err := chatWithMarkdownContext(chatCtx, cmd, desktopClient, backend, model, userInput, apiKey)
225-
225+
226226
// Clean up signal handler
227227
signal.Stop(sigChan)
228228
// Do not close sigChan to avoid race condition
@@ -268,7 +268,7 @@ func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client,
268268
// Create a cancellable context for the chat request
269269
// This allows us to cancel the request if the user presses Ctrl+C during response generation
270270
chatCtx, cancelChat := context.WithCancel(cmd.Context())
271-
271+
272272
// Set up signal handler to cancel the context on Ctrl+C
273273
sigChan := make(chan os.Signal, 1)
274274
signal.Notify(sigChan, syscall.SIGINT)
@@ -283,7 +283,7 @@ func generateInteractiveBasic(cmd *cobra.Command, desktopClient *desktop.Client,
283283
}()
284284

285285
err = chatWithMarkdownContext(chatCtx, cmd, desktopClient, backend, model, userInput, apiKey)
286-
286+
287287
cancelChat()
288288
signal.Stop(sigChan)
289289
cancelChat()
@@ -615,10 +615,8 @@ func newRunCmd() *cobra.Command {
615615

616616
// Handle --detach flag: just load the model without interaction
617617
if detach {
618-
// Make a minimal request to load the model into memory
619-
err := desktopClient.Chat(backend, model, "", apiKey, func(content string) {
620-
// Silently discard output in detach mode
621-
}, false)
618+
// Load the model into memory using the new load endpoint
619+
err := desktopClient.WarmupModel(cmd.Context(), backend, model)
622620
if err != nil {
623621
return handleClientError(err, "Failed to load model")
624622
}

cmd/cli/desktop/desktop.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,56 @@ func (c *Client) handleQueryError(err error, path string) error {
799799
return fmt.Errorf("error querying %s: %w", path, err)
800800
}
801801

802+
// WarmupModel loads a model into memory without performing inference.
803+
// This is useful for warming up models in detached mode.
804+
func (c *Client) WarmupModel(ctx context.Context, backend, model string) error {
805+
model = dmrm.NormalizeModelName(model)
806+
if !strings.Contains(strings.Trim(model, "/"), "/") {
807+
// Do an extra API call to check if the model parameter isn't a model ID.
808+
if expanded, err := c.fullModelID(model); err == nil {
809+
model = expanded
810+
}
811+
}
812+
813+
reqBody := struct {
814+
Model string `json:"model"`
815+
}{
816+
Model: model,
817+
}
818+
819+
jsonData, err := json.Marshal(reqBody)
820+
if err != nil {
821+
return fmt.Errorf("error marshaling request: %w", err)
822+
}
823+
824+
var loadPath string
825+
if backend != "" {
826+
loadPath = inference.InferencePrefix + "/" + backend + "/load"
827+
} else {
828+
loadPath = inference.InferencePrefix + "/load"
829+
}
830+
831+
resp, err := c.doRequestWithAuthContext(
832+
ctx,
833+
http.MethodPost,
834+
loadPath,
835+
bytes.NewReader(jsonData),
836+
backend,
837+
"", // no API key needed for local load
838+
)
839+
if err != nil {
840+
return c.handleQueryError(err, loadPath)
841+
}
842+
defer resp.Body.Close()
843+
844+
if resp.StatusCode != http.StatusOK {
845+
body, _ := io.ReadAll(resp.Body)
846+
return fmt.Errorf("load failed with status %d: %s", resp.StatusCode, body)
847+
}
848+
849+
return nil
850+
}
851+
802852
func (c *Client) Tag(source, targetRepo, targetTag string) error {
803853
source = dmrm.NormalizeModelName(source)
804854
// Check if the source is a model ID, and expand it if necessary

cmd/cli/desktop/desktop_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package desktop
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"io"
78
"net/http"
@@ -225,3 +226,57 @@ func TestInspectOpenAIHuggingFaceModel(t *testing.T) {
225226
assert.NoError(t, err)
226227
assert.Equal(t, expectedLowercase, model.ID)
227228
}
229+
230+
func TestWarmupModel(t *testing.T) {
231+
ctrl := gomock.NewController(t)
232+
defer ctrl.Finish()
233+
234+
modelName := "ai/smollm2"
235+
expectedModelName := "ai/smollm2:latest" // normalized with tag
236+
backend := "llama.cpp"
237+
238+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
239+
mockContext := NewContextForMock(mockClient)
240+
client := New(mockContext)
241+
242+
mockClient.EXPECT().Do(gomock.Any()).Do(func(req *http.Request) {
243+
// Verify the request path contains the backend
244+
assert.Contains(t, req.URL.Path, backend)
245+
assert.Contains(t, req.URL.Path, "/load")
246+
247+
// Verify the request body contains the model name
248+
var reqBody struct {
249+
Model string `json:"model"`
250+
}
251+
err := json.NewDecoder(req.Body).Decode(&reqBody)
252+
require.NoError(t, err)
253+
assert.Equal(t, expectedModelName, reqBody.Model)
254+
}).Return(&http.Response{
255+
StatusCode: http.StatusOK,
256+
Body: io.NopCloser(bytes.NewBufferString(`{"status":"loaded","message":"Model ai/smollm2 loaded successfully"}`)),
257+
}, nil)
258+
259+
err := client.WarmupModel(context.Background(), backend, modelName)
260+
assert.NoError(t, err)
261+
}
262+
263+
func TestWarmupModelWithError(t *testing.T) {
264+
ctrl := gomock.NewController(t)
265+
defer ctrl.Finish()
266+
267+
modelName := "ai/smollm2"
268+
backend := "llama.cpp"
269+
270+
mockClient := mockdesktop.NewMockDockerHttpClient(ctrl)
271+
mockContext := NewContextForMock(mockClient)
272+
client := New(mockContext)
273+
274+
mockClient.EXPECT().Do(gomock.Any()).Return(&http.Response{
275+
StatusCode: http.StatusInternalServerError,
276+
Body: io.NopCloser(bytes.NewBufferString("failed to load model")),
277+
}, nil)
278+
279+
err := client.WarmupModel(context.Background(), backend, modelName)
280+
assert.Error(t, err)
281+
assert.Contains(t, err.Error(), "load failed")
282+
}

pkg/inference/scheduling/api.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,14 @@ type ConfigureRequest struct {
9191
RuntimeFlags []string `json:"runtime-flags,omitempty"`
9292
RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"`
9393
}
94+
95+
// LoadRequest specifies the model to load into memory.
96+
type LoadRequest struct {
97+
Model string `json:"model"`
98+
}
99+
100+
// LoadResponse indicates whether the model was loaded successfully.
101+
type LoadResponse struct {
102+
Status string `json:"status"`
103+
Message string `json:"message,omitempty"`
104+
}

pkg/inference/scheduling/loader.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ func (l *loader) load(ctx context.Context, backendName, modelID, modelRef string
469469
return l.slots[existing.slot], nil
470470
}
471471
}
472-
472+
473473
if runtime.GOOS == "windows" {
474474
// On Windows, we can use up to half of the total system RAM as shared GPU memory,
475475
// limited by the currently available RAM.

pkg/inference/scheduling/scheduler.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ func (s *Scheduler) routeHandlers() map[string]http.HandlerFunc {
125125
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
126126
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
127127
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
128+
m["POST "+inference.InferencePrefix+"/{backend}/load"] = s.Load
129+
m["POST "+inference.InferencePrefix+"/load"] = s.Load
128130
m["GET "+inference.InferencePrefix+"/requests"] = s.openAIRecorder.GetRecordsHandler()
129131
return m
130132
}
@@ -431,6 +433,89 @@ func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) {
431433
w.WriteHeader(http.StatusAccepted)
432434
}
433435

436+
// Load handles loading a model into memory without performing inference.
437+
// This is useful for warming up models in detached mode.
438+
func (s *Scheduler) Load(w http.ResponseWriter, r *http.Request) {
439+
// Determine the requested backend and ensure that it's valid.
440+
var backend inference.Backend
441+
if b := r.PathValue("backend"); b == "" {
442+
backend = s.defaultBackend
443+
} else {
444+
backend = s.backends[b]
445+
}
446+
if backend == nil {
447+
http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound)
448+
return
449+
}
450+
451+
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
452+
if err != nil {
453+
if _, ok := err.(*http.MaxBytesError); ok {
454+
http.Error(w, "request too large", http.StatusBadRequest)
455+
} else {
456+
http.Error(w, "unknown error", http.StatusInternalServerError)
457+
}
458+
return
459+
}
460+
461+
var loadRequest LoadRequest
462+
if err := json.Unmarshal(body, &loadRequest); err != nil {
463+
http.Error(w, "invalid request", http.StatusBadRequest)
464+
return
465+
}
466+
467+
if loadRequest.Model == "" {
468+
http.Error(w, "model name is required", http.StatusBadRequest)
469+
return
470+
}
471+
472+
// Wait for the backend installation to complete
473+
if err := s.installer.wait(r.Context(), backend.Name()); err != nil {
474+
if errors.Is(err, ErrBackendNotFound) {
475+
http.Error(w, err.Error(), http.StatusNotFound)
476+
} else if errors.Is(err, errInstallerNotStarted) {
477+
http.Error(w, err.Error(), http.StatusServiceUnavailable)
478+
} else if errors.Is(err, context.Canceled) {
479+
http.Error(w, "service unavailable", http.StatusServiceUnavailable)
480+
} else {
481+
http.Error(w, fmt.Errorf("backend installation failed: %w", err).Error(), http.StatusServiceUnavailable)
482+
}
483+
return
484+
}
485+
486+
// Resolve the model ID
487+
modelID := s.modelManager.ResolveModelID(loadRequest.Model)
488+
489+
// Load the model using the loader (default to completion mode)
490+
mode := inference.BackendModeCompletion
491+
runner, err := s.loader.load(r.Context(), backend.Name(), modelID, loadRequest.Model, mode)
492+
if err != nil {
493+
s.log.Warnf("Failed to load model %s (%s): %v", loadRequest.Model, modelID, err)
494+
if errors.Is(err, errModelTooBig) {
495+
http.Error(w, "model too big for available memory", http.StatusInsufficientStorage)
496+
} else if errors.Is(err, context.Canceled) {
497+
http.Error(w, "request canceled", http.StatusRequestTimeout)
498+
} else {
499+
http.Error(w, fmt.Sprintf("failed to load model: %v", err), http.StatusInternalServerError)
500+
}
501+
return
502+
}
503+
504+
// Release the runner immediately since we're just loading it, not using it
505+
defer s.loader.release(runner)
506+
507+
// Return success response
508+
response := LoadResponse{
509+
Status: "loaded",
510+
Message: fmt.Sprintf("Model %s loaded successfully", loadRequest.Model),
511+
}
512+
w.Header().Set("Content-Type", "application/json")
513+
w.WriteHeader(http.StatusOK)
514+
if err := json.NewEncoder(w).Encode(response); err != nil {
515+
s.log.Warnf("Failed to encode load response: %v", err)
516+
}
517+
}
518+
434519
// GetAllActiveRunners returns information about all active runners
435520
func (s *Scheduler) GetAllActiveRunners() []metrics.ActiveRunner {
436521
runningBackends := s.getLoaderStatus(context.Background())

pkg/inference/scheduling/scheduler_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
type systemMemoryInfo struct{}
1414

1515
func (i systemMemoryInfo) HaveSufficientMemory(req inference.RequiredMemory) (bool, error) {
16-
return true, nil
16+
return true, nil
1717
}
1818

1919
func (i systemMemoryInfo) GetTotalMemory() inference.RequiredMemory {

0 commit comments

Comments
 (0)