|
| 1 | +// Copyright 2025 Google LLC |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +package cloudsqlpgupgradeprecheck |
| 16 | + |
| 17 | +import ( |
| 18 | + "context" |
| 19 | + "fmt" |
| 20 | + "time" |
| 21 | + |
| 22 | + yaml "github.com/goccy/go-yaml" |
| 23 | + "github.com/googleapis/genai-toolbox/internal/sources" |
| 24 | + "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" |
| 25 | + "github.com/googleapis/genai-toolbox/internal/tools" |
| 26 | + "github.com/googleapis/genai-toolbox/internal/util/parameters" |
| 27 | + sqladmin "google.golang.org/api/sqladmin/v1" |
| 28 | +) |
| 29 | + |
| 30 | +const kind string = "postgres-upgrade-precheck" |
| 31 | + |
| 32 | +func init() { |
| 33 | + if !tools.Register(kind, newConfig) { |
| 34 | + panic(fmt.Sprintf("tool kind %q already registered", kind)) |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { |
| 39 | + actual := Config{Name: name} |
| 40 | + if err := decoder.DecodeContext(ctx, &actual); err != nil { |
| 41 | + return nil, err |
| 42 | + } |
| 43 | + return actual, nil |
| 44 | +} |
| 45 | + |
| 46 | +// Config defines the configuration for the precheck-upgrade tool. |
| 47 | +type Config struct { |
| 48 | + Name string `yaml:"name" validate:"required"` |
| 49 | + Kind string `yaml:"kind" validate:"required"` |
| 50 | + Description string `yaml:"description"` |
| 51 | + Source string `yaml:"source" validate:"required"` |
| 52 | + AuthRequired []string `yaml:"authRequired"` |
| 53 | +} |
| 54 | + |
| 55 | +// validate interface |
| 56 | +var _ tools.ToolConfig = Config{} |
| 57 | + |
| 58 | +// ToolConfigKind returns the kind of the tool. |
| 59 | +func (cfg Config) ToolConfigKind() string { |
| 60 | + return kind |
| 61 | +} |
| 62 | + |
| 63 | +// Initialize initializes the tool from the configuration. |
| 64 | +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { |
| 65 | + rawS, ok := srcs[cfg.Source] |
| 66 | + if !ok { |
| 67 | + return nil, fmt.Errorf("no source named %q configured", cfg.Source) |
| 68 | + } |
| 69 | + s, ok := rawS.(*cloudsqladmin.Source) |
| 70 | + if !ok { |
| 71 | + return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) |
| 72 | + } |
| 73 | + |
| 74 | + allParameters := parameters.Parameters{ |
| 75 | + parameters.NewStringParameter("project", "The project ID"), |
| 76 | + parameters.NewStringParameter("instance", "The name of the instance to check"), |
| 77 | + parameters.NewStringParameterWithDefault("targetDatabaseVersion", "POSTGRES_18", "The target PostgreSQL version for the upgrade (e.g., POSTGRES_18). If not specified, defaults to the PostgreSQL 18."), |
| 78 | + } |
| 79 | + paramManifest := allParameters.Manifest() |
| 80 | + |
| 81 | + description := cfg.Description |
| 82 | + if description == "" { |
| 83 | + description = "Analyzes a Cloud SQL PostgreSQL instance for major version upgrade readiness. Results are provided to guide customer actions:\n" + |
| 84 | + "ERROR: Action Required. These are critical issues blocking the upgrade. Customers must resolve these using the provided actions_required steps before attempting the upgrade.\n" + |
| 85 | + "WARNING: Review Recommended. These are potential issues. Customers should review the message and actions_required. While not blocking, addressing these is advised to prevent future problems or unexpected behavior post-upgrade.\n" + |
| 86 | + "INFO: No Action Needed. Informational messages only. This pre-check helps customers proactively fix problems, preventing upgrade failures and ensuring a smoother transition." |
| 87 | + } |
| 88 | + mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters) |
| 89 | + |
| 90 | + return Tool{ |
| 91 | + Name: cfg.Name, |
| 92 | + Kind: kind, |
| 93 | + AuthRequired: cfg.AuthRequired, |
| 94 | + Source: s, |
| 95 | + AllParams: allParameters, |
| 96 | + manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, |
| 97 | + mcpManifest: mcpManifest, |
| 98 | + }, nil |
| 99 | +} |
| 100 | + |
| 101 | +// Tool represents the precheck-upgrade tool. |
| 102 | +type Tool struct { |
| 103 | + Name string `yaml:"name"` |
| 104 | + Kind string `yaml:"kind"` |
| 105 | + Description string `yaml:"description"` |
| 106 | + AuthRequired []string `yaml:"authRequired"` |
| 107 | + |
| 108 | + Source *cloudsqladmin.Source |
| 109 | + AllParams parameters.Parameters `yaml:"allParams"` |
| 110 | + manifest tools.Manifest |
| 111 | + mcpManifest tools.McpManifest |
| 112 | + Config |
| 113 | +} |
| 114 | + |
| 115 | +// PreCheckResultItem holds the details of a single check result. |
| 116 | +type PreCheckResultItem struct { |
| 117 | + Message string `json:"message"` |
| 118 | + MessageType string `json:"messageType"` // INFO, WARNING, ERROR |
| 119 | + ActionsRequired []string `json:"actionsRequired"` |
| 120 | +} |
| 121 | + |
| 122 | +// PreCheckAPIResponse holds the array of pre-check results. |
| 123 | +type PreCheckAPIResponse struct { |
| 124 | + Items []PreCheckResultItem `json:"preCheckResponse"` |
| 125 | +} |
| 126 | + |
| 127 | +// Helper function to convert from []*sqladmin.PreCheckResponse to []PreCheckResultItem |
| 128 | +func convertResults(items []*sqladmin.PreCheckResponse) []PreCheckResultItem { |
| 129 | + if len(items) == 0 { // Handle nil or empty slice |
| 130 | + return []PreCheckResultItem{} |
| 131 | + } |
| 132 | + results := make([]PreCheckResultItem, len(items)) |
| 133 | + for i, item := range items { |
| 134 | + results[i] = PreCheckResultItem{ |
| 135 | + Message: item.Message, |
| 136 | + MessageType: item.MessageType, |
| 137 | + ActionsRequired: item.ActionsRequired, |
| 138 | + } |
| 139 | + } |
| 140 | + return results |
| 141 | +} |
| 142 | + |
| 143 | +func (t Tool) ToConfig() tools.ToolConfig { |
| 144 | + return t.Config |
| 145 | +} |
| 146 | + |
| 147 | +// Invoke executes the tool's logic. |
| 148 | +func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { |
| 149 | + paramsMap := params.AsMap() |
| 150 | + |
| 151 | + project, ok := paramsMap["project"].(string) |
| 152 | + if !ok || project == "" { |
| 153 | + return nil, fmt.Errorf("missing or empty 'project' parameter") |
| 154 | + } |
| 155 | + instanceName, ok := paramsMap["instance"].(string) |
| 156 | + if !ok || instanceName == "" { |
| 157 | + return nil, fmt.Errorf("missing or empty 'instance' parameter") |
| 158 | + } |
| 159 | + targetVersion, ok := paramsMap["targetDatabaseVersion"].(string) |
| 160 | + if !ok || targetVersion == "" { |
| 161 | + // This should not happen due to the default value |
| 162 | + return nil, fmt.Errorf("missing or empty 'targetDatabaseVersion' parameter") |
| 163 | + } |
| 164 | + |
| 165 | + service, err := t.Source.GetService(ctx, string(accessToken)) |
| 166 | + if err != nil { |
| 167 | + return nil, fmt.Errorf("failed to get HTTP client from source: %w", err) |
| 168 | + } |
| 169 | + |
| 170 | + reqBody := &sqladmin.InstancesPreCheckMajorVersionUpgradeRequest{ |
| 171 | + PreCheckMajorVersionUpgradeContext: &sqladmin.PreCheckMajorVersionUpgradeContext{ |
| 172 | + TargetDatabaseVersion: targetVersion, |
| 173 | + }, |
| 174 | + } |
| 175 | + |
| 176 | + call := service.Instances.PreCheckMajorVersionUpgrade(project, instanceName, reqBody).Context(ctx) |
| 177 | + op, err := call.Do() |
| 178 | + if err != nil { |
| 179 | + return nil, fmt.Errorf("failed to start pre-check operation: %w", err) |
| 180 | + } |
| 181 | + |
| 182 | + const pollTimeout = 20 * time.Second |
| 183 | + cutoffTime := time.Now().Add(pollTimeout) |
| 184 | + |
| 185 | + for time.Now().Before(cutoffTime) { |
| 186 | + currentOp, err := service.Operations.Get(project, op.Name).Context(ctx).Do() |
| 187 | + if err != nil { |
| 188 | + return nil, fmt.Errorf("failed to get operation status: %w", err) |
| 189 | + } |
| 190 | + |
| 191 | + if currentOp.Status == "DONE" { |
| 192 | + if currentOp.Error != nil && len(currentOp.Error.Errors) > 0 { |
| 193 | + errMsg := fmt.Sprintf("pre-check operation LRO failed: %s", currentOp.Error.Errors[0].Message) |
| 194 | + if currentOp.Error.Errors[0].Code != "" { |
| 195 | + errMsg = fmt.Sprintf("%s (Code: %s)", errMsg, currentOp.Error.Errors[0].Code) |
| 196 | + } |
| 197 | + return nil, fmt.Errorf("%s", errMsg) |
| 198 | + } |
| 199 | + |
| 200 | + var preCheckItems []*sqladmin.PreCheckResponse |
| 201 | + if currentOp.PreCheckMajorVersionUpgradeContext != nil { |
| 202 | + preCheckItems = currentOp.PreCheckMajorVersionUpgradeContext.PreCheckResponse |
| 203 | + } |
| 204 | + // convertResults handles nil or empty preCheckItems |
| 205 | + return PreCheckAPIResponse{Items: convertResults(preCheckItems)}, nil |
| 206 | + } |
| 207 | + |
| 208 | + select { |
| 209 | + case <-ctx.Done(): |
| 210 | + return nil, ctx.Err() |
| 211 | + case <-time.After(5 * time.Second): |
| 212 | + } |
| 213 | + } |
| 214 | + return op, nil |
| 215 | +} |
| 216 | + |
| 217 | +// ParseParams parses the parameters for the tool. |
| 218 | +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { |
| 219 | + return parameters.ParseParams(t.AllParams, data, claims) |
| 220 | +} |
| 221 | + |
| 222 | +// Manifest returns the tool's manifest. |
| 223 | +func (t Tool) Manifest() tools.Manifest { |
| 224 | + return t.manifest |
| 225 | +} |
| 226 | + |
| 227 | +// McpManifest returns the tool's MCP manifest. |
| 228 | +func (t Tool) McpManifest() tools.McpManifest { |
| 229 | + return t.mcpManifest |
| 230 | +} |
| 231 | + |
| 232 | +// Authorized checks if the tool is authorized. |
| 233 | +func (t Tool) Authorized(verifiedAuthServices []string) bool { |
| 234 | + return true |
| 235 | +} |
| 236 | + |
| 237 | +func (t Tool) RequiresClientAuthorization() bool { |
| 238 | + return t.Source.UseClientAuthorization() |
| 239 | +} |
| 240 | + |
| 241 | +func (t Tool) GetAuthTokenHeaderName() string { |
| 242 | + return "Authorization" |
| 243 | +} |
0 commit comments