Skip to content

Commit f8b7326

Browse files
committed
Add a toolset registry in the teamloader
Making it possible to re-define builtin tools when we call the team loader Signed-off-by: Djordje Lukic <[email protected]>
1 parent b0af04d commit f8b7326

File tree

1 file changed

+204
-126
lines changed

1 file changed

+204
-126
lines changed

pkg/teamloader/teamloader.go

Lines changed: 204 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,191 @@ import (
2525
"github.com/docker/cagent/pkg/tools/mcp"
2626
)
2727

28+
// ToolsetCreator is a function that creates a toolset based on the provided configuration
29+
type ToolsetCreator func(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error)
30+
31+
// ToolsetRegistry manages the registration of toolset creators by type
32+
type ToolsetRegistry struct {
33+
creators map[string]ToolsetCreator
34+
}
35+
36+
// NewToolsetRegistry creates a new empty toolset registry
37+
func NewToolsetRegistry() *ToolsetRegistry {
38+
return &ToolsetRegistry{
39+
creators: make(map[string]ToolsetCreator),
40+
}
41+
}
42+
43+
// Register adds a new toolset creator for the given type
44+
func (r *ToolsetRegistry) Register(toolsetType string, creator ToolsetCreator) {
45+
r.creators[toolsetType] = creator
46+
}
47+
48+
// Get retrieves a toolset creator for the given type
49+
func (r *ToolsetRegistry) Get(toolsetType string) (ToolsetCreator, bool) {
50+
creator, ok := r.creators[toolsetType]
51+
return creator, ok
52+
}
53+
54+
// CreateTool creates a toolset using the registered creator for the given type
55+
func (r *ToolsetRegistry) CreateTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
56+
creator, ok := r.Get(toolset.Type)
57+
if !ok {
58+
return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type)
59+
}
60+
return creator(ctx, toolset, parentDir, envProvider, runtimeConfig)
61+
}
62+
63+
func NewDefaultToolsetRegistry() *ToolsetRegistry {
64+
r := NewToolsetRegistry()
65+
// Register all built-in toolset creators
66+
r.Register("todo", createTodoTool)
67+
r.Register("memory", createMemoryTool)
68+
r.Register("think", createThinkTool)
69+
r.Register("shell", createShellTool)
70+
r.Register("script", createScriptTool)
71+
r.Register("filesystem", createFilesystemTool)
72+
r.Register("fetch", createFetchTool)
73+
r.Register("mcp", createMCPTool)
74+
return r
75+
}
76+
77+
func createTodoTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
78+
if toolset.Shared {
79+
return builtin.NewSharedTodoTool(), nil
80+
}
81+
return builtin.NewTodoTool(), nil
82+
}
83+
84+
func createMemoryTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
85+
var memoryPath string
86+
if filepath.IsAbs(toolset.Path) {
87+
memoryPath = ""
88+
} else if wd, err := os.Getwd(); err == nil {
89+
memoryPath = wd
90+
} else {
91+
memoryPath = parentDir
92+
}
93+
94+
validatedMemoryPath, err := path.ValidatePathInDirectory(toolset.Path, memoryPath)
95+
if err != nil {
96+
return nil, fmt.Errorf("invalid memory database path: %w", err)
97+
}
98+
if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil {
99+
return nil, fmt.Errorf("failed to create memory database directory: %w", err)
100+
}
101+
102+
db, err := sqlite.NewMemoryDatabase(validatedMemoryPath)
103+
if err != nil {
104+
return nil, fmt.Errorf("failed to create memory database: %w", err)
105+
}
106+
107+
return builtin.NewMemoryTool(db), nil
108+
}
109+
110+
func createThinkTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
111+
return builtin.NewThinkTool(), nil
112+
}
113+
114+
func createShellTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
115+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
116+
if err != nil {
117+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
118+
}
119+
env = append(env, os.Environ()...)
120+
return builtin.NewShellTool(env), nil
121+
}
122+
123+
func createScriptTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
124+
if len(toolset.Shell) == 0 {
125+
return nil, fmt.Errorf("shell is required for script toolset")
126+
}
127+
128+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
129+
if err != nil {
130+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
131+
}
132+
env = append(env, os.Environ()...)
133+
return builtin.NewScriptShellTool(toolset.Shell, env), nil
134+
}
135+
136+
func createFilesystemTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
137+
wd := runtimeConfig.WorkingDir
138+
if wd == "" {
139+
var err error
140+
wd, err = os.Getwd()
141+
if err != nil {
142+
return nil, fmt.Errorf("failed to get working directory: %w", err)
143+
}
144+
}
145+
146+
var opts []builtin.FileSystemOpt
147+
if len(toolset.PostEdit) > 0 {
148+
postEditConfigs := make([]builtin.PostEditConfig, len(toolset.PostEdit))
149+
for i, pe := range toolset.PostEdit {
150+
postEditConfigs[i] = builtin.PostEditConfig{
151+
Path: pe.Path,
152+
Cmd: pe.Cmd,
153+
}
154+
}
155+
opts = append(opts, builtin.WithPostEditCommands(postEditConfigs))
156+
}
157+
158+
return builtin.NewFilesystemTool([]string{wd}, opts...), nil
159+
}
160+
161+
func createFetchTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
162+
var opts []builtin.FetchToolOption
163+
if toolset.Timeout > 0 {
164+
timeout := time.Duration(toolset.Timeout) * time.Second
165+
opts = append(opts, builtin.WithTimeout(timeout))
166+
}
167+
return builtin.NewFetchTool(opts...), nil
168+
}
169+
170+
func createMCPTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
171+
// MCP tool has three different modes: ref, command, and remote
172+
if toolset.Ref != "" {
173+
mcpServerName := gateway.ParseServerRef(toolset.Ref)
174+
serverSpec, err := gateway.ServerSpec(ctx, mcpServerName)
175+
if err != nil {
176+
return nil, fmt.Errorf("fetching MCP server spec for %q: %w", mcpServerName, err)
177+
}
178+
179+
// TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
180+
if serverSpec.Type == "remote" {
181+
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, runtimeConfig.RedirectURI), nil
182+
}
183+
184+
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, envProvider)
185+
}
186+
187+
if toolset.Command != "" {
188+
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
189+
if err != nil {
190+
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
191+
}
192+
env = append(env, os.Environ()...)
193+
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil
194+
}
195+
196+
if toolset.Remote.URL != "" {
197+
headers := map[string]string{}
198+
for k, v := range toolset.Remote.Headers {
199+
expanded, err := environment.Expand(ctx, v, envProvider)
200+
if err != nil {
201+
return nil, fmt.Errorf("failed to expand header '%s': %w", k, err)
202+
}
203+
204+
headers[k] = expanded
205+
}
206+
207+
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, runtimeConfig.RedirectURI), nil
208+
}
209+
210+
return nil, fmt.Errorf("mcp toolset requires either ref, command, or remote configuration")
211+
}
212+
28213
// LoadTeams loads all agent teams from the given directory or file path
29214
func LoadTeams(ctx context.Context, agentsPathOrDirectory string, runtimeConfig config.RuntimeConfig) (map[string]*team.Team, error) {
30215
teams := make(map[string]*team.Team)
@@ -95,7 +280,8 @@ func checkRequiredEnvVars(ctx context.Context, cfg *latest.Config, env environme
95280
}
96281

97282
type loadOptions struct {
98-
modelOverrides []string
283+
modelOverrides []string
284+
toolsetRegistry *ToolsetRegistry
99285
}
100286

101287
type Opt func(*loadOptions) error
@@ -107,13 +293,24 @@ func WithModelOverrides(overrides []string) Opt {
107293
}
108294
}
109295

296+
// WithToolsetRegistry allows using a custom toolset registry instead of the default
297+
func WithToolsetRegistry(registry *ToolsetRegistry) Opt {
298+
return func(opts *loadOptions) error {
299+
opts.toolsetRegistry = registry
300+
return nil
301+
}
302+
}
303+
110304
func Load(ctx context.Context, p string, runtimeConfig config.RuntimeConfig, opts ...Opt) (*team.Team, error) {
111-
var loadOptions loadOptions
305+
var loadOpts loadOptions
306+
loadOpts.toolsetRegistry = NewDefaultToolsetRegistry()
307+
112308
for _, o := range opts {
113-
if err := o(&loadOptions); err != nil {
309+
if err := o(&loadOpts); err != nil {
114310
return nil, err
115311
}
116312
}
313+
117314
fileName := filepath.Base(p)
118315
parentDir := filepath.Dir(p)
119316

@@ -141,7 +338,7 @@ func Load(ctx context.Context, p string, runtimeConfig config.RuntimeConfig, opt
141338
}
142339

143340
// Apply model overrides from CLI flags before checking required env vars
144-
if err := config.ApplyModelOverrides(cfg, loadOptions.modelOverrides); err != nil {
341+
if err := config.ApplyModelOverrides(cfg, loadOpts.modelOverrides); err != nil {
145342
return nil, err
146343
}
147344

@@ -174,7 +371,7 @@ func Load(ctx context.Context, p string, runtimeConfig config.RuntimeConfig, opt
174371
opts = append(opts, agent.WithModel(model))
175372
}
176373

177-
agentTools, err := getToolsForAgent(ctx, &agentConfig, parentDir, env, runtimeConfig)
374+
agentTools, err := getToolsForAgent(ctx, &agentConfig, parentDir, env, runtimeConfig, loadOpts.toolsetRegistry)
178375
if err != nil {
179376
return nil, fmt.Errorf("failed to get tools: %w", err)
180377
}
@@ -239,13 +436,13 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC
239436
}
240437

241438
// getToolsForAgent returns the tool definitions for an agent based on its configuration
242-
func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) ([]tools.ToolSet, error) {
439+
func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig, registry *ToolsetRegistry) ([]tools.ToolSet, error) {
243440
var t []tools.ToolSet
244441

245442
for i := range a.Toolsets {
246443
toolset := a.Toolsets[i]
247444

248-
tool, err := createTool(ctx, toolset, parentDir, envProvider, runtimeConfig)
445+
tool, err := registry.CreateTool(ctx, toolset, parentDir, envProvider, runtimeConfig)
249446
if err != nil {
250447
return nil, err
251448
}
@@ -267,122 +464,3 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri
267464
codemode.Wrap(t...),
268465
}, nil
269466
}
270-
271-
func createTool(ctx context.Context, toolset latest.Toolset, parentDir string, envProvider environment.Provider, runtimeConfig config.RuntimeConfig) (tools.ToolSet, error) {
272-
env, err := environment.ExpandAll(ctx, environment.ToValues(toolset.Env), envProvider)
273-
if err != nil {
274-
return nil, fmt.Errorf("failed to expand the tool's environment variables: %w", err)
275-
}
276-
env = append(env, os.Environ()...)
277-
278-
switch {
279-
case toolset.Type == "todo":
280-
if toolset.Shared {
281-
return builtin.NewSharedTodoTool(), nil
282-
}
283-
return builtin.NewTodoTool(), nil
284-
285-
case toolset.Type == "memory":
286-
var memoryPath string
287-
if filepath.IsAbs(toolset.Path) {
288-
memoryPath = ""
289-
} else if wd, err := os.Getwd(); err == nil {
290-
memoryPath = wd
291-
} else {
292-
memoryPath = parentDir
293-
}
294-
295-
validatedMemoryPath, err := path.ValidatePathInDirectory(toolset.Path, memoryPath)
296-
if err != nil {
297-
return nil, fmt.Errorf("invalid memory database path: %w", err)
298-
}
299-
if err := os.MkdirAll(filepath.Dir(validatedMemoryPath), 0o700); err != nil {
300-
return nil, fmt.Errorf("failed to create memory database directory: %w", err)
301-
}
302-
303-
db, err := sqlite.NewMemoryDatabase(validatedMemoryPath)
304-
if err != nil {
305-
return nil, fmt.Errorf("failed to create memory database: %w", err)
306-
}
307-
308-
return builtin.NewMemoryTool(db), nil
309-
310-
case toolset.Type == "think":
311-
return builtin.NewThinkTool(), nil
312-
313-
case toolset.Type == "shell":
314-
return builtin.NewShellTool(env), nil
315-
316-
case toolset.Type == "script":
317-
if len(toolset.Shell) == 0 {
318-
return nil, fmt.Errorf("shell is required for script toolset")
319-
}
320-
321-
return builtin.NewScriptShellTool(toolset.Shell, env), nil
322-
323-
case toolset.Type == "filesystem":
324-
wd := runtimeConfig.WorkingDir
325-
if wd == "" {
326-
var err error
327-
wd, err = os.Getwd()
328-
if err != nil {
329-
return nil, fmt.Errorf("failed to get working directory: %w", err)
330-
}
331-
}
332-
333-
var opts []builtin.FileSystemOpt
334-
if len(toolset.PostEdit) > 0 {
335-
postEditConfigs := make([]builtin.PostEditConfig, len(toolset.PostEdit))
336-
for i, pe := range toolset.PostEdit {
337-
postEditConfigs[i] = builtin.PostEditConfig{
338-
Path: pe.Path,
339-
Cmd: pe.Cmd,
340-
}
341-
}
342-
opts = append(opts, builtin.WithPostEditCommands(postEditConfigs))
343-
}
344-
345-
return builtin.NewFilesystemTool([]string{wd}, opts...), nil
346-
347-
case toolset.Type == "fetch":
348-
var opts []builtin.FetchToolOption
349-
if toolset.Timeout > 0 {
350-
timeout := time.Duration(toolset.Timeout) * time.Second
351-
opts = append(opts, builtin.WithTimeout(timeout))
352-
}
353-
return builtin.NewFetchTool(opts...), nil
354-
355-
case toolset.Type == "mcp" && toolset.Ref != "":
356-
mcpServerName := gateway.ParseServerRef(toolset.Ref)
357-
serverSpec, err := gateway.ServerSpec(ctx, mcpServerName)
358-
if err != nil {
359-
return nil, fmt.Errorf("fetching MCP server spec for %q: %w", mcpServerName, err)
360-
}
361-
362-
// TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
363-
if serverSpec.Type == "remote" {
364-
return mcp.NewRemoteToolset(serverSpec.Remote.URL, serverSpec.Remote.TransportType, nil, runtimeConfig.RedirectURI), nil
365-
}
366-
367-
return mcp.NewGatewayToolset(ctx, mcpServerName, toolset.Config, envProvider)
368-
369-
case toolset.Type == "mcp" && toolset.Command != "":
370-
return mcp.NewToolsetCommand(toolset.Command, toolset.Args, env), nil
371-
372-
case toolset.Type == "mcp" && toolset.Remote.URL != "":
373-
headers := map[string]string{}
374-
for k, v := range toolset.Remote.Headers {
375-
expanded, err := environment.Expand(ctx, v, envProvider)
376-
if err != nil {
377-
return nil, fmt.Errorf("failed to expand header '%s': %w", k, err)
378-
}
379-
380-
headers[k] = expanded
381-
}
382-
383-
return mcp.NewRemoteToolset(toolset.Remote.URL, toolset.Remote.TransportType, headers, runtimeConfig.RedirectURI), nil
384-
385-
default:
386-
return nil, fmt.Errorf("unknown toolset type: %s", toolset.Type)
387-
}
388-
}

0 commit comments

Comments
 (0)