Skip to content

Commit b925536

Browse files
dborowitzrahulpinto19
authored andcommitted
refactor(serverless-spark): extract common create batch tool & config
We are planning to add several very similar tools for creating batches like the existing pyspark batches: spark (Java), R, etc. They will use an identical approach of specifying environment and runtime config in the YAML, differing only in how language-specific args are passed. We can streamline this.
1 parent 06ffec2 commit b925536

File tree

5 files changed

+423
-293
lines changed

5 files changed

+423
-293
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package createbatch
16+
17+
import (
18+
"context"
19+
"encoding/json"
20+
"fmt"
21+
22+
dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
23+
"github.com/goccy/go-yaml"
24+
"google.golang.org/protobuf/encoding/protojson"
25+
"google.golang.org/protobuf/proto"
26+
)
27+
28+
// unmarshalProto is a helper function to unmarshal a generic interface{} into a proto.Message.
29+
func unmarshalProto(data any, m proto.Message) error {
30+
jsonData, err := json.Marshal(data)
31+
if err != nil {
32+
return fmt.Errorf("failed to marshal to JSON: %w", err)
33+
}
34+
return protojson.Unmarshal(jsonData, m)
35+
}
36+
37+
// Config is a common config that can be used with any type of create batch tool. However, each tool
38+
// will still need its own config type, embedding this Config, so it can provide a type-specific
39+
// Initialize implementation.
40+
type Config struct {
41+
Name string `yaml:"name" validate:"required"`
42+
Kind string `yaml:"kind" validate:"required"`
43+
Source string `yaml:"source" validate:"required"`
44+
Description string `yaml:"description"`
45+
RuntimeConfig *dataproc.RuntimeConfig `yaml:"runtimeConfig"`
46+
EnvironmentConfig *dataproc.EnvironmentConfig `yaml:"environmentConfig"`
47+
AuthRequired []string `yaml:"authRequired"`
48+
}
49+
50+
func NewConfig(ctx context.Context, name string, decoder *yaml.Decoder) (Config, error) {
51+
// Use a temporary struct to decode the YAML, so that we can handle the proto
52+
// conversion for RuntimeConfig and EnvironmentConfig.
53+
var ymlCfg struct {
54+
Name string `yaml:"name"`
55+
Kind string `yaml:"kind"`
56+
Source string `yaml:"source"`
57+
Description string `yaml:"description"`
58+
RuntimeConfig any `yaml:"runtimeConfig"`
59+
EnvironmentConfig any `yaml:"environmentConfig"`
60+
AuthRequired []string `yaml:"authRequired"`
61+
}
62+
63+
if err := decoder.DecodeContext(ctx, &ymlCfg); err != nil {
64+
return Config{}, err
65+
}
66+
67+
cfg := Config{
68+
Name: name,
69+
Kind: ymlCfg.Kind,
70+
Source: ymlCfg.Source,
71+
Description: ymlCfg.Description,
72+
AuthRequired: ymlCfg.AuthRequired,
73+
}
74+
75+
if ymlCfg.RuntimeConfig != nil {
76+
rc := &dataproc.RuntimeConfig{}
77+
if err := unmarshalProto(ymlCfg.RuntimeConfig, rc); err != nil {
78+
return Config{}, fmt.Errorf("failed to unmarshal runtimeConfig: %w", err)
79+
}
80+
cfg.RuntimeConfig = rc
81+
}
82+
83+
if ymlCfg.EnvironmentConfig != nil {
84+
ec := &dataproc.EnvironmentConfig{}
85+
if err := unmarshalProto(ymlCfg.EnvironmentConfig, ec); err != nil {
86+
return Config{}, fmt.Errorf("failed to unmarshal environmentConfig: %w", err)
87+
}
88+
cfg.EnvironmentConfig = ec
89+
}
90+
91+
return cfg, nil
92+
}
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package createbatch
16+
17+
import (
18+
"context"
19+
"encoding/json"
20+
"fmt"
21+
22+
dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
23+
"github.com/googleapis/genai-toolbox/internal/sources"
24+
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
25+
"github.com/googleapis/genai-toolbox/internal/tools"
26+
"github.com/googleapis/genai-toolbox/internal/util/parameters"
27+
"google.golang.org/protobuf/encoding/protojson"
28+
"google.golang.org/protobuf/proto"
29+
)
30+
31+
type BatchBuilder interface {
32+
Parameters() parameters.Parameters
33+
BuildBatch(params parameters.ParamValues) (*dataproc.Batch, error)
34+
}
35+
36+
func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.Source, builder BatchBuilder) (*Tool, error) {
37+
rawS, ok := srcs[cfg.Source]
38+
if !ok {
39+
return nil, fmt.Errorf("source %q not found", cfg.Source)
40+
}
41+
42+
ds, ok := rawS.(*serverlessspark.Source)
43+
if !ok {
44+
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", cfg.Kind, serverlessspark.SourceKind)
45+
}
46+
47+
desc := cfg.Description
48+
if desc == "" {
49+
desc = fmt.Sprintf("Creates a Serverless Spark (aka Dataproc Serverless) %s operation.", cfg.Kind)
50+
}
51+
52+
allParameters := builder.Parameters()
53+
inputSchema, _ := allParameters.McpManifest()
54+
55+
mcpManifest := tools.McpManifest{
56+
Name: cfg.Name,
57+
Description: desc,
58+
InputSchema: inputSchema,
59+
}
60+
61+
return &Tool{
62+
Config: cfg,
63+
originalConfig: originalCfg,
64+
Source: ds,
65+
Builder: builder,
66+
manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()},
67+
mcpManifest: mcpManifest,
68+
Parameters: allParameters,
69+
}, nil
70+
}
71+
72+
type Tool struct {
73+
Config
74+
originalConfig tools.ToolConfig
75+
76+
Source *serverlessspark.Source
77+
Builder BatchBuilder
78+
79+
manifest tools.Manifest
80+
mcpManifest tools.McpManifest
81+
Parameters parameters.Parameters
82+
}
83+
84+
func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
85+
client := t.Source.GetBatchControllerClient()
86+
87+
batch, err := t.Builder.BuildBatch(params)
88+
if err != nil {
89+
return nil, fmt.Errorf("failed to build batch: %w", err)
90+
}
91+
92+
if t.Config.RuntimeConfig != nil {
93+
batch.RuntimeConfig = proto.Clone(t.Config.RuntimeConfig).(*dataproc.RuntimeConfig)
94+
}
95+
96+
if t.Config.EnvironmentConfig != nil {
97+
batch.EnvironmentConfig = proto.Clone(t.Config.EnvironmentConfig).(*dataproc.EnvironmentConfig)
98+
}
99+
100+
// Common override for version if present in params
101+
paramMap := params.AsMap()
102+
if version, ok := paramMap["version"].(string); ok && version != "" {
103+
if batch.RuntimeConfig == nil {
104+
batch.RuntimeConfig = &dataproc.RuntimeConfig{}
105+
}
106+
batch.RuntimeConfig.Version = version
107+
}
108+
109+
req := &dataproc.CreateBatchRequest{
110+
Parent: fmt.Sprintf("projects/%s/locations/%s", t.Source.Project, t.Source.Location),
111+
Batch: batch,
112+
}
113+
114+
op, err := client.CreateBatch(ctx, req)
115+
if err != nil {
116+
return nil, fmt.Errorf("failed to create batch: %w", err)
117+
}
118+
119+
meta, err := op.Metadata()
120+
if err != nil {
121+
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
122+
}
123+
124+
jsonBytes, err := protojson.Marshal(meta)
125+
if err != nil {
126+
return nil, fmt.Errorf("failed to marshal create batch op metadata to JSON: %w", err)
127+
}
128+
129+
var result map[string]any
130+
if err := json.Unmarshal(jsonBytes, &result); err != nil {
131+
return nil, fmt.Errorf("failed to unmarshal create batch op metadata JSON: %w", err)
132+
}
133+
134+
return result, nil
135+
}
136+
137+
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
138+
return parameters.ParseParams(t.Parameters, data, claims)
139+
}
140+
141+
func (t *Tool) Manifest() tools.Manifest {
142+
return t.manifest
143+
}
144+
145+
func (t *Tool) McpManifest() tools.McpManifest {
146+
return t.mcpManifest
147+
}
148+
149+
func (t *Tool) Authorized(services []string) bool {
150+
return tools.IsAuthorized(t.AuthRequired, services)
151+
}
152+
153+
func (t *Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
154+
return false
155+
}
156+
157+
func (t *Tool) ToConfig() tools.ToolConfig {
158+
return t.originalConfig
159+
}
160+
161+
func (t *Tool) GetAuthTokenHeaderName() string {
162+
return "Authorization"
163+
}

0 commit comments

Comments
 (0)