Skip to content

Commit e08564f

Browse files
authored
Merge pull request #353 from dgageot/add-output-schema
Add output schema
2 parents 179b092 + cb2be57 commit e08564f

File tree

15 files changed

+342
-9
lines changed

15 files changed

+342
-9
lines changed

pkg/memory/database/database.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ import (
88
var ErrEmptyID = errors.New("memory ID cannot be empty")
99

1010
type UserMemory struct {
11-
ID string
12-
CreatedAt string
13-
Memory string
11+
ID string `description:"The ID of the memory"`
12+
CreatedAt string `description:"The creation timestamp of the memory"`
13+
Memory string `description:"The content of the memory"`
1414
}
1515

1616
type Memory interface {

pkg/tools/builtin/filesystem.go

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"os"
1010
"os/exec"
1111
"path/filepath"
12+
"reflect"
1213
"regexp"
1314
"slices"
1415
"strings"
@@ -99,6 +100,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
99100
},
100101
Required: []string{"path"},
101102
},
103+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
102104
},
103105
Handler: t.handleCreateDirectory,
104106
},
@@ -124,6 +126,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
124126
},
125127
Required: []string{"path"},
126128
},
129+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[*TreeNode]()),
127130
},
128131
Handler: t.handleDirectoryTree,
129132
},
@@ -162,6 +165,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
162165
},
163166
Required: []string{"path", "edits"},
164167
},
168+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
165169
},
166170
Handler: t.handleEditFile,
167171
},
@@ -183,6 +187,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
183187
},
184188
Required: []string{"path"},
185189
},
190+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[FileInfo]()),
186191
},
187192
Handler: t.handleGetFileInfo,
188193
},
@@ -194,6 +199,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
194199
ReadOnlyHint: &[]bool{true}[0],
195200
Title: "List Allowed Directories",
196201
},
202+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[[]string]()),
197203
},
198204
Handler: t.handleListAllowedDirectories,
199205
},
@@ -222,6 +228,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
222228
},
223229
Required: []string{"path", "reason"},
224230
},
231+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
225232
},
226233
Handler: t.handleAddAllowedDirectory,
227234
},
@@ -243,6 +250,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
243250
},
244251
Required: []string{"path"},
245252
},
253+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
246254
},
247255
Handler: t.handleListDirectory,
248256
},
@@ -264,6 +272,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
264272
},
265273
Required: []string{"path"},
266274
},
275+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
267276
},
268277
Handler: t.handleListDirectoryWithSizes,
269278
},
@@ -288,6 +297,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
288297
},
289298
Required: []string{"source", "destination"},
290299
},
300+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
291301
},
292302
Handler: t.handleMoveFile,
293303
},
@@ -309,6 +319,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
309319
},
310320
Required: []string{"path"},
311321
},
322+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
312323
},
313324
Handler: t.handleReadFile,
314325
},
@@ -337,6 +348,8 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
337348
},
338349
Required: []string{"paths"},
339350
},
351+
// TODO(dga): depends on the json param
352+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
340353
},
341354
Handler: t.handleReadMultipleFiles,
342355
},
@@ -369,6 +382,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
369382
},
370383
Required: []string{"path", "pattern"},
371384
},
385+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
372386
},
373387
Handler: t.handleSearchFiles,
374388
},
@@ -405,6 +419,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
405419
},
406420
Required: []string{"path", "query"},
407421
},
422+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
408423
},
409424
Handler: t.handleSearchFilesContent,
410425
},
@@ -429,6 +444,7 @@ func (t *FilesystemTool) Tools(context.Context) ([]tools.Tool, error) {
429444
},
430445
Required: []string{"path", "content"},
431446
},
447+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
432448
},
433449
Handler: t.handleWriteFile,
434450
},
@@ -640,6 +656,14 @@ func (t *FilesystemTool) handleEditFile(ctx context.Context, toolCall tools.Tool
640656
return &tools.ToolCallResult{Output: fmt.Sprintf("File edited successfully. Changes:\n%s", strings.Join(changes, "\n"))}, nil
641657
}
642658

659+
type FileInfo struct {
660+
Name string `json:"name"`
661+
Size int64 `json:"size"`
662+
Mode string `json:"mode"`
663+
ModTime string `json:"modTime"`
664+
IsDir bool `json:"isDir"`
665+
}
666+
643667
func (t *FilesystemTool) handleGetFileInfo(_ context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
644668
var args struct {
645669
Path string `json:"path"`
@@ -657,12 +681,12 @@ func (t *FilesystemTool) handleGetFileInfo(_ context.Context, toolCall tools.Too
657681
return &tools.ToolCallResult{Output: fmt.Sprintf("Error getting file info: %s", err)}, nil
658682
}
659683

660-
fileInfo := map[string]any{
661-
"name": info.Name(),
662-
"size": info.Size(),
663-
"mode": info.Mode().String(),
664-
"modTime": info.ModTime().Format(time.RFC3339),
665-
"isDir": info.IsDir(),
684+
fileInfo := FileInfo{
685+
Name: info.Name(),
686+
Size: info.Size(),
687+
Mode: info.Mode().String(),
688+
ModTime: info.ModTime().Format(time.RFC3339),
689+
IsDir: info.IsDir(),
666690
}
667691

668692
result, err := json.MarshalIndent(fileInfo, "", " ")

pkg/tools/builtin/filesystem_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,3 +924,15 @@ func TestMatchExcludePattern(t *testing.T) {
924924
})
925925
}
926926
}
927+
928+
func TestFilesystemTool_OutputSchema(t *testing.T) {
929+
tool := NewFilesystemTool(nil)
930+
931+
allTools, err := tool.Tools(t.Context())
932+
require.NoError(t, err)
933+
require.NotEmpty(t, allTools)
934+
935+
for _, tool := range allTools {
936+
assert.NotEmpty(t, tool.Function.OutputSchema.Type)
937+
}
938+
}

pkg/tools/builtin/memory.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"reflect"
78
"time"
89

910
"github.com/docker/cagent/pkg/memory/database"
@@ -53,6 +54,7 @@ func (t *MemoryTool) Tools(context.Context) ([]tools.Tool, error) {
5354
},
5455
Required: []string{"memory"},
5556
},
57+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
5658
},
5759
Handler: t.handleAddMemory,
5860
},
@@ -64,6 +66,7 @@ func (t *MemoryTool) Tools(context.Context) ([]tools.Tool, error) {
6466
ReadOnlyHint: &[]bool{true}[0],
6567
Title: "Get Memories",
6668
},
69+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[[]database.UserMemory]()),
6770
},
6871
Handler: t.handleGetMemories,
6972
},
@@ -84,6 +87,7 @@ func (t *MemoryTool) Tools(context.Context) ([]tools.Tool, error) {
8487
},
8588
Required: []string{"id"},
8689
},
90+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
8791
},
8892
Handler: t.handleDeleteMemory,
8993
},

pkg/tools/builtin/memory_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,15 @@ func TestMemoryTool_StartStop(t *testing.T) {
227227
err = tool.Stop()
228228
require.NoError(t, err)
229229
}
230+
231+
func TestMemoryTool_OutputSchema(t *testing.T) {
232+
tool := NewMemoryTool(nil)
233+
234+
allTools, err := tool.Tools(t.Context())
235+
require.NoError(t, err)
236+
require.NotEmpty(t, allTools)
237+
238+
for _, tool := range allTools {
239+
assert.NotEmpty(t, tool.Function.OutputSchema.Type)
240+
}
241+
}

pkg/tools/builtin/shell.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"os"
88
"os/exec"
9+
"reflect"
910
"runtime"
1011

1112
"github.com/docker/cagent/pkg/tools"
@@ -181,6 +182,7 @@ func (t *ShellTool) Tools(context.Context) ([]tools.Tool, error) {
181182
},
182183
Required: []string{"cmd", "cwd"},
183184
},
185+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
184186
},
185187
Handler: t.handler.CallTool,
186188
},

pkg/tools/builtin/shell_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,15 @@ func TestShellTool_StartStop(t *testing.T) {
216216
err = tool.Stop()
217217
require.NoError(t, err)
218218
}
219+
220+
func TestShellTool_OutputSchema(t *testing.T) {
221+
tool := NewShellTool(nil)
222+
223+
allTools, err := tool.Tools(t.Context())
224+
require.NoError(t, err)
225+
require.NotEmpty(t, allTools)
226+
227+
for _, tool := range allTools {
228+
assert.NotEmpty(t, tool.Function.OutputSchema.Type)
229+
}
230+
}

pkg/tools/builtin/think.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"reflect"
78
"strings"
89

910
"github.com/docker/cagent/pkg/tools"
@@ -74,6 +75,7 @@ func (t *ThinkTool) Tools(context.Context) ([]tools.Tool, error) {
7475
},
7576
Required: []string{"thought"},
7677
},
78+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
7779
},
7880
Handler: t.handler.CallTool,
7981
},

pkg/tools/builtin/think_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,15 @@ func TestThinkTool_StartStop(t *testing.T) {
141141
err = tool.Stop()
142142
require.NoError(t, err)
143143
}
144+
145+
func TestThinkTool_OutputSchema(t *testing.T) {
146+
tool := NewThinkTool()
147+
148+
allTools, err := tool.Tools(t.Context())
149+
require.NoError(t, err)
150+
require.NotEmpty(t, allTools)
151+
152+
for _, tool := range allTools {
153+
assert.NotEmpty(t, tool.Function.OutputSchema.Type)
154+
}
155+
}

pkg/tools/builtin/todo.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"reflect"
78
"strings"
89

910
"github.com/docker/cagent/pkg/tools"
@@ -163,6 +164,7 @@ func (t *TodoTool) Tools(context.Context) ([]tools.Tool, error) {
163164
},
164165
Required: []string{"description"},
165166
},
167+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
166168
},
167169
Handler: t.handler.createTodo,
168170
},
@@ -194,6 +196,7 @@ func (t *TodoTool) Tools(context.Context) ([]tools.Tool, error) {
194196
},
195197
Required: []string{"todos"},
196198
},
199+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
197200
},
198201
Handler: t.handler.createTodos,
199202
},
@@ -220,6 +223,7 @@ func (t *TodoTool) Tools(context.Context) ([]tools.Tool, error) {
220223
},
221224
Required: []string{"id", "status"},
222225
},
226+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
223227
},
224228
Handler: t.handler.updateTodo,
225229
},
@@ -231,6 +235,7 @@ func (t *TodoTool) Tools(context.Context) ([]tools.Tool, error) {
231235
ReadOnlyHint: &[]bool{true}[0],
232236
Title: "List TODOs",
233237
},
238+
OutputSchema: tools.ToOutputSchemaSchema(reflect.TypeFor[string]()),
234239
},
235240
Handler: t.handler.listTodos,
236241
},

0 commit comments

Comments
 (0)