diff --git a/cmd/root.go b/cmd/root.go index 1728aed80ea..347d6948c04 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -363,6 +363,7 @@ func NewCommand(opts ...Option) *Command { flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.") flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.") flags.BoolVar(&cmd.cfg.UI, "ui", false, "Launches the Toolbox UI web server.") + flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") // wrap RunE command so that we have access to original Command object cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) } diff --git a/cmd/root_test.go b/cmd/root_test.go index d698a11db41..7914bb27542 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -63,6 +63,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig { if c.TelemetryServiceName == "" { c.TelemetryServiceName = "toolbox" } + if c.AllowedOrigins == nil { + c.AllowedOrigins = []string{"*"} + } return c } @@ -194,6 +197,13 @@ func TestServerConfigFlags(t *testing.T) { DisableReload: true, }), }, + { + desc: "allowed origin", + args: []string{"--allowed-origins", "http://foo.com,http://bar.com"}, + want: withDefaults(server.ServerConfig{ + AllowedOrigins: []string{"http://foo.com", "http://bar.com"}, + }), + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { diff --git a/docs/en/how-to/deploy_docker.md b/docs/en/how-to/deploy_docker.md index f7f6ac88274..ff12367572e 100644 --- a/docs/en/how-to/deploy_docker.md +++ b/docs/en/how-to/deploy_docker.md @@ -67,6 +67,13 @@ networks: ``` + {{< notice tip >}} +To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a +list of origins permitted to access the server. E.g. `command: [ "toolbox", +"--tools-file", "/config/tools.yaml", "--address", "0.0.0.0", +"--allowed-origins", "https://foo.bar"]` +{{< /notice >}} + 1. Run the following command to bring up the Toolbox and Postgres instance ```bash diff --git a/docs/en/how-to/deploy_gke.md b/docs/en/how-to/deploy_gke.md index d717dff787f..4c18e9bfb14 100644 --- a/docs/en/how-to/deploy_gke.md +++ b/docs/en/how-to/deploy_gke.md @@ -188,6 +188,12 @@ description: > path: tools.yaml ``` + {{< notice tip >}} +To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a +list of origins permitted to access the server. E.g. `args: ["--address", +"0.0.0.0", "--allowed-origins", "https://foo.bar"]` +{{< /notice >}} + 1. Create the deployment. ```bash diff --git a/docs/en/how-to/deploy_toolbox.md b/docs/en/how-to/deploy_toolbox.md index a5941362f60..455f6bd3ff0 100644 --- a/docs/en/how-to/deploy_toolbox.md +++ b/docs/en/how-to/deploy_toolbox.md @@ -104,7 +104,7 @@ section. export IMAGE=us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest ``` -{{< notice note >}} + {{< notice note >}} **The `$PORT` Environment Variable** Google Cloud Run dictates the port your application must listen on by setting the `$PORT` environment variable inside your container. This value defaults to @@ -140,6 +140,45 @@ deployment will time out. # --allow-unauthenticated # https://cloud.google.com/run/docs/authenticating/public#gcloud ``` +### Update deployed server to be secure + +To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a +list of origins permitted to access the server. In order to do that, you will +have to re-deploy the cloud run service with the new flag. + +1. Set an environment variable to the cloud run url: + + ```bash + export URL= + ``` + +2. Redeploy Toolbox: + + ```bash + gcloud run deploy toolbox \ + --image $IMAGE \ + --service-account toolbox-identity \ + --region us-central1 \ + --set-secrets "/app/tools.yaml=tools:latest" \ + --args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--allowed-origins=$URL" + # --allow-unauthenticated # https://cloud.google.com/run/docs/authenticating/public#gcloud + ``` + + If you are using a VPC network, use the command below: + + ```bash + gcloud run deploy toolbox \ + --image $IMAGE \ + --service-account toolbox-identity \ + --region us-central1 \ + --set-secrets "/app/tools.yaml=tools:latest" \ + --args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--allowed-origins=$URL" \ + # TODO(dev): update the following to match your VPC if necessary + --network default \ + --subnet default + # --allow-unauthenticated # https://cloud.google.com/run/docs/authenticating/public#gcloud + ``` + ## Connecting with Toolbox Client SDK You can connect to Toolbox Cloud Run instances directly through the SDK. diff --git a/docs/en/reference/cli.md b/docs/en/reference/cli.md index 91c4f2b2401..490e63fe2a6 100644 --- a/docs/en/reference/cli.md +++ b/docs/en/reference/cli.md @@ -25,6 +25,7 @@ description: > | | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder. | | | | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files. | | | | `--ui` | Launches the Toolbox UI web server. | | +| | `--allowed-origins` | Specifies a list of origins permitted to access this server. | `*` | | `-v` | `--version` | version for toolbox | | ## Examples diff --git a/go.mod b/go.mod index 7b6adf3074c..09c08680772 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/elastic/go-elasticsearch/v9 v9.2.0 github.com/fsnotify/fsnotify v1.9.0 github.com/go-chi/chi/v5 v5.2.3 + github.com/go-chi/cors v1.2.2 github.com/go-chi/httplog/v2 v2.1.1 github.com/go-chi/render v1.0.3 github.com/go-goquery/goquery v1.0.1 diff --git a/go.sum b/go.sum index 0440488508d..ba0cf4d8ce2 100644 --- a/go.sum +++ b/go.sum @@ -858,6 +858,8 @@ github.com/gabriel-vasile/mimetype v1.4.10/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= +github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-chi/httplog/v2 v2.1.1 h1:ojojiu4PIaoeJ/qAO4GWUxJqvYUTobeo7zmuHQJAxRk= github.com/go-chi/httplog/v2 v2.1.1/go.mod h1:/XXdxicJsp4BA5fapgIC3VuTD+z0Z/VzukoB3VDc1YE= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= diff --git a/internal/server/config.go b/internal/server/config.go index e16d2e83273..fa0f1952a7f 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -62,6 +62,8 @@ type ServerConfig struct { DisableReload bool // UI indicates if Toolbox UI endpoints (/ui) are available UI bool + // Specifies a list of origins permitted to access this server. + AllowedOrigins []string } type logFormat string diff --git a/internal/server/server.go b/internal/server/server.go index 8b5427d17c8..7bbe218f284 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,6 +20,7 @@ import ( "io" "net" "net/http" + "slices" "strconv" "strings" "sync" @@ -27,6 +28,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/cors" "github.com/go-chi/httplog/v2" "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/log" @@ -388,6 +390,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { // set up http serving r := chi.NewRouter() r.Use(middleware.Recoverer) + // logging logLevel, err := log.SeverityToLevel(cfg.LogLevel.String()) if err != nil { @@ -440,6 +443,21 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { sseManager: sseManager, ResourceMgr: resourceManager, } + + // cors + if slices.Contains(cfg.AllowedOrigins, "*") { + s.logger.WarnContext(ctx, "wildcard (`*`) allows all origin to access the resource and is not secure. Use it with cautious for public, non-sensitive data, or during local development. Recommended to use `--allowed-origins` flag to prevent DNS rebinding attacks") + } + corsOpts := cors.Options{ + AllowedOrigins: cfg.AllowedOrigins, + AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"}, + AllowCredentials: true, // required since Toolbox uses auth headers + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "Mcp-Session-Id", "MCP-Protocol-Version"}, + ExposedHeaders: []string{"Mcp-Session-Id"}, // headers that are sent to clients + MaxAge: 300, // cache preflight results for 5 minutes + } + r.Use(cors.Handler(corsOpts)) + // control plane apiR, err := apiRouter(s) if err != nil {