@@ -25,6 +25,191 @@ import (
2525 "github.com/docker/cagent/pkg/tools/mcp"
2626)
2727
28+ // ToolsetCreator is a function that creates a toolset based on the provided configuration
29+ type ToolsetCreator func (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error )
30+
31+ // ToolsetRegistry manages the registration of toolset creators by type
32+ type ToolsetRegistry struct {
33+ creators map [string ]ToolsetCreator
34+ }
35+
36+ // NewToolsetRegistry creates a new empty toolset registry
37+ func NewToolsetRegistry () * ToolsetRegistry {
38+ return & ToolsetRegistry {
39+ creators : make (map [string ]ToolsetCreator ),
40+ }
41+ }
42+
43+ // Register adds a new toolset creator for the given type
44+ func (r * ToolsetRegistry ) Register (toolsetType string , creator ToolsetCreator ) {
45+ r .creators [toolsetType ] = creator
46+ }
47+
48+ // Get retrieves a toolset creator for the given type
49+ func (r * ToolsetRegistry ) Get (toolsetType string ) (ToolsetCreator , bool ) {
50+ creator , ok := r .creators [toolsetType ]
51+ return creator , ok
52+ }
53+
54+ // CreateTool creates a toolset using the registered creator for the given type
55+ func (r * ToolsetRegistry ) CreateTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
56+ creator , ok := r .Get (toolset .Type )
57+ if ! ok {
58+ return nil , fmt .Errorf ("unknown toolset type: %s" , toolset .Type )
59+ }
60+ return creator (ctx , toolset , parentDir , envProvider , runtimeConfig )
61+ }
62+
63+ func NewDefaultToolsetRegistry () * ToolsetRegistry {
64+ r := NewToolsetRegistry ()
65+ // Register all built-in toolset creators
66+ r .Register ("todo" , createTodoTool )
67+ r .Register ("memory" , createMemoryTool )
68+ r .Register ("think" , createThinkTool )
69+ r .Register ("shell" , createShellTool )
70+ r .Register ("script" , createScriptTool )
71+ r .Register ("filesystem" , createFilesystemTool )
72+ r .Register ("fetch" , createFetchTool )
73+ r .Register ("mcp" , createMCPTool )
74+ return r
75+ }
76+
77+ func createTodoTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
78+ if toolset .Shared {
79+ return builtin .NewSharedTodoTool (), nil
80+ }
81+ return builtin .NewTodoTool (), nil
82+ }
83+
84+ func createMemoryTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
85+ var memoryPath string
86+ if filepath .IsAbs (toolset .Path ) {
87+ memoryPath = ""
88+ } else if wd , err := os .Getwd (); err == nil {
89+ memoryPath = wd
90+ } else {
91+ memoryPath = parentDir
92+ }
93+
94+ validatedMemoryPath , err := path .ValidatePathInDirectory (toolset .Path , memoryPath )
95+ if err != nil {
96+ return nil , fmt .Errorf ("invalid memory database path: %w" , err )
97+ }
98+ if err := os .MkdirAll (filepath .Dir (validatedMemoryPath ), 0o700 ); err != nil {
99+ return nil , fmt .Errorf ("failed to create memory database directory: %w" , err )
100+ }
101+
102+ db , err := sqlite .NewMemoryDatabase (validatedMemoryPath )
103+ if err != nil {
104+ return nil , fmt .Errorf ("failed to create memory database: %w" , err )
105+ }
106+
107+ return builtin .NewMemoryTool (db ), nil
108+ }
109+
110+ func createThinkTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
111+ return builtin .NewThinkTool (), nil
112+ }
113+
114+ func createShellTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
115+ env , err := environment .ExpandAll (ctx , environment .ToValues (toolset .Env ), envProvider )
116+ if err != nil {
117+ return nil , fmt .Errorf ("failed to expand the tool's environment variables: %w" , err )
118+ }
119+ env = append (env , os .Environ ()... )
120+ return builtin .NewShellTool (env ), nil
121+ }
122+
123+ func createScriptTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
124+ if len (toolset .Shell ) == 0 {
125+ return nil , fmt .Errorf ("shell is required for script toolset" )
126+ }
127+
128+ env , err := environment .ExpandAll (ctx , environment .ToValues (toolset .Env ), envProvider )
129+ if err != nil {
130+ return nil , fmt .Errorf ("failed to expand the tool's environment variables: %w" , err )
131+ }
132+ env = append (env , os .Environ ()... )
133+ return builtin .NewScriptShellTool (toolset .Shell , env ), nil
134+ }
135+
136+ func createFilesystemTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
137+ wd := runtimeConfig .WorkingDir
138+ if wd == "" {
139+ var err error
140+ wd , err = os .Getwd ()
141+ if err != nil {
142+ return nil , fmt .Errorf ("failed to get working directory: %w" , err )
143+ }
144+ }
145+
146+ var opts []builtin.FileSystemOpt
147+ if len (toolset .PostEdit ) > 0 {
148+ postEditConfigs := make ([]builtin.PostEditConfig , len (toolset .PostEdit ))
149+ for i , pe := range toolset .PostEdit {
150+ postEditConfigs [i ] = builtin.PostEditConfig {
151+ Path : pe .Path ,
152+ Cmd : pe .Cmd ,
153+ }
154+ }
155+ opts = append (opts , builtin .WithPostEditCommands (postEditConfigs ))
156+ }
157+
158+ return builtin .NewFilesystemTool ([]string {wd }, opts ... ), nil
159+ }
160+
161+ func createFetchTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
162+ var opts []builtin.FetchToolOption
163+ if toolset .Timeout > 0 {
164+ timeout := time .Duration (toolset .Timeout ) * time .Second
165+ opts = append (opts , builtin .WithTimeout (timeout ))
166+ }
167+ return builtin .NewFetchTool (opts ... ), nil
168+ }
169+
170+ func createMCPTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
171+ // MCP tool has three different modes: ref, command, and remote
172+ if toolset .Ref != "" {
173+ mcpServerName := gateway .ParseServerRef (toolset .Ref )
174+ serverSpec , err := gateway .ServerSpec (ctx , mcpServerName )
175+ if err != nil {
176+ return nil , fmt .Errorf ("fetching MCP server spec for %q: %w" , mcpServerName , err )
177+ }
178+
179+ // TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
180+ if serverSpec .Type == "remote" {
181+ return mcp .NewRemoteToolset (serverSpec .Remote .URL , serverSpec .Remote .TransportType , nil , runtimeConfig .RedirectURI ), nil
182+ }
183+
184+ return mcp .NewGatewayToolset (ctx , mcpServerName , toolset .Config , envProvider )
185+ }
186+
187+ if toolset .Command != "" {
188+ env , err := environment .ExpandAll (ctx , environment .ToValues (toolset .Env ), envProvider )
189+ if err != nil {
190+ return nil , fmt .Errorf ("failed to expand the tool's environment variables: %w" , err )
191+ }
192+ env = append (env , os .Environ ()... )
193+ return mcp .NewToolsetCommand (toolset .Command , toolset .Args , env ), nil
194+ }
195+
196+ if toolset .Remote .URL != "" {
197+ headers := map [string ]string {}
198+ for k , v := range toolset .Remote .Headers {
199+ expanded , err := environment .Expand (ctx , v , envProvider )
200+ if err != nil {
201+ return nil , fmt .Errorf ("failed to expand header '%s': %w" , k , err )
202+ }
203+
204+ headers [k ] = expanded
205+ }
206+
207+ return mcp .NewRemoteToolset (toolset .Remote .URL , toolset .Remote .TransportType , headers , runtimeConfig .RedirectURI ), nil
208+ }
209+
210+ return nil , fmt .Errorf ("mcp toolset requires either ref, command, or remote configuration" )
211+ }
212+
28213// LoadTeams loads all agent teams from the given directory or file path
29214func LoadTeams (ctx context.Context , agentsPathOrDirectory string , runtimeConfig config.RuntimeConfig ) (map [string ]* team.Team , error ) {
30215 teams := make (map [string ]* team.Team )
@@ -95,7 +280,8 @@ func checkRequiredEnvVars(ctx context.Context, cfg *latest.Config, env environme
95280}
96281
97282type loadOptions struct {
98- modelOverrides []string
283+ modelOverrides []string
284+ toolsetRegistry * ToolsetRegistry
99285}
100286
101287type Opt func (* loadOptions ) error
@@ -107,13 +293,24 @@ func WithModelOverrides(overrides []string) Opt {
107293 }
108294}
109295
296+ // WithToolsetRegistry allows using a custom toolset registry instead of the default
297+ func WithToolsetRegistry (registry * ToolsetRegistry ) Opt {
298+ return func (opts * loadOptions ) error {
299+ opts .toolsetRegistry = registry
300+ return nil
301+ }
302+ }
303+
110304func Load (ctx context.Context , p string , runtimeConfig config.RuntimeConfig , opts ... Opt ) (* team.Team , error ) {
111- var loadOptions loadOptions
305+ var loadOpts loadOptions
306+ loadOpts .toolsetRegistry = NewDefaultToolsetRegistry ()
307+
112308 for _ , o := range opts {
113- if err := o (& loadOptions ); err != nil {
309+ if err := o (& loadOpts ); err != nil {
114310 return nil , err
115311 }
116312 }
313+
117314 fileName := filepath .Base (p )
118315 parentDir := filepath .Dir (p )
119316
@@ -141,7 +338,7 @@ func Load(ctx context.Context, p string, runtimeConfig config.RuntimeConfig, opt
141338 }
142339
143340 // Apply model overrides from CLI flags before checking required env vars
144- if err := config .ApplyModelOverrides (cfg , loadOptions .modelOverrides ); err != nil {
341+ if err := config .ApplyModelOverrides (cfg , loadOpts .modelOverrides ); err != nil {
145342 return nil , err
146343 }
147344
@@ -174,7 +371,7 @@ func Load(ctx context.Context, p string, runtimeConfig config.RuntimeConfig, opt
174371 opts = append (opts , agent .WithModel (model ))
175372 }
176373
177- agentTools , err := getToolsForAgent (ctx , & agentConfig , parentDir , env , runtimeConfig )
374+ agentTools , err := getToolsForAgent (ctx , & agentConfig , parentDir , env , runtimeConfig , loadOpts . toolsetRegistry )
178375 if err != nil {
179376 return nil , fmt .Errorf ("failed to get tools: %w" , err )
180377 }
@@ -239,13 +436,13 @@ func getModelsForAgent(ctx context.Context, cfg *latest.Config, a *latest.AgentC
239436}
240437
241438// getToolsForAgent returns the tool definitions for an agent based on its configuration
242- func getToolsForAgent (ctx context.Context , a * latest.AgentConfig , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) ([]tools.ToolSet , error ) {
439+ func getToolsForAgent (ctx context.Context , a * latest.AgentConfig , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig , registry * ToolsetRegistry ) ([]tools.ToolSet , error ) {
243440 var t []tools.ToolSet
244441
245442 for i := range a .Toolsets {
246443 toolset := a .Toolsets [i ]
247444
248- tool , err := createTool (ctx , toolset , parentDir , envProvider , runtimeConfig )
445+ tool , err := registry . CreateTool (ctx , toolset , parentDir , envProvider , runtimeConfig )
249446 if err != nil {
250447 return nil , err
251448 }
@@ -267,122 +464,3 @@ func getToolsForAgent(ctx context.Context, a *latest.AgentConfig, parentDir stri
267464 codemode .Wrap (t ... ),
268465 }, nil
269466}
270-
271- func createTool (ctx context.Context , toolset latest.Toolset , parentDir string , envProvider environment.Provider , runtimeConfig config.RuntimeConfig ) (tools.ToolSet , error ) {
272- env , err := environment .ExpandAll (ctx , environment .ToValues (toolset .Env ), envProvider )
273- if err != nil {
274- return nil , fmt .Errorf ("failed to expand the tool's environment variables: %w" , err )
275- }
276- env = append (env , os .Environ ()... )
277-
278- switch {
279- case toolset .Type == "todo" :
280- if toolset .Shared {
281- return builtin .NewSharedTodoTool (), nil
282- }
283- return builtin .NewTodoTool (), nil
284-
285- case toolset .Type == "memory" :
286- var memoryPath string
287- if filepath .IsAbs (toolset .Path ) {
288- memoryPath = ""
289- } else if wd , err := os .Getwd (); err == nil {
290- memoryPath = wd
291- } else {
292- memoryPath = parentDir
293- }
294-
295- validatedMemoryPath , err := path .ValidatePathInDirectory (toolset .Path , memoryPath )
296- if err != nil {
297- return nil , fmt .Errorf ("invalid memory database path: %w" , err )
298- }
299- if err := os .MkdirAll (filepath .Dir (validatedMemoryPath ), 0o700 ); err != nil {
300- return nil , fmt .Errorf ("failed to create memory database directory: %w" , err )
301- }
302-
303- db , err := sqlite .NewMemoryDatabase (validatedMemoryPath )
304- if err != nil {
305- return nil , fmt .Errorf ("failed to create memory database: %w" , err )
306- }
307-
308- return builtin .NewMemoryTool (db ), nil
309-
310- case toolset .Type == "think" :
311- return builtin .NewThinkTool (), nil
312-
313- case toolset .Type == "shell" :
314- return builtin .NewShellTool (env ), nil
315-
316- case toolset .Type == "script" :
317- if len (toolset .Shell ) == 0 {
318- return nil , fmt .Errorf ("shell is required for script toolset" )
319- }
320-
321- return builtin .NewScriptShellTool (toolset .Shell , env ), nil
322-
323- case toolset .Type == "filesystem" :
324- wd := runtimeConfig .WorkingDir
325- if wd == "" {
326- var err error
327- wd , err = os .Getwd ()
328- if err != nil {
329- return nil , fmt .Errorf ("failed to get working directory: %w" , err )
330- }
331- }
332-
333- var opts []builtin.FileSystemOpt
334- if len (toolset .PostEdit ) > 0 {
335- postEditConfigs := make ([]builtin.PostEditConfig , len (toolset .PostEdit ))
336- for i , pe := range toolset .PostEdit {
337- postEditConfigs [i ] = builtin.PostEditConfig {
338- Path : pe .Path ,
339- Cmd : pe .Cmd ,
340- }
341- }
342- opts = append (opts , builtin .WithPostEditCommands (postEditConfigs ))
343- }
344-
345- return builtin .NewFilesystemTool ([]string {wd }, opts ... ), nil
346-
347- case toolset .Type == "fetch" :
348- var opts []builtin.FetchToolOption
349- if toolset .Timeout > 0 {
350- timeout := time .Duration (toolset .Timeout ) * time .Second
351- opts = append (opts , builtin .WithTimeout (timeout ))
352- }
353- return builtin .NewFetchTool (opts ... ), nil
354-
355- case toolset .Type == "mcp" && toolset .Ref != "" :
356- mcpServerName := gateway .ParseServerRef (toolset .Ref )
357- serverSpec , err := gateway .ServerSpec (ctx , mcpServerName )
358- if err != nil {
359- return nil , fmt .Errorf ("fetching MCP server spec for %q: %w" , mcpServerName , err )
360- }
361-
362- // TODO(dga): until the MCP Gateway supports oauth with cagent, we fetch the remote url and directly connect to it.
363- if serverSpec .Type == "remote" {
364- return mcp .NewRemoteToolset (serverSpec .Remote .URL , serverSpec .Remote .TransportType , nil , runtimeConfig .RedirectURI ), nil
365- }
366-
367- return mcp .NewGatewayToolset (ctx , mcpServerName , toolset .Config , envProvider )
368-
369- case toolset .Type == "mcp" && toolset .Command != "" :
370- return mcp .NewToolsetCommand (toolset .Command , toolset .Args , env ), nil
371-
372- case toolset .Type == "mcp" && toolset .Remote .URL != "" :
373- headers := map [string ]string {}
374- for k , v := range toolset .Remote .Headers {
375- expanded , err := environment .Expand (ctx , v , envProvider )
376- if err != nil {
377- return nil , fmt .Errorf ("failed to expand header '%s': %w" , k , err )
378- }
379-
380- headers [k ] = expanded
381- }
382-
383- return mcp .NewRemoteToolset (toolset .Remote .URL , toolset .Remote .TransportType , headers , runtimeConfig .RedirectURI ), nil
384-
385- default :
386- return nil , fmt .Errorf ("unknown toolset type: %s" , toolset .Type )
387- }
388- }
0 commit comments