diff --git a/cmd/cli/commands/backend.go b/cmd/cli/commands/backend.go index c13a74f9..74eeb1ce 100644 --- a/cmd/cli/commands/backend.go +++ b/cmd/cli/commands/backend.go @@ -15,6 +15,19 @@ var ValidBackends = map[string]bool{ "openai": true, } +// ServerPreset represents a preconfigured server endpoint +type ServerPreset struct { + Name string + URL string +} + +// ServerPresets defines the available server presets +var ServerPresets = []ServerPreset{ + {"llamacpp", "http://127.0.0.1:8080/v1"}, + {"ollama", "http://127.0.0.1:11434/v1"}, + {"openrouter", "https://openrouter.ai/api/v1"}, +} + // validateBackend checks if the provided backend is valid func validateBackend(backend string) error { if !ValidBackends[backend] { @@ -28,14 +41,82 @@ func validateBackend(backend string) error { func ensureAPIKey(backend string) (string, error) { if backend == "openai" { apiKey := os.Getenv("OPENAI_API_KEY") - if apiKey == "" { - return "", errors.New("OPENAI_API_KEY environment variable is required when using --backend=openai") + if apiKey != "" { + return apiKey, nil } - return apiKey, nil } return "", nil } +// resolveServerURL determines the server URL from flags +// Returns: (url, useOpenAI, apiKey, error) +func resolveServerURL(host, customURL, urlAlias string, port int) (string, bool, string, error) { + // Count how many server options are specified + presetCount := 0 + if urlAlias != "" { + presetCount++ + } + if customURL != "" { + presetCount++ + } + + // Check for conflicting options + if presetCount > 1 { + return "", false, "", errors.New("only one of --url or --url-alias can be specified") + } + + // Check for conflicting host/port with URL/preset options + hostPortSpecified := host != "" || port != 0 + urlPresetSpecified := customURL != "" || urlAlias != "" + + if hostPortSpecified && urlPresetSpecified { + return "", false, "", errors.New("cannot specify both --host/--port and --url/--url-alias options") + } + + // Resolve the URL + var serverURL string + useOpenAI := false + apiKey := "" + + if customURL != "" { + serverURL = customURL + useOpenAI = true + } else if urlAlias != "" { + // Find the matching preset + found := false + for _, preset := range ServerPresets { + if preset.Name == urlAlias { + serverURL = preset.URL + useOpenAI = true + found = true + break + } + } + if !found { + return "", false, "", fmt.Errorf("invalid url-alias '%s'. Valid options are: llamacpp, ollama, openrouter", urlAlias) + } + + apiKey = os.Getenv("OPENAI_API_KEY") + } else if hostPortSpecified { + // Use custom host/port for model-runner endpoint + if host == "" { + host = "127.0.0.1" + } + if port == 0 { + port = 12434 + } + serverURL = fmt.Sprintf("http://%s:%d", host, port) + useOpenAI = false + } + + // For OpenAI-compatible endpoints, check for API key (optional for most, required for openrouter) + if useOpenAI && apiKey == "" { + apiKey = os.Getenv("OPENAI_API_KEY") + } + + return serverURL, useOpenAI, apiKey, nil +} + func ValidBackendsKeys() string { keys := slices.Collect(maps.Keys(ValidBackends)) slices.Sort(keys) diff --git a/cmd/cli/commands/backend_test.go b/cmd/cli/commands/backend_test.go new file mode 100644 index 00000000..84ed0c0f --- /dev/null +++ b/cmd/cli/commands/backend_test.go @@ -0,0 +1,243 @@ +package commands + +import ( + "os" + "testing" +) + +func TestResolveServerURL(t *testing.T) { + tests := []struct { + name string + host string + customURL string + urlAlias string + port int + expectURL string + expectOAI bool + wantErr bool + setupEnv func() + cleanupEnv func() + }{ + { + name: "no flags specified", + expectURL: "", + expectOAI: false, + wantErr: false, + }, + { + name: "host and port specified", + host: "192.168.1.1", + port: 8080, + expectURL: "http://192.168.1.1:8080", + expectOAI: false, + wantErr: false, + }, + { + name: "only host specified", + host: "192.168.1.1", + expectURL: "http://192.168.1.1:12434", + expectOAI: false, + wantErr: false, + }, + { + name: "only port specified", + port: 8080, + expectURL: "http://127.0.0.1:8080", + expectOAI: false, + wantErr: false, + }, + { + name: "llamacpp url-alias specified", + urlAlias: "llamacpp", + expectURL: "http://127.0.0.1:8080/v1", + expectOAI: true, + wantErr: false, + }, + { + name: "ollama url-alias specified", + urlAlias: "ollama", + expectURL: "http://127.0.0.1:11434/v1", + expectOAI: true, + wantErr: false, + }, + { + name: "openrouter url-alias without API key", + urlAlias: "openrouter", + wantErr: true, + }, + { + name: "openrouter url-alias with API key", + urlAlias: "openrouter", + expectURL: "https://openrouter.ai/api/v1", + expectOAI: true, + wantErr: false, + setupEnv: func() { + os.Setenv("OPENAI_API_KEY", "test-key") + }, + cleanupEnv: func() { + os.Unsetenv("OPENAI_API_KEY") + }, + }, + { + name: "custom URL specified", + customURL: "http://custom.server.com:9000/v1", + expectURL: "http://custom.server.com:9000/v1", + expectOAI: true, + wantErr: false, + }, + { + name: "multiple preset flags (url + url-alias)", + customURL: "http://test.com/v1", + urlAlias: "ollama", + wantErr: true, + }, + { + name: "host/port with url-alias", + host: "192.168.1.1", + urlAlias: "llamacpp", + wantErr: true, + }, + { + name: "host/port with url", + host: "192.168.1.1", + customURL: "http://test.com/v1", + wantErr: true, + }, + { + name: "invalid url-alias", + urlAlias: "invalid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupEnv != nil { + tt.setupEnv() + } + if tt.cleanupEnv != nil { + defer tt.cleanupEnv() + } + + url, useOAI, apiKey, err := resolveServerURL(tt.host, tt.customURL, tt.urlAlias, tt.port) + + if (err != nil) != tt.wantErr { + t.Errorf("resolveServerURL() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err != nil { + return + } + + if url != tt.expectURL { + t.Errorf("resolveServerURL() url = %v, want %v", url, tt.expectURL) + } + + if useOAI != tt.expectOAI { + t.Errorf("resolveServerURL() useOAI = %v, want %v", useOAI, tt.expectOAI) + } + + // For openrouter, check that API key is returned + if tt.urlAlias == "openrouter" && !tt.wantErr { + if apiKey == "" { + t.Errorf("resolveServerURL() expected API key for openrouter, got empty string") + } + } + }) + } +} + +func TestValidateBackend(t *testing.T) { + tests := []struct { + name string + backend string + wantErr bool + }{ + { + name: "valid backend llama.cpp", + backend: "llama.cpp", + wantErr: false, + }, + { + name: "valid backend openai", + backend: "openai", + wantErr: false, + }, + { + name: "invalid backend", + backend: "invalid", + wantErr: true, + }, + { + name: "empty backend", + backend: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateBackend(tt.backend) + if (err != nil) != tt.wantErr { + t.Errorf("validateBackend() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestEnsureAPIKey(t *testing.T) { + tests := []struct { + name string + backend string + setupEnv func() + cleanupEnv func() + wantErr bool + wantKey string + }{ + { + name: "non-openai backend", + backend: "llama.cpp", + wantErr: false, + wantKey: "", + }, + { + name: "openai backend without key", + backend: "openai", + wantErr: true, + }, + { + name: "openai backend with key", + backend: "openai", + setupEnv: func() { + os.Setenv("OPENAI_API_KEY", "test-key") + }, + cleanupEnv: func() { + os.Unsetenv("OPENAI_API_KEY") + }, + wantErr: false, + wantKey: "test-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setupEnv != nil { + tt.setupEnv() + } + if tt.cleanupEnv != nil { + defer tt.cleanupEnv() + } + + key, err := ensureAPIKey(tt.backend) + if (err != nil) != tt.wantErr { + t.Errorf("ensureAPIKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if key != tt.wantKey { + t.Errorf("ensureAPIKey() key = %v, want %v", key, tt.wantKey) + } + }) + } +} diff --git a/cmd/cli/commands/list.go b/cmd/cli/commands/list.go index b4dfa9e4..d08a7d53 100644 --- a/cmd/cli/commands/list.go +++ b/cmd/cli/commands/list.go @@ -20,11 +20,33 @@ import ( func newListCmd() *cobra.Command { var jsonFormat, openai, quiet bool var backend string + var host string + var port int + var customURL string + var urlAlias string + c := &cobra.Command{ Use: "list [OPTIONS]", Aliases: []string{"ls"}, Short: "List the models pulled to your local environment", RunE: func(cmd *cobra.Command, args []string) error { + // Resolve server URL from flags + serverURL, useOpenAI, apiKey, err := resolveServerURL(host, customURL, urlAlias, port) + if err != nil { + return err + } + + // Override model runner context if server URL is specified + if serverURL != "" { + if err := overrideModelRunnerContextFromURL(serverURL, useOpenAI); err != nil { + return err + } + } else if host != "" || port != 0 { + if err := overrideModelRunnerContext(host, port); err != nil { + return err + } + } + // Validate backend if specified if backend != "" { if err := validateBackend(backend); err != nil { @@ -32,14 +54,25 @@ func newListCmd() *cobra.Command { } } - if (backend == "openai" || openai) && quiet { - return fmt.Errorf("--quiet flag cannot be used with --openai flag or OpenAI backend") + // If using OpenAI-compatible endpoints, set backend to "openai" + // Note: We don't automatically set openai=true here because that controls output format + // Users need to explicitly pass --openai flag for OpenAI JSON format output + if useOpenAI { + if backend == "" { + backend = "openai" + } } - // Validate API key for OpenAI backend - apiKey, err := ensureAPIKey(backend) - if err != nil { - return err + if openai && quiet { + return fmt.Errorf("--quiet flag cannot be used with --openai flag") + } + + // Validate API key for OpenAI backend (legacy backend flag) + if backend != "" && apiKey == "" { + apiKey, err = ensureAPIKey(backend) + if err != nil { + return err + } } // If we're doing an automatic install, only show the installation @@ -69,17 +102,36 @@ func newListCmd() *cobra.Command { c.Flags().BoolVarP(&quiet, "quiet", "q", false, "Only show model IDs") c.Flags().StringVar(&backend, "backend", "", fmt.Sprintf("Specify the backend to use (%s)", ValidBackendsKeys())) c.Flags().MarkHidden("backend") + + // Server connection flags + c.Flags().StringVar(&host, "host", "", "Host address to bind Docker Model Runner (default \"127.0.0.1\")") + c.Flags().IntVar(&port, "port", 0, "Docker container port for Docker Model Runner (default: 12434)") + c.Flags().StringVar(&customURL, "url", "", "Base URL for the model API") + c.Flags().StringVar(&urlAlias, "url-alias", "", "Use openai alias server output (llamacpp|ollama|openrouter)") + return c } func listModels(openai bool, backend string, desktopClient *desktop.Client, quiet bool, jsonFormat bool, apiKey string, modelFilter string) (string, error) { - if openai || backend == "openai" { + if backend == "openai" { models, err := desktopClient.ListOpenAI(backend, apiKey) if err != nil { err = handleClientError(err, "Failed to list models") return "", handleNotRunningError(err) } - return formatter.ToStandardJSON(models) + + // Support different output formats + if openai || jsonFormat { + return formatter.ToStandardJSON(models) + } + if quiet { + var modelIDs string + for _, m := range models.Data { + modelIDs += fmt.Sprintf("%s\n", m.ID) + } + return modelIDs, nil + } + return prettyPrintOpenAIModels(models), nil } models, err := desktopClient.List() if err != nil { @@ -130,6 +182,35 @@ func listModels(openai bool, backend string, desktopClient *desktop.Client, quie return prettyPrintModels(models), nil } +func prettyPrintOpenAIModels(modelList dmrm.OpenAIModelList) string { + var buf bytes.Buffer + table := tablewriter.NewWriter(&buf) + + table.SetHeader([]string{"MODEL NAME", "CREATED"}) + + table.SetBorder(false) + table.SetColumnSeparator("") + table.SetHeaderLine(false) + table.SetTablePadding(" ") + table.SetNoWhiteSpace(true) + + table.SetColumnAlignment([]int{ + tablewriter.ALIGN_LEFT, // MODEL NAME + tablewriter.ALIGN_LEFT, // CREATED + }) + table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) + + for _, m := range modelList.Data { + table.Append([]string{ + m.ID, + units.HumanDuration(time.Since(time.Unix(m.Created, 0))) + " ago", + }) + } + + table.Render() + return buf.String() +} + func prettyPrintModels(models []dmrm.Model) string { var buf bytes.Buffer table := tablewriter.NewWriter(&buf) diff --git a/cmd/cli/commands/list_test.go b/cmd/cli/commands/list_test.go new file mode 100644 index 00000000..301cc991 --- /dev/null +++ b/cmd/cli/commands/list_test.go @@ -0,0 +1,73 @@ +package commands + +import ( + "strings" + "testing" + "time" + + dmrm "github.com/docker/model-runner/pkg/inference/models" +) + +func TestPrettyPrintOpenAIModels(t *testing.T) { + modelList := dmrm.OpenAIModelList{ + Object: "list", + Data: []*dmrm.OpenAIModel{ + { + ID: "llama3.2:3b", + Object: "model", + Created: time.Now().Unix() - 3600, // 1 hour ago + OwnedBy: "docker", + }, + { + ID: "qwen2.5:7b", + Object: "model", + Created: time.Now().Unix() - 86400, // 1 day ago + OwnedBy: "docker", + }, + }, + } + + output := prettyPrintOpenAIModels(modelList) + + // Verify it's table format (contains headers) + if !strings.Contains(output, "MODEL NAME") { + t.Errorf("Expected output to contain 'MODEL NAME' header, got: %s", output) + } + + if !strings.Contains(output, "CREATED") { + t.Errorf("Expected output to contain 'CREATED' header, got: %s", output) + } + + // Verify model names are in output + if !strings.Contains(output, "llama3.2:3b") { + t.Errorf("Expected output to contain 'llama3.2:3b', got: %s", output) + } + + if !strings.Contains(output, "qwen2.5:7b") { + t.Errorf("Expected output to contain 'qwen2.5:7b', got: %s", output) + } + + // Verify time format (should contain "ago") + if !strings.Contains(output, "ago") { + t.Errorf("Expected output to contain time format with 'ago', got: %s", output) + } + + // Verify it's not JSON format + if strings.Contains(output, "{") || strings.Contains(output, "}") { + t.Errorf("Expected table format, but got JSON-like output: %s", output) + } +} + +func TestPrettyPrintOpenAIModelsEmpty(t *testing.T) { + modelList := dmrm.OpenAIModelList{ + Object: "list", + Data: []*dmrm.OpenAIModel{}, + } + + output := prettyPrintOpenAIModels(modelList) + + // Should still have headers even with no models + if !strings.Contains(output, "MODEL NAME") { + t.Errorf("Expected output to contain 'MODEL NAME' header even with no models, got: %s", output) + } +} diff --git a/cmd/cli/commands/pull.go b/cmd/cli/commands/pull.go index 38f3c7b5..a7006c86 100644 --- a/cmd/cli/commands/pull.go +++ b/cmd/cli/commands/pull.go @@ -13,6 +13,8 @@ import ( func newPullCmd() *cobra.Command { var ignoreRuntimeMemoryCheck bool + var host string + var port int c := &cobra.Command{ Use: "pull MODEL", @@ -28,6 +30,13 @@ func newPullCmd() *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { + // Override model runner context if host/port is specified + if host != "" || port != 0 { + if err := overrideModelRunnerContext(host, port); err != nil { + return err + } + } + if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), cmd); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } @@ -37,6 +46,8 @@ func newPullCmd() *cobra.Command { } c.Flags().BoolVar(&ignoreRuntimeMemoryCheck, "ignore-runtime-memory-check", false, "Do not block pull if estimated runtime memory for model exceeds system resources.") + c.Flags().StringVar(&host, "host", "", "Host address to bind Docker Model Runner (default \"127.0.0.1\")") + c.Flags().IntVar(&port, "port", 0, "Docker container port for Docker Model Runner (default: 12434)") return c } diff --git a/cmd/cli/commands/push.go b/cmd/cli/commands/push.go index 72f614c6..26ec9cd1 100644 --- a/cmd/cli/commands/push.go +++ b/cmd/cli/commands/push.go @@ -10,6 +10,9 @@ import ( ) func newPushCmd() *cobra.Command { + var host string + var port int + c := &cobra.Command{ Use: "push MODEL", Short: "Push a model to Docker Hub", @@ -24,6 +27,13 @@ func newPushCmd() *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { + // Override model runner context if host/port is specified + if host != "" || port != 0 { + if err := overrideModelRunnerContext(host, port); err != nil { + return err + } + } + if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), cmd); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } @@ -31,6 +41,10 @@ func newPushCmd() *cobra.Command { }, ValidArgsFunction: completion.NoComplete, } + + c.Flags().StringVar(&host, "host", "", "Host address to bind Docker Model Runner (default \"127.0.0.1\")") + c.Flags().IntVar(&port, "port", 0, "Docker container port for Docker Model Runner (default: 12434)") + return c } diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 2f2ff982..1d2f90da 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -535,6 +535,10 @@ func newRunCmd() *cobra.Command { var ignoreRuntimeMemoryCheck bool var colorMode string var detach bool + var host string + var port int + var customURL string + var urlAlias string const cmdArgs = "MODEL [PROMPT]" c := &cobra.Command{ @@ -549,6 +553,23 @@ func newRunCmd() *cobra.Command { } }, RunE: func(cmd *cobra.Command, args []string) error { + // Resolve server URL from flags + serverURL, useOpenAI, apiKey, err := resolveServerURL(host, customURL, urlAlias, port) + if err != nil { + return err + } + + // Override model runner context if server URL is specified + if serverURL != "" { + if err := overrideModelRunnerContextFromURL(serverURL, useOpenAI); err != nil { + return err + } + } else if host != "" || port != 0 { + if err := overrideModelRunnerContext(host, port); err != nil { + return err + } + } + // Validate backend if specified if backend != "" { if err := validateBackend(backend); err != nil { @@ -556,10 +577,17 @@ func newRunCmd() *cobra.Command { } } - // Validate API key for OpenAI backend - apiKey, err := ensureAPIKey(backend) - if err != nil { - return err + // Validate API key for OpenAI backend (legacy backend flag) + if backend != "" && apiKey == "" { + apiKey, err = ensureAPIKey(backend) + if err != nil { + return err + } + } + + // If using OpenAI-compatible endpoints, set backend to "openai" + if useOpenAI && backend == "" { + backend = "openai" } // Normalize model name to add default org and tag if missing @@ -665,5 +693,11 @@ func newRunCmd() *cobra.Command { c.Flags().StringVar(&colorMode, "color", "auto", "Use colored output (auto|yes|no)") c.Flags().BoolVarP(&detach, "detach", "d", false, "Load the model in the background without interaction") + // Server connection flags + c.Flags().StringVar(&host, "host", "", "Host address to bind Docker Model Runner (default \"127.0.0.1\")") + c.Flags().IntVar(&port, "port", 0, "Docker container port for Docker Model Runner (default: 12434)") + c.Flags().StringVar(&customURL, "url", "", "Base URL for the model API") + c.Flags().StringVar(&urlAlias, "url-alias", "", "Use openai alias server output (llamacpp|ollama|openrouter)") + return c } diff --git a/cmd/cli/commands/run_test.go b/cmd/cli/commands/run_test.go index 422f6efa..ad835757 100644 --- a/cmd/cli/commands/run_test.go +++ b/cmd/cli/commands/run_test.go @@ -156,3 +156,131 @@ func TestRunCmdDetachFlag(t *testing.T) { t.Errorf("Expected detach flag value to be true, got false") } } + +func TestRunCmdServerConnectionFlags(t *testing.T) { + // Create the run command + cmd := newRunCmd() + + // Test --host flag + hostFlag := cmd.Flags().Lookup("host") + if hostFlag == nil { + t.Fatal("--host flag not found") + } + if hostFlag.Value.Type() != "string" { + t.Errorf("Expected host flag type to be 'string', got '%s'", hostFlag.Value.Type()) + } + + // Test --port flag + portFlag := cmd.Flags().Lookup("port") + if portFlag == nil { + t.Fatal("--port flag not found") + } + if portFlag.Value.Type() != "int" { + t.Errorf("Expected port flag type to be 'int', got '%s'", portFlag.Value.Type()) + } + + // Test --url flag + urlFlag := cmd.Flags().Lookup("url") + if urlFlag == nil { + t.Fatal("--url flag not found") + } + if urlFlag.Value.Type() != "string" { + t.Errorf("Expected url flag type to be 'string', got '%s'", urlFlag.Value.Type()) + } + + // Test --url-alias flag + urlAliasFlag := cmd.Flags().Lookup("url-alias") + if urlAliasFlag == nil { + t.Fatal("--url-alias flag not found") + } + if urlAliasFlag.Value.Type() != "string" { + t.Errorf("Expected url-alias flag type to be 'string', got '%s'", urlAliasFlag.Value.Type()) + } +} + +func TestListCmdServerConnectionFlags(t *testing.T) { + // Create the list command + cmd := newListCmd() + + // Test --host flag + hostFlag := cmd.Flags().Lookup("host") + if hostFlag == nil { + t.Fatal("--host flag not found") + } + if hostFlag.Value.Type() != "string" { + t.Errorf("Expected host flag type to be 'string', got '%s'", hostFlag.Value.Type()) + } + + // Test --port flag + portFlag := cmd.Flags().Lookup("port") + if portFlag == nil { + t.Fatal("--port flag not found") + } + if portFlag.Value.Type() != "int" { + t.Errorf("Expected port flag type to be 'int', got '%s'", portFlag.Value.Type()) + } + + // Test --url flag + urlFlag := cmd.Flags().Lookup("url") + if urlFlag == nil { + t.Fatal("--url flag not found") + } + if urlFlag.Value.Type() != "string" { + t.Errorf("Expected url flag type to be 'string', got '%s'", urlFlag.Value.Type()) + } + + // Test --url-alias flag + urlAliasFlag := cmd.Flags().Lookup("url-alias") + if urlAliasFlag == nil { + t.Fatal("--url-alias flag not found") + } + if urlAliasFlag.Value.Type() != "string" { + t.Errorf("Expected url-alias flag type to be 'string', got '%s'", urlAliasFlag.Value.Type()) + } +} + +func TestPullCmdServerConnectionFlags(t *testing.T) { + // Create the pull command + cmd := newPullCmd() + + // Test --host flag + hostFlag := cmd.Flags().Lookup("host") + if hostFlag == nil { + t.Fatal("--host flag not found") + } + if hostFlag.Value.Type() != "string" { + t.Errorf("Expected host flag type to be 'string', got '%s'", hostFlag.Value.Type()) + } + + // Test --port flag + portFlag := cmd.Flags().Lookup("port") + if portFlag == nil { + t.Fatal("--port flag not found") + } + if portFlag.Value.Type() != "int" { + t.Errorf("Expected port flag type to be 'int', got '%s'", portFlag.Value.Type()) + } +} + +func TestPushCmdServerConnectionFlags(t *testing.T) { + // Create the push command + cmd := newPushCmd() + + // Test --host flag + hostFlag := cmd.Flags().Lookup("host") + if hostFlag == nil { + t.Fatal("--host flag not found") + } + if hostFlag.Value.Type() != "string" { + t.Errorf("Expected host flag type to be 'string', got '%s'", hostFlag.Value.Type()) + } + + // Test --port flag + portFlag := cmd.Flags().Lookup("port") + if portFlag == nil { + t.Fatal("--port flag not found") + } + if portFlag.Value.Type() != "int" { + t.Errorf("Expected port flag type to be 'int', got '%s'", portFlag.Value.Type()) + } +} diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index c8b4abbe..4f89cca4 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -59,3 +59,40 @@ func stripDefaultsFromModelName(model string) string { // For other cases (ai/ with custom tag, custom org with :latest, etc.), keep as-is return model } + +// overrideModelRunnerContext updates the model runner context with custom host/port +func overrideModelRunnerContext(host string, port int) error { + if host == "" { + host = "127.0.0.1" + } + if port == 0 { + port = 12434 + } + + // Create a new model runner context with the custom host and port + newContext, err := desktop.NewContextWithHostPort(dockerCLI, host, port) + if err != nil { + return fmt.Errorf("unable to create model runner context with host %s and port %d: %w", host, port, err) + } + + // Update global variables + modelRunner = newContext + desktopClient = desktop.New(newContext) + + return nil +} + +// overrideModelRunnerContextFromURL updates the model runner context with a custom URL +func overrideModelRunnerContextFromURL(url string, external bool) error { + // Create a new model runner context with the custom URL + newContext, err := desktop.NewContextWithURLExternal(dockerCLI, url, external) + if err != nil { + return fmt.Errorf("unable to create model runner context with URL %s: %w", url, err) + } + + // Update global variables + modelRunner = newContext + desktopClient = desktop.New(newContext) + + return nil +} diff --git a/cmd/cli/desktop/context.go b/cmd/cli/desktop/context.go index d3ff0c5b..16d87c4d 100644 --- a/cmd/cli/desktop/context.go +++ b/cmd/cli/desktop/context.go @@ -89,6 +89,8 @@ type ModelRunnerContext struct { urlPrefix *url.URL // client is the model runner client. client DockerHttpClient + // externalOpenAI indicates if this context points to an external OpenAI-compatible endpoint + externalOpenAI bool } // NewContextForMock is a ModelRunnerContext constructor exposed only for the @@ -167,6 +169,70 @@ func DetectContext(ctx context.Context, cli *command.DockerCli) (*ModelRunnerCon }, nil } +// NewContextWithHostPort creates a new ModelRunnerContext with a custom host and port. +func NewContextWithHostPort(cli *command.DockerCli, host string, port int) (*ModelRunnerContext, error) { + if host == "" { + host = "127.0.0.1" + } + if port == 0 { + port = 12434 + } + + // Create URL prefix with custom host and port + rawURLPrefix := fmt.Sprintf("http://%s:%d", host, port) + urlPrefix, err := url.Parse(rawURLPrefix) + if err != nil { + return nil, fmt.Errorf("invalid model runner URL (%s): %w", rawURLPrefix, err) + } + + // Use HTTP default client for custom host/port + client := http.DefaultClient + + if userAgent := os.Getenv("USER_AGENT"); userAgent != "" { + setUserAgent(client, userAgent) + } + + return &ModelRunnerContext{ + kind: types.ModelRunnerEngineKindMobyManual, + urlPrefix: urlPrefix, + client: client, + }, nil +} + +// NewContextWithURL creates a new ModelRunnerContext with a custom URL. +// If external is true, this is an external OpenAI-compatible endpoint. +func NewContextWithURL(cli *command.DockerCli, rawURL string) (*ModelRunnerContext, error) { + return NewContextWithURLExternal(cli, rawURL, false) +} + +// NewContextWithURLExternal creates a new ModelRunnerContext with a custom URL. +// The external parameter indicates if this is an external OpenAI-compatible endpoint. +func NewContextWithURLExternal(cli *command.DockerCli, rawURL string, external bool) (*ModelRunnerContext, error) { + urlPrefix, err := url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("invalid model runner URL (%s): %w", rawURL, err) + } + + // Use HTTP default client for custom URL + client := http.DefaultClient + + if userAgent := os.Getenv("USER_AGENT"); userAgent != "" { + setUserAgent(client, userAgent) + } + + return &ModelRunnerContext{ + kind: types.ModelRunnerEngineKindMobyManual, + urlPrefix: urlPrefix, + client: client, + externalOpenAI: external, + }, nil +} + +// IsExternalOpenAI returns true if this context points to an external OpenAI-compatible endpoint. +func (c *ModelRunnerContext) IsExternalOpenAI() bool { + return c.externalOpenAI +} + // EngineKind returns the Docker engine kind associated with the model runner. func (c *ModelRunnerContext) EngineKind() types.ModelRunnerEngineKind { return c.kind diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 01b7b080..8ef74a8a 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -234,10 +234,18 @@ func (c *Client) List() ([]dmrm.Model, error) { } func (c *Client) ListOpenAI(backend, apiKey string) (dmrm.OpenAIModelList, error) { - if backend == "" { - backend = DefaultBackend + var modelsRoute string + + // For external OpenAI endpoints, use the direct path + if c.modelRunner.IsExternalOpenAI() { + modelsRoute = "/models" + } else { + // For model-runner proxy, use the inference prefix + if backend == "" { + backend = DefaultBackend + } + modelsRoute = fmt.Sprintf("%s/%s/v1/models", inference.InferencePrefix, backend) } - modelsRoute := fmt.Sprintf("%s/%s/v1/models", inference.InferencePrefix, backend) // Use doRequestWithAuth to support API key authentication resp, err := c.doRequestWithAuth(http.MethodGet, modelsRoute, nil, "openai", apiKey) @@ -388,10 +396,16 @@ func (c *Client) ChatWithContext(ctx context.Context, backend, model, prompt, ap } var completionsPath string - if backend != "" { - completionsPath = inference.InferencePrefix + "/" + backend + "/v1/chat/completions" + // For external OpenAI endpoints, use the direct path + if c.modelRunner.IsExternalOpenAI() { + completionsPath = "/chat/completions" } else { - completionsPath = inference.InferencePrefix + "/v1/chat/completions" + // For model-runner proxy, use the inference prefix + if backend != "" { + completionsPath = inference.InferencePrefix + "/" + backend + "/v1/chat/completions" + } else { + completionsPath = inference.InferencePrefix + "/v1/chat/completions" + } } resp, err := c.doRequestWithAuthContext(