Skip to content

Commit 167b136

Browse files
committed
Add /api/generate endpoint for model loading and unloading
So we can load and unload models Signed-off-by: Eric Curtin <[email protected]>
1 parent d94f4f8 commit 167b136

File tree

5 files changed

+206
-0
lines changed

5 files changed

+206
-0
lines changed

cmd/cli/commands/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ func NewRootCmd(cli *command.DockerCli) *cobra.Command {
113113
newConfigureCmd(),
114114
newPSCmd(),
115115
newDFCmd(),
116+
newStopCmd(),
116117
newUnloadCmd(),
117118
newRequestsCmd(),
118119
)

cmd/cli/commands/stop.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package commands
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/docker/model-runner/cmd/cli/commands/completion"
7+
"github.com/docker/model-runner/cmd/cli/desktop"
8+
"github.com/docker/model-runner/pkg/inference/models"
9+
"github.com/spf13/cobra"
10+
)
11+
12+
func newStopCmd() *cobra.Command {
13+
var backend string
14+
15+
const cmdArgs = "MODEL"
16+
c := &cobra.Command{
17+
Use: "stop " + cmdArgs,
18+
Short: "Stop a running model",
19+
RunE: func(cmd *cobra.Command, args []string) error {
20+
model := models.NormalizeModelName(args[0])
21+
unloadResp, err := desktopClient.Unload(desktop.UnloadRequest{Backend: backend, Models: []string{model}})
22+
if err != nil {
23+
err = handleClientError(err, "Failed to stop model")
24+
return handleNotRunningError(err)
25+
}
26+
unloaded := unloadResp.UnloadedRunners
27+
if unloaded == 0 {
28+
cmd.Println("No such model running.")
29+
} else {
30+
cmd.Printf("Stopped %d model(s).\n", unloaded)
31+
}
32+
return nil
33+
},
34+
ValidArgsFunction: completion.NoComplete,
35+
}
36+
c.Args = func(cmd *cobra.Command, args []string) error {
37+
if len(args) < 1 {
38+
return fmt.Errorf(
39+
"'docker model stop' requires MODEL.\\n\\n" +
40+
"Usage: docker model stop " + cmdArgs + "\\n\\n" +
41+
"See 'docker model stop --help' for more information.",
42+
)
43+
}
44+
return nil
45+
}
46+
c.Flags().StringVar(&backend, "backend", "", "Optional backend to target")
47+
return c
48+
}

main.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,17 @@ func main() {
155155
router.Handle(inference.ModelsPrefix+"/", modelManager)
156156
router.Handle(inference.InferencePrefix+"/", scheduler)
157157

158+
// Add API endpoints by creating a custom handler
159+
apiHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
160+
switch r.URL.Path {
161+
case "/api/generate":
162+
scheduler.HandleGenerate(w, r)
163+
default:
164+
http.NotFound(w, r)
165+
}
166+
})
167+
router.Handle("/api/generate", apiHandler)
168+
158169
// Add metrics endpoint if enabled
159170
if os.Getenv("DISABLE_METRICS") != "1" {
160171
metricsHandler := metrics.NewAggregatedMetricsHandler(

pkg/inference/scheduling/api.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,33 @@ type ConfigureRequest struct {
9393
RuntimeFlags []string `json:"runtime-flags,omitempty"`
9494
RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"`
9595
}
96+
97+
// GenerateRequest represents the request structure for /api/generate endpoint
98+
type GenerateRequest struct {
99+
Model string `json:"model"`
100+
Prompt string `json:"prompt"`
101+
System string `json:"system,omitempty"`
102+
Template string `json:"template,omitempty"`
103+
Context []int `json:"context,omitempty"`
104+
Stream *bool `json:"stream,omitempty"`
105+
Raw bool `json:"raw,omitempty"`
106+
KeepAlive *int `json:"keep_alive,omitempty"`
107+
Options map[string]interface{} `json:"options,omitempty"`
108+
}
109+
110+
// GenerateResponse represents the response structure for /api/generate endpoint
111+
type GenerateResponse struct {
112+
Model string `json:"model"`
113+
CreatedAt time.Time `json:"created_at"`
114+
Response string `json:"response"`
115+
Done bool `json:"done"`
116+
DoneReason string `json:"done_reason,omitempty"`
117+
Context []int `json:"context,omitempty"`
118+
TotalDuration int64 `json:"total_duration,omitempty"`
119+
LoadDuration int64 `json:"load_duration,omitempty"`
120+
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
121+
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
122+
EvalCount int `json:"eval_count,omitempty"`
123+
EvalDuration int64 `json:"eval_duration,omitempty"`
124+
}
125+

pkg/inference/scheduling/scheduler.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,122 @@ func (s *Scheduler) handleModels(w http.ResponseWriter, r *http.Request) {
535535
s.modelManager.ServeHTTP(w, r)
536536
}
537537

538+
// HandleGenerate handles /api/generate requests
539+
// If prompt is empty, loads the model into memory
540+
// If prompt is empty and keep_alive is 0, unloads the model
541+
func (s *Scheduler) HandleGenerate(w http.ResponseWriter, r *http.Request) {
542+
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
543+
if err != nil {
544+
if _, ok := err.(*http.MaxBytesError); ok {
545+
http.Error(w, "request too large", http.StatusBadRequest)
546+
} else {
547+
http.Error(w, "unknown error", http.StatusInternalServerError)
548+
}
549+
return
550+
}
551+
552+
var request GenerateRequest
553+
if err := json.Unmarshal(body, &request); err != nil {
554+
http.Error(w, "invalid request", http.StatusBadRequest)
555+
return
556+
}
557+
558+
if request.Model == "" {
559+
http.Error(w, "model is required", http.StatusBadRequest)
560+
return
561+
}
562+
563+
// Check if it's a load/unload request (empty prompt)
564+
if request.Prompt == "" {
565+
// Load request - if keep_alive is 0, it's an unload request
566+
if request.KeepAlive != nil && *request.KeepAlive == 0 {
567+
// Unload the model
568+
unloadReq := UnloadRequest{
569+
Models: []string{request.Model},
570+
Backend: "", // Use default backend
571+
}
572+
_ = UnloadResponse{s.loader.Unload(r.Context(), unloadReq)}
573+
574+
// Return unload response
575+
response := GenerateResponse{
576+
Model: request.Model,
577+
CreatedAt: time.Now().UTC(),
578+
Response: "",
579+
Done: true,
580+
DoneReason: "unload",
581+
}
582+
583+
w.Header().Set("Content-Type", "application/json")
584+
json.NewEncoder(w).Encode(response)
585+
return
586+
} else {
587+
// Load the model by requesting a minimal inference
588+
// This will trigger the loading mechanism in the loader
589+
backend := s.defaultBackend
590+
if backend == nil {
591+
http.Error(w, "no default backend available", http.StatusInternalServerError)
592+
return
593+
}
594+
595+
modelID := s.modelManager.ResolveModelID(request.Model)
596+
597+
// Request a runner to load the model - we'll do a minimal operation to trigger loading
598+
runner, err := s.loader.load(r.Context(), backend.Name(), modelID, request.Model, inference.BackendModeCompletion)
599+
if err != nil {
600+
http.Error(w, fmt.Errorf("unable to load runner: %w", err).Error(), http.StatusInternalServerError)
601+
return
602+
}
603+
defer s.loader.release(runner)
604+
605+
// Return load response
606+
response := GenerateResponse{
607+
Model: request.Model,
608+
CreatedAt: time.Now().UTC(),
609+
Response: "",
610+
Done: true,
611+
}
612+
613+
w.Header().Set("Content-Type", "application/json")
614+
json.NewEncoder(w).Encode(response)
615+
return
616+
}
617+
}
618+
619+
// Regular generate request - convert to OpenAI format and reuse existing logic
620+
// Create an OpenAI-compatible request
621+
openAIRequest := map[string]interface{}{
622+
"model": request.Model,
623+
"prompt": request.Prompt,
624+
"stream": request.Stream,
625+
"system": request.System,
626+
"raw": request.Raw,
627+
"options": request.Options,
628+
}
629+
630+
// Add context if it exists
631+
if request.Context != nil {
632+
openAIRequest["context"] = request.Context
633+
}
634+
635+
// Add template if it exists
636+
if request.Template != "" {
637+
openAIRequest["template"] = request.Template
638+
}
639+
640+
openAIBody, err := json.Marshal(openAIRequest)
641+
if err != nil {
642+
http.Error(w, "failed to process request", http.StatusInternalServerError)
643+
return
644+
}
645+
646+
// Create a new request with the OpenAI body for forwarding
647+
upstreamRequest := r.Clone(r.Context())
648+
upstreamRequest.Body = io.NopCloser(bytes.NewReader(openAIBody))
649+
650+
// Call the existing OpenAI inference handler
651+
s.handleOpenAIInference(w, upstreamRequest)
652+
}
653+
538654
// ServeHTTP implements net/http.Handler.ServeHTTP.
539655
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
540656
s.lock.RLock()

0 commit comments

Comments
 (0)