diff --git a/contrib/mark3labs/mcp-go/mcpgo.go b/contrib/mark3labs/mcp-go/mcpgo.go index fe9ec4dccf..f2214e4fbc 100644 --- a/contrib/mark3labs/mcp-go/mcpgo.go +++ b/contrib/mark3labs/mcp-go/mcpgo.go @@ -59,6 +59,8 @@ func NewToolHandlerMiddleware() server.ToolHandlerMiddleware { outputText = string(resultJSON) } + tagWithSessionID(ctx, toolSpan) + toolSpan.AnnotateTextIO(string(inputJSON), outputText) if err != nil { @@ -86,6 +88,7 @@ func (h *hooks) onBeforeInitialize(ctx context.Context, id any, request *mcp.Ini taskSpan.Annotate(llmobs.WithAnnotatedTags(map[string]string{"client_name": clientName, "client_version": clientName + "_" + clientVersion})) h.spanCache.Store(id, taskSpan) + tagWithSessionID(ctx, taskSpan) } func (h *hooks) onAfterInitialize(ctx context.Context, id any, request *mcp.InitializeRequest, result *mcp.InitializeResult) { @@ -114,6 +117,14 @@ func (h *hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, messa span.Finish(llmobs.WithError(err)) } +func tagWithSessionID(ctx context.Context, span llmobs.Span) { + session := server.ClientSessionFromContext(ctx) + if session != nil { + sessionID := session.SessionID() + span.Annotate(llmobs.WithAnnotatedTags(map[string]string{"mcp_session_id": sessionID})) + } +} + func finishSpanWithIO[Req any, Res any](h *hooks, id any, request Req, result Res) { value, ok := h.spanCache.LoadAndDelete(id) if !ok { diff --git a/contrib/mark3labs/mcp-go/mcpgo_test.go b/contrib/mark3labs/mcp-go/mcpgo_test.go index 1be3a419bb..678feb6974 100644 --- a/contrib/mark3labs/mcp-go/mcpgo_test.go +++ b/contrib/mark3labs/mcp-go/mcpgo_test.go @@ -52,6 +52,11 @@ func TestIntegrationSessionInitialize(t *testing.T) { server.WithHooks(hooks)) ctx := context.Background() + sessionID := "test-session-init" + session := &mockSession{id: sessionID} + session.Initialize() + ctx = srv.WithContext(ctx, session) + initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}` response := srv.HandleMessage(ctx, []byte(initRequest)) @@ -77,6 +82,8 @@ func TestIntegrationSessionInitialize(t *testing.T) { assert.Contains(t, taskSpan.Tags, "client_name:test-client") assert.Contains(t, taskSpan.Tags, "client_version:test-client_1.0.0") + assert.Contains(t, taskSpan.Tags, "mcp_session_id:test-session-init") + assert.Contains(t, taskSpan.Meta, "input") assert.Contains(t, taskSpan.Meta, "output") @@ -101,7 +108,11 @@ func TestIntegrationToolCallSuccess(t *testing.T) { tt := testTracer(t) defer tt.Stop() + hooks := &server.Hooks{} + AddServerHooks(hooks) + srv := server.NewMCPServer("test-server", "1.0.0", + server.WithHooks(hooks), server.WithToolHandlerMiddleware(NewToolHandlerMiddleware())) calcTool := mcp.NewTool("calculator", @@ -133,9 +144,13 @@ func TestIntegrationToolCallSuccess(t *testing.T) { session.Initialize() ctx = srv.WithContext(ctx, session) + initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}` + response := srv.HandleMessage(ctx, []byte(initRequest)) + assert.NotNil(t, response) + toolCallRequest := `{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"calculator","arguments":{"operation":"add","x":5,"y":3}}}` - response := srv.HandleMessage(ctx, []byte(toolCallRequest)) + response = srv.HandleMessage(ctx, []byte(toolCallRequest)) assert.NotNil(t, response) responseBytes, err := json.Marshal(response) @@ -147,10 +162,25 @@ func TestIntegrationToolCallSuccess(t *testing.T) { assert.Equal(t, "2.0", resp["jsonrpc"]) assert.NotNil(t, resp["result"]) - spans := tt.WaitForLLMObsSpans(t, 1) - require.Len(t, spans, 1) + spans := tt.WaitForLLMObsSpans(t, 2) + require.Len(t, spans, 2) + + var initSpan, toolSpan *testtracer.LLMObsSpan + for i := range spans { + if spans[i].Name == "mcp.initialize" { + initSpan = &spans[i] + } else if spans[i].Name == "calculator" { + toolSpan = &spans[i] + } + } + + require.NotNil(t, initSpan, "initialize span not found") + require.NotNil(t, toolSpan, "tool span not found") + + expectedTag := "mcp_session_id:test-session-123" + assert.Contains(t, initSpan.Tags, expectedTag) + assert.Contains(t, toolSpan.Tags, expectedTag) - toolSpan := spans[0] assert.Equal(t, "calculator", toolSpan.Name) assert.Equal(t, "tool", toolSpan.Meta["span.kind"]) @@ -218,6 +248,8 @@ func TestIntegrationToolCallError(t *testing.T) { assert.Equal(t, "error_tool", toolSpan.Name) assert.Equal(t, "tool", toolSpan.Meta["span.kind"]) + assert.Contains(t, toolSpan.Tags, "mcp_session_id:test-session-456") + assert.Contains(t, toolSpan.Meta, "error.message") assert.Contains(t, toolSpan.Meta["error.message"], "intentional test error") assert.Contains(t, toolSpan.Meta, "error.type")