Skip to content

Commit 0dbdc42

Browse files
committed
Add tests for session ID tagging
1 parent f2727e3 commit 0dbdc42

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

contrib/mark3labs/mcp-go/mcpgo.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ func NewToolHandlerMiddleware() server.ToolHandlerMiddleware {
5959
outputText = string(resultJSON)
6060
}
6161

62+
tagWithSessionID(ctx, toolSpan)
63+
6264
toolSpan.AnnotateTextIO(string(inputJSON), outputText)
6365

6466
if err != nil {
@@ -92,6 +94,7 @@ func newHooks() *hooks {
9294

9395
func (h *hooks) onBeforeInitialize(ctx context.Context, id any, request *mcp.InitializeRequest) {
9496
taskSpan, _ := llmobs.StartTaskSpan(ctx, "mcp.initialize", llmobs.WithIntegration("mark3labs/mcp-go"))
97+
tagWithSessionID(ctx, taskSpan)
9598
h.spanCache.Set(id, taskSpan, ttlcache.DefaultTTL)
9699
}
97100

@@ -117,6 +120,14 @@ func (h *hooks) stop() {
117120
h.spanCache.Stop()
118121
}
119122

123+
func tagWithSessionID(ctx context.Context, span llmobs.Span) {
124+
session := server.ClientSessionFromContext(ctx)
125+
if session != nil {
126+
sessionID := session.SessionID()
127+
span.Annotate(llmobs.WithAnnotatedTags(map[string]string{"mcp_session_id": sessionID}))
128+
}
129+
}
130+
120131
func finishSpanWithIO[Req any, Res any](h *hooks, id any, request Req, result Res) {
121132
if item := h.spanCache.Get(id); item != nil {
122133
span := item.Value()

contrib/mark3labs/mcp-go/mcpgo_test.go

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ func TestIntegrationSessionInitialize(t *testing.T) {
5656
server.WithHooks(hooks))
5757

5858
ctx := context.Background()
59+
sessionID := "test-session-init"
60+
session := &mockSession{id: sessionID}
61+
session.Initialize()
62+
ctx = srv.WithContext(ctx, session)
63+
5964
initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}`
6065

6166
response := srv.HandleMessage(ctx, []byte(initRequest))
@@ -78,6 +83,8 @@ func TestIntegrationSessionInitialize(t *testing.T) {
7883
assert.Equal(t, "mcp.initialize", taskSpan.Name)
7984
assert.Equal(t, "task", taskSpan.Meta["span.kind"])
8085

86+
assert.Contains(t, taskSpan.Tags, "mcp_session_id:test-session-init")
87+
8188
assert.Contains(t, taskSpan.Meta, "input")
8289
assert.Contains(t, taskSpan.Meta, "output")
8390

@@ -101,7 +108,12 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
101108
tt := testTracer(t)
102109
defer tt.Stop()
103110

111+
hooks := &server.Hooks{}
112+
cleanup := AddServerHooks(hooks)
113+
defer cleanup()
114+
104115
srv := server.NewMCPServer("test-server", "1.0.0",
116+
server.WithHooks(hooks),
105117
server.WithToolHandlerMiddleware(NewToolHandlerMiddleware()))
106118

107119
calcTool := mcp.NewTool("calculator",
@@ -133,9 +145,13 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
133145
session.Initialize()
134146
ctx = srv.WithContext(ctx, session)
135147

148+
initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}`
149+
response := srv.HandleMessage(ctx, []byte(initRequest))
150+
assert.NotNil(t, response)
151+
136152
toolCallRequest := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"calculator","arguments":{"operation":"add","x":5,"y":3}}}`
137153

138-
response := srv.HandleMessage(ctx, []byte(toolCallRequest))
154+
response = srv.HandleMessage(ctx, []byte(toolCallRequest))
139155
assert.NotNil(t, response)
140156

141157
responseBytes, err := json.Marshal(response)
@@ -147,10 +163,25 @@ func TestIntegrationToolCallSuccess(t *testing.T) {
147163
assert.Equal(t, "2.0", resp["jsonrpc"])
148164
assert.NotNil(t, resp["result"])
149165

150-
spans := tt.WaitForLLMObsSpans(t, 1)
151-
require.Len(t, spans, 1)
166+
spans := tt.WaitForLLMObsSpans(t, 2)
167+
require.Len(t, spans, 2)
168+
169+
var initSpan, toolSpan *testtracer.LLMObsSpan
170+
for i := range spans {
171+
if spans[i].Name == "mcp.initialize" {
172+
initSpan = &spans[i]
173+
} else if spans[i].Name == "calculator" {
174+
toolSpan = &spans[i]
175+
}
176+
}
177+
178+
require.NotNil(t, initSpan, "initialize span not found")
179+
require.NotNil(t, toolSpan, "tool span not found")
180+
181+
expectedTag := "mcp_session_id:test-session-123"
182+
assert.Contains(t, initSpan.Tags, expectedTag)
183+
assert.Contains(t, toolSpan.Tags, expectedTag)
152184

153-
toolSpan := spans[0]
154185
assert.Equal(t, "calculator", toolSpan.Name)
155186
assert.Equal(t, "tool", toolSpan.Meta["span.kind"])
156187

@@ -217,6 +248,8 @@ func TestIntegrationToolCallError(t *testing.T) {
217248
assert.Equal(t, "error_tool", toolSpan.Name)
218249
assert.Equal(t, "tool", toolSpan.Meta["span.kind"])
219250

251+
assert.Contains(t, toolSpan.Tags, "mcp_session_id:test-session-456")
252+
220253
assert.Contains(t, toolSpan.Meta, "error.message")
221254
assert.Contains(t, toolSpan.Meta["error.message"], "intentional test error")
222255
assert.Contains(t, toolSpan.Meta, "error.type")

0 commit comments

Comments
 (0)