diff --git a/.gitignore b/.gitignore index 7ef6ce3..aba8ac5 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,9 @@ mcp-language-server # Temporary files *~ + +# MCP Server config +.mcp.json + +# Markdown files with notes +notes diff --git a/.mcp.json.example b/.mcp.json.example new file mode 100644 index 0000000..811ba11 --- /dev/null +++ b/.mcp.json.example @@ -0,0 +1,15 @@ +{ + "mcpServers": { + "language-server": { + "type": "stdio", + "command": "/Users/orsen/Develop/mcp-language-server/mcp-language-server", + "args": [ + "--workspace", + "/Users/orsen/Develop/mcp-language-server", + "--lsp", + "gopls" + ], + "env": {} + } + } +} diff --git a/internal/lsp/client.go b/internal/lsp/client.go index 3cbbe43..73dafb5 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -154,7 +154,9 @@ func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) ( CodeLens: &protocol.CodeLensClientCapabilities{ DynamicRegistration: true, }, - DocumentSymbol: protocol.DocumentSymbolClientCapabilities{}, + DocumentSymbol: protocol.DocumentSymbolClientCapabilities{ + HierarchicalDocumentSymbolSupport: true, + }, CodeAction: protocol.CodeActionClientCapabilities{ CodeActionLiteralSupport: protocol.ClientCodeActionLiteralOptions{ CodeActionKind: protocol.ClientCodeActionKindOptions{ diff --git a/internal/tools/diagnostics.go b/internal/tools/diagnostics.go index d033aab..f747a9d 100644 --- a/internal/tools/diagnostics.go +++ b/internal/tools/diagnostics.go @@ -42,70 +42,91 @@ func GetDiagnosticsForFile(ctx context.Context, client *lsp.Client, filePath str return "No diagnostics found for " + filePath, nil } + // Create a summary header + summary := fmt.Sprintf("Diagnostics for %s (%d issues)\n", + filePath, + len(diagnostics)) + // Format the diagnostics var formattedDiagnostics []string - for _, diag := range diagnostics { + formattedDiagnostics = append(formattedDiagnostics, summary) + + for i, diag := range diagnostics { severity := getSeverityString(diag.Severity) - location := fmt.Sprintf("Line %d, Column %d", + location := fmt.Sprintf("L%d:C%d", diag.Range.Start.Line+1, diag.Range.Start.Character+1) // Get the file content for context if needed var codeContext string - startLine := diag.Range.Start.Line + 1 + var startLine uint32 + + // Always get at least the line with the diagnostic + content, err := os.ReadFile(filePath) + if err == nil { + lines := strings.Split(string(content), "\n") + if int(diag.Range.Start.Line) < len(lines) { + codeContext = strings.TrimSpace(lines[diag.Range.Start.Line]) + + // Truncate line if it's too long + const maxLineLength = 80 + if len(codeContext) > maxLineLength { + startChar := int(diag.Range.Start.Character) + if startChar > maxLineLength/2 { + codeContext = "..." + codeContext[startChar-maxLineLength/2:] + } + if len(codeContext) > maxLineLength { + codeContext = codeContext[:maxLineLength] + "..." + } + } + } + } + + // Get more context if requested if includeContext { - content, loc, err := GetFullDefinition(ctx, client, protocol.Location{ + extendedContext, loc, err := GetFullDefinition(ctx, client, protocol.Location{ URI: uri, Range: diag.Range, }) - startLine = loc.Range.Start.Line + 1 - if err != nil { - log.Printf("failed to get file content: %v", err) - } else { - codeContext = content - } - } else { - // Read just the line with the error - content, err := os.ReadFile(filePath) if err == nil { - lines := strings.Split(string(content), "\n") - if int(diag.Range.Start.Line) < len(lines) { - codeContext = lines[diag.Range.Start.Line] + startLine = loc.Range.Start.Line + 1 + if showLineNumbers { + extendedContext = addLineNumbers(extendedContext, int(startLine)) } + codeContext = extendedContext } } - formattedDiag := fmt.Sprintf( - "%s\n[%s] %s\n"+ - "Location: %s\n"+ - "Message: %s\n", - strings.Repeat("=", 60), + // Create a concise diagnostic entry + var formattedDiag strings.Builder + formattedDiag.WriteString(fmt.Sprintf("%d. [%s] %s - %s\n", + i+1, severity, - filePath, location, - diag.Message) + diag.Message)) + // Add source and code if present, but keep it compact + var details []string if diag.Source != "" { - formattedDiag += fmt.Sprintf("Source: %s\n", diag.Source) + details = append(details, fmt.Sprintf("Source: %s", diag.Source)) } - if diag.Code != nil { - formattedDiag += fmt.Sprintf("Code: %v\n", diag.Code) + details = append(details, fmt.Sprintf("Code: %v", diag.Code)) } - formattedDiag += strings.Repeat("=", 60) + if len(details) > 0 { + formattedDiag.WriteString(fmt.Sprintf(" %s\n", strings.Join(details, ", "))) + } + // Add code context if codeContext != "" { - if showLineNumbers { - codeContext = addLineNumbers(codeContext, int(startLine)) - } - formattedDiag += fmt.Sprintf("\n%s\n", codeContext) + formattedDiag.WriteString(fmt.Sprintf(" > %s\n", codeContext)) } - formattedDiagnostics = append(formattedDiagnostics, formattedDiag) + formattedDiagnostics = append(formattedDiagnostics, formattedDiag.String()) } - return strings.Join(formattedDiagnostics, "\n"), nil + return strings.Join(formattedDiagnostics, ""), nil } func getSeverityString(severity protocol.DiagnosticSeverity) string { diff --git a/internal/tools/document_symbols.go b/internal/tools/document_symbols.go new file mode 100644 index 0000000..879c8a7 --- /dev/null +++ b/internal/tools/document_symbols.go @@ -0,0 +1,93 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/isaacphi/mcp-language-server/internal/lsp" + "github.com/isaacphi/mcp-language-server/internal/protocol" + "github.com/isaacphi/mcp-language-server/internal/utilities" +) + +// GetDocumentSymbols retrieves all symbols in a document and formats them in a hierarchical structure +func GetDocumentSymbols(ctx context.Context, client *lsp.Client, filePath string, showLineNumbers bool) (string, error) { + // Open the file if not already open + err := client.OpenFile(ctx, filePath) + if err != nil { + return "", fmt.Errorf("could not open file: %v", err) + } + + // Convert to URI format for LSP protocol + uri := protocol.DocumentUri("file://" + filePath) + + // Create the document symbol parameters + symParams := protocol.DocumentSymbolParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: uri, + }, + } + + // Execute the document symbol request + symResult, err := client.DocumentSymbol(ctx, symParams) + if err != nil { + return "", fmt.Errorf("failed to get document symbols: %v", err) + } + + symbols, err := symResult.Results() + if err != nil { + return "", fmt.Errorf("failed to process document symbols: %v", err) + } + + if len(symbols) == 0 { + return fmt.Sprintf("No symbols found in %s", filePath), nil + } + + var result strings.Builder + result.WriteString(fmt.Sprintf("Symbols in %s\n\n", filePath)) + + // Format symbols hierarchically + formatSymbols(&result, symbols, 0, showLineNumbers) + + return result.String(), nil +} + +// formatSymbols recursively formats symbols with proper indentation +func formatSymbols(sb *strings.Builder, symbols []protocol.DocumentSymbolResult, level int, showLineNumbers bool) { + indent := strings.Repeat(" ", level) + + for _, sym := range symbols { + // Get symbol information + name := sym.GetName() + + // Format location information + location := "" + if showLineNumbers { + r := sym.GetRange() + if r.Start.Line == r.End.Line { + location = fmt.Sprintf("Line %d", r.Start.Line+1) + } else { + location = fmt.Sprintf("Lines %d-%d", r.Start.Line+1, r.End.Line+1) + } + } + + // Use the shared utility to extract kind information + kindStr := utilities.ExtractSymbolKind(sym) + + // Format the symbol entry + if location != "" { + sb.WriteString(fmt.Sprintf("%s%s %s (%s)\n", indent, kindStr, name, location)) + } else { + sb.WriteString(fmt.Sprintf("%s%s %s\n", indent, kindStr, name)) + } + + // Format children if it's a DocumentSymbol + if ds, ok := sym.(*protocol.DocumentSymbol); ok && len(ds.Children) > 0 { + childSymbols := make([]protocol.DocumentSymbolResult, len(ds.Children)) + for i := range ds.Children { + childSymbols[i] = &ds.Children[i] + } + formatSymbols(sb, childSymbols, level+1, showLineNumbers) + } + } +} diff --git a/internal/tools/find-references.go b/internal/tools/find-references.go index e356c5d..e9898db 100644 --- a/internal/tools/find-references.go +++ b/internal/tools/find-references.go @@ -3,97 +3,600 @@ package tools import ( "context" "fmt" + "io" + "log" + "os" + "sort" "strings" "github.com/isaacphi/mcp-language-server/internal/lsp" "github.com/isaacphi/mcp-language-server/internal/protocol" + "github.com/isaacphi/mcp-language-server/internal/utilities" + // "github.com/davecgh/go-spew/spew" // Useful for debugging complex structs ) +// --- At the top of your tools package --- +var debugLogger *log.Logger + +// ScopeIdentifier uniquely identifies a scope (function, method, etc.) in a file +type ScopeIdentifier struct { + URI protocol.DocumentUri + StartLine uint32 + EndLine uint32 + // Adding StartChar and EndChar might make it more unique if needed, but Line is usually enough + // StartChar uint32 + // EndChar uint32 +} + +// ReferencePosition represents a single reference position within a scope +type ReferencePosition struct { + Line uint32 + Character uint32 +} + +// ScopeInfo stores information about a code scope including its name and kind +type ScopeInfo struct { + Name string // Name of the scope (from DocumentSymbol) + Kind protocol.SymbolKind // Kind of the symbol (from DocumentSymbol) + HasKind bool // Whether we have kind information (always true if found via symbol) +} + +func init() { + debugLogger = log.New(io.Discard, "DEBUG_FIND_REFS: ", log.LstdFlags|log.Lmicroseconds) + + enableDebug := os.Getenv("MCP_DEBUG_LOG") == "true" + if enableDebug { + logFilePath := "debug_find_refs.log" + logFileHandle, err := os.OpenFile(logFilePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + // Fallback to stderr if file fails + debugLogger.SetOutput(os.Stderr) // Change output to stderr + debugLogger.Printf("!!! FAILED TO OPEN DEBUG LOG FILE '%s': %v - Logging to Stderr !!!\n", logFilePath, err) + } else { + debugLogger.SetOutput(logFileHandle) // Change output to the file + // Optionally close the file handle gracefully on server shutdown + debugLogger.Printf("--- Debug logging explicitly enabled to file: %s ---", logFilePath) + } + } +} + +// Helper function to find the smallest DocumentSymbol containing the target position +// Returns the symbol and a boolean indicating if found. +func findSymbolContainingPosition(symbols []protocol.DocumentSymbolResult, targetPos protocol.Position, level int) (*protocol.DocumentSymbol, bool) { + indent := strings.Repeat(" ", level) + debugLogger.Printf("%sDEBUG: [%d] findSymbolContainingPosition called for TargetPos: L%d:C%d (0-based)\n", indent, level, targetPos.Line, targetPos.Character) + + var bestMatch *protocol.DocumentSymbol = nil + + for i, symResult := range symbols { + debugLogger.Printf("%sDEBUG: [%d] Checking symbol #%d: Name='%s'\n", indent, level, i, symResult.GetName()) + + ds, ok := symResult.(*protocol.DocumentSymbol) + if !ok { + debugLogger.Printf("%sDEBUG: [%d] Skipping symbol '%s' - not *protocol.DocumentSymbol\n", indent, level, symResult.GetName()) + continue // Skip if it's not the expected type + } + + symRange := ds.GetRange() + debugLogger.Printf("%sDEBUG: [%d] Symbol: '%s', Kind: %d, Range: L%d:C%d - L%d:C%d (0-based)\n", + indent, level, ds.Name, ds.Kind, + symRange.Start.Line, symRange.Start.Character, + symRange.End.Line, symRange.End.Character) + + // Check if the symbol's range contains the target position + posInLineRange := targetPos.Line >= symRange.Start.Line && targetPos.Line <= symRange.End.Line + posInRange := false + if posInLineRange { + // Must be strictly *after* start on start line, and strictly *before* end on end line? No, LSP range includes boundaries. + // Check: Not before start on start line AND not after end on end line. + if targetPos.Line == symRange.Start.Line && targetPos.Character < symRange.Start.Character { + // Before start char on start line - NO MATCH + debugLogger.Printf("%sDEBUG: [%d] RangeCheck: posInLineRange=true. Target char %d < Start char %d on Start Line %d.\n", indent, level, targetPos.Character, symRange.Start.Character, targetPos.Line) + } else if targetPos.Line == symRange.End.Line && targetPos.Character > symRange.End.Character { + // After end char on end line - NO MATCH + debugLogger.Printf("%sDEBUG: [%d] RangeCheck: posInLineRange=true. Target char %d > End char %d on End Line %d.\n", indent, level, targetPos.Character, symRange.End.Character, targetPos.Line) + } else { + posInRange = true + } + } + debugLogger.Printf("%sDEBUG: [%d] RangeCheck Result: posInLineRange=%v, posInRange=%v\n", indent, level, posInLineRange, posInRange) // Log the crucial check result + + if posInRange { + debugLogger.Printf("%sDEBUG: [%d] Position IS within '%s' range. Checking children...\n", indent, level, ds.Name) + // This symbol contains the position. Check children for a more specific match. + var childMatch *protocol.DocumentSymbol = nil + var childFound bool = false + if len(ds.Children) > 0 { + childSymbols := make([]protocol.DocumentSymbolResult, len(ds.Children)) + for i := range ds.Children { + childSymbols[i] = &ds.Children[i] + } + // Pass level + 1 for indentation + childMatch, childFound = findSymbolContainingPosition(childSymbols, targetPos, level+1) + debugLogger.Printf("%sDEBUG: [%d] Recursive call for children of '%s' returned: found=%v, childMatch=%p\n", indent, level, ds.Name, childFound, childMatch) + if childFound { + debugLogger.Printf("%sDEBUG: [%d] Child match name: '%s'\n", indent, level, childMatch.Name) + } + } else { + debugLogger.Printf("%sDEBUG: [%d] Symbol '%s' has no children.\n", indent, level, ds.Name) + } + + if childFound { + // A child is a more specific match + debugLogger.Printf("%sDEBUG: [%d] Updating bestMatch to child: '%s' (%p)\n", indent, level, childMatch.Name, childMatch) + bestMatch = childMatch // Use the child match + } else { + // This symbol is the best match found so far *at or below this level*. + // Compare its size with any existing bestMatch (which could be from a sibling branch's child). + debugLogger.Printf("%sDEBUG: [%d] No better child found for '%s'. Comparing with current bestMatch (%p).\n", indent, level, ds.Name, bestMatch) + + // Determine if this symbol (ds) is better (smaller) than the current bestMatch + isBetter := bestMatch == nil || + // Smaller line range is better + (symRange.End.Line-symRange.Start.Line < bestMatch.Range.End.Line-bestMatch.Range.Start.Line) || + // Same line range, smaller character range is better (more specific) + ((symRange.End.Line-symRange.Start.Line == bestMatch.Range.End.Line-bestMatch.Range.Start.Line) && + (symRange.End.Character-symRange.Start.Character < bestMatch.Range.End.Character-bestMatch.Range.Start.Character)) + + if isBetter { + debugLogger.Printf("%sDEBUG: [%d] Current symbol '%s' IS better than bestMatch (%p). Updating bestMatch.\n", indent, level, ds.Name, bestMatch) + bestMatch = ds // Update bestMatch to this symbol + } else { + debugLogger.Printf("%sDEBUG: [%d] Current symbol '%s' is NOT better than bestMatch ('%s').\n", indent, level, ds.Name, bestMatch.Name) + } + } + } else { + debugLogger.Printf("%sDEBUG: [%d] Position is NOT within '%s' range.\n", indent, level, ds.Name) + } + debugLogger.Printf("%sDEBUG: [%d] --- End Check for Symbol '%s' ---\n", indent, level, ds.Name) + } // End loop through symbols at this level + + debugLogger.Printf("%sDEBUG: [%d] findSymbolContainingPosition returning: found=%v, bestMatch=%p\n", indent, level, bestMatch != nil, bestMatch) + if bestMatch != nil { + debugLogger.Printf("%sDEBUG: [%d] Return Symbol Name: '%s'\n", indent, level, bestMatch.Name) + } + return bestMatch, bestMatch != nil +} + +// Helper function to get text content for a specific range (implementation needed) +// This might use file reading or potentially a custom LSP request if available. +// For simplicity, we'll read the file content here. Could be optimized. +func getTextForRange(ctx context.Context, uri protocol.DocumentUri, fileContent []byte, targetRange protocol.Range) (string, error) { + lines := strings.Split(string(fileContent), "\n") // Assumes LF endings for simplicity here + + startLine := int(targetRange.Start.Line) + endLine := int(targetRange.End.Line) + startChar := int(targetRange.Start.Character) + endChar := int(targetRange.End.Character) + + if startLine < 0 || startLine >= len(lines) || endLine < 0 || endLine >= len(lines) || startLine > endLine { + return "", fmt.Errorf("invalid range for file content: lines %d-%d (file has %d lines)", startLine+1, endLine+1, len(lines)) + } + + var sb strings.Builder + + if startLine == endLine { + // Single line range + line := lines[startLine] + if startChar > len(line) { + startChar = len(line) + } + if endChar > len(line) { + endChar = len(line) + } + if startChar < 0 { + startChar = 0 + } + if endChar < 0 { + endChar = 0 + } + if startChar > endChar { + startChar = endChar + } // Ensure start <= end + sb.WriteString(line[startChar:endChar]) + } else { + // Multi-line range + // Start line: from startChar to end + firstLine := lines[startLine] + if startChar > len(firstLine) { + startChar = len(firstLine) + } + if startChar < 0 { + startChar = 0 + } + sb.WriteString(firstLine[startChar:]) + sb.WriteString("\n") // Add newline separator + + // Middle lines: entire lines + for i := startLine + 1; i < endLine; i++ { + sb.WriteString(lines[i]) + sb.WriteString("\n") + } + + // End line: from beginning to endChar + lastLine := lines[endLine] + if endChar > len(lastLine) { + endChar = len(lastLine) + } + if endChar < 0 { + endChar = 0 + } + sb.WriteString(lastLine[:endChar]) + } + + return sb.String(), nil +} + func FindReferences(ctx context.Context, client *lsp.Client, symbolName string, showLineNumbers bool) (string, error) { - // First get the symbol location like ReadDefinition does - symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{ - Query: symbolName, - }) + // --- Stage 1: Find Symbol Definitions --- + symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{Query: symbolName}) if err != nil { return "", fmt.Errorf("Failed to fetch symbol: %v", err) } - results, err := symbolResult.Results() if err != nil { return "", fmt.Errorf("Failed to parse results: %v", err) } - var allReferences []string + processedLocations := make(map[protocol.Location]struct{}) + var uniqueLocations []protocol.Location for _, symbol := range results { if symbol.GetName() != symbolName { continue } - - // Get the location of the symbol loc := symbol.GetLocation() + // Ensure loc is valid (sometimes workspace/symbol might return incomplete info) + if loc.URI == "" || loc.Range.Start.Line == 0 && loc.Range.Start.Character == 0 && loc.Range.End.Line == 0 && loc.Range.End.Character == 0 { + // debugLogger.Printf( "Warning: Skipping invalid location for symbol %s\n", symbolName) + continue + } + if _, exists := processedLocations[loc]; !exists { + processedLocations[loc] = struct{}{} + uniqueLocations = append(uniqueLocations, loc) + } + } + if len(uniqueLocations) == 0 { + return fmt.Sprintf("Symbol definition not found for: %s", symbolName), nil + } - // Use LSP references request with correct params structure - refsParams := protocol.ReferenceParams{ + // --- Stage 2: Find All References --- + var allFoundRefs []protocol.Location + for _, loc := range uniqueLocations { + refsParams := protocol.ReferenceParams{ /* ... as before ... */ TextDocumentPositionParams: protocol.TextDocumentPositionParams{ - TextDocument: protocol.TextDocumentIdentifier{ - URI: loc.URI, - }, - Position: loc.Range.Start, - }, - Context: protocol.ReferenceContext{ - IncludeDeclaration: false, + TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI}, + Position: loc.Range.Start, }, + Context: protocol.ReferenceContext{IncludeDeclaration: false}, } - refs, err := client.References(ctx, refsParams) if err != nil { - return "", fmt.Errorf("Failed to get references: %v", err) + // Log or report, but continue if possible + debugLogger.Printf("Warning: Failed to get references for definition at %s:%d: %v\n", + loc.URI, loc.Range.Start.Line+1, err) + continue + } + allFoundRefs = append(allFoundRefs, refs...) + } + totalRefs := len(allFoundRefs) + if totalRefs == 0 { + return fmt.Sprintf("No references found for symbol: %s (definition found at %d location(s))", symbolName, len(uniqueLocations)), nil + } + + // --- Stage 3: Group References by File and Scope --- + refsByFile := make(map[protocol.DocumentUri][]protocol.Location) + for _, ref := range allFoundRefs { + refsByFile[ref.URI] = append(refsByFile[ref.URI], ref) + } + + allReferences := []string{fmt.Sprintf("Symbol: %s (%d references in %d files)", symbolName, totalRefs, len(refsByFile))} + + filesProcessed := 0 + for uri, fileRefs := range refsByFile { + filesProcessed++ + filePath := strings.TrimPrefix(string(uri), "file://") + // Sort refs by position within the file + sort.Slice(fileRefs, func(i, j int) bool { /* ... as before ... */ + if fileRefs[i].Range.Start.Line != fileRefs[j].Range.Start.Line { + return fileRefs[i].Range.Start.Line < fileRefs[j].Range.Start.Line + } + return fileRefs[i].Range.Start.Character < fileRefs[j].Range.Start.Character + }) + allReferences = append(allReferences, fmt.Sprintf("File: %s (%d references)", filePath, len(fileRefs))) + + // --- Sub-Stage 3a: Get Symbols and File Content Once Per File --- + var docSymbols []protocol.DocumentSymbolResult + symParams := protocol.DocumentSymbolParams{TextDocument: protocol.TextDocumentIdentifier{URI: uri}} + symResult, symErr := client.DocumentSymbol(ctx, symParams) + if symErr == nil { + docSymbols, _ = symResult.Results() + // Check if we got DocumentSymbol, not SymbolInformation + if len(docSymbols) > 0 { + if _, ok := docSymbols[0].(*protocol.DocumentSymbol); !ok { + debugLogger.Printf("Warning: Received SymbolInformation instead of DocumentSymbol for %s, scope identification might be limited.\n", uri) + docSymbols = nil // Treat as no symbols found for our purpose + } + } + } else { + debugLogger.Printf("Warning: Failed to get document symbols for %s: %v\n", uri, symErr) } - // Group references by file - refsByFile := make(map[protocol.DocumentUri][]protocol.Location) - for _, ref := range refs { - refsByFile[ref.URI] = append(refsByFile[ref.URI], ref) + // Read file content once for fetching scope text later + fileContent, readErr := os.ReadFile(filePath) + if readErr != nil { + debugLogger.Printf("Warning: Failed to read file content for %s: %v. Scope text will be unavailable.\n", filePath, readErr) + fileContent = nil // Mark content as unavailable } - // Process each file's references - for uri, fileRefs := range refsByFile { - // Format file header similarly to ReadDefinition style - fileInfo := fmt.Sprintf("\n%s\nFile: %s\nReferences in File: %d\n%s\n", - strings.Repeat("=", 60), - strings.TrimPrefix(string(uri), "file://"), - len(fileRefs), - strings.Repeat("=", 60)) - allReferences = append(allReferences, fileInfo) + // --- Sub-Stage 3b: Group References by Symbol Scope --- + scopeRefs := make(map[ScopeIdentifier][]ReferencePosition) + scopeInfos := make(map[ScopeIdentifier]ScopeInfo) + scopeTexts := make(map[ScopeIdentifier]string) // Store text based on symbol range - for _, ref := range fileRefs { - // Use GetFullDefinition but with a smaller context window - snippet, _, err := GetFullDefinition(ctx, client, ref) + for _, ref := range fileRefs { + var containingSymbol *protocol.DocumentSymbol + var foundSymbol bool + + // ** KEY CHANGE: Find the symbol containing the *reference position* ** + if len(docSymbols) > 0 { + // Call the debugged function with initial level 0 + debugLogger.Printf("\n--- Searching for symbol containing reference at L%d:C%d (0-based Line %d) ---\n", ref.Range.Start.Line+1, ref.Range.Start.Character+1, ref.Range.Start.Line) + containingSymbol, foundSymbol = findSymbolContainingPosition(docSymbols, ref.Range.Start, 0) // Start recursion level at 0 + debugLogger.Printf("--- Search complete for L%d:C%d. Found: %v ---\n\n", ref.Range.Start.Line+1, ref.Range.Start.Character+1, foundSymbol) + } + + var scopeID ScopeIdentifier + var scopeRange protocol.Range // The range used for fetching text + + if foundSymbol { + // --- Case 1: Reference is within a known symbol --- + scopeRange = containingSymbol.Range // Use the symbol's range + scopeID = ScopeIdentifier{ + URI: uri, + StartLine: containingSymbol.Range.Start.Line, + EndLine: containingSymbol.Range.End.Line, + // Optional: Add character info if needed for uniqueness: + // StartChar: containingSymbol.Range.Start.Character, + // EndChar: containingSymbol.Range.End.Character, + } + + // Store scope info only once per symbol + if _, exists := scopeInfos[scopeID]; !exists { + scopeInfos[scopeID] = ScopeInfo{ + Name: containingSymbol.Name, + Kind: containingSymbol.Kind, + HasKind: true, // We got it from a symbol + } + // Fetch and store text for this symbol's range + if fileContent != nil { + text, err := getTextForRange(ctx, uri, fileContent, scopeRange) + if err == nil { + scopeTexts[scopeID] = text + } else { + debugLogger.Printf("Warning: Failed to get text for symbol %s range (%d-%d): %v\n", containingSymbol.Name, scopeRange.Start.Line+1, scopeRange.End.Line+1, err) + scopeTexts[scopeID] = fmt.Sprintf("Error fetching text for symbol '%s'", containingSymbol.Name) + } + } else { + scopeTexts[scopeID] = "[File content unavailable]" + } + } + + } else { + // --- Case 2: Reference is NOT within a known symbol (e.g., top-level, import, comment) --- + // Fallback: Use context snippet approach + contextLines := 5 + scopeText, scopeLoc, err := GetDefinitionWithContext(ctx, client, ref, contextLines) if err != nil { + debugLogger.Printf("Warning: Could not get context for reference outside symbol at %s:%d: %v\n", ref.URI, ref.Range.Start.Line+1, err) + // Create a dummy scopeID just for this reference if needed, or skip continue } - if showLineNumbers { - snippet = addLineNumbers(snippet, int(ref.Range.Start.Line)+1) + scopeRange = scopeLoc.Range // Use the context range + scopeID = ScopeIdentifier{ // Create ID based on context range + URI: uri, + StartLine: scopeLoc.Range.Start.Line, + EndLine: scopeLoc.Range.End.Line, } - // Format reference location info - refInfo := fmt.Sprintf("Reference at Line %d, Column %d:\n%s\n%s\n", - ref.Range.Start.Line+1, - ref.Range.Start.Character+1, - strings.Repeat("-", 40), - snippet) + // Store info for this fallback scope only once + if _, exists := scopeInfos[scopeID]; !exists { + scopeInfos[scopeID] = ScopeInfo{ + Name: fmt.Sprintf("Context near L%d", ref.Range.Start.Line+1), + Kind: 0, // Unknown kind + HasKind: false, + } + scopeTexts[scopeID] = scopeText // Store the fetched context text + } + } - allReferences = append(allReferences, refInfo) + // Add the reference position to the determined scope (symbol-based or context-based) + position := ReferencePosition{ + Line: ref.Range.Start.Line, + Character: ref.Range.Start.Character, } + scopeRefs[scopeID] = append(scopeRefs[scopeID], position) + + } // End loop through references in file + + // --- Stage 4: Format Output --- + // Get the keys (scopeIDs) and sort them by starting line + scopeIDs := make([]ScopeIdentifier, 0, len(scopeRefs)) + for id := range scopeRefs { + scopeIDs = append(scopeIDs, id) } - } + sort.Slice(scopeIDs, func(i, j int) bool { /* ... as before ... */ + return scopeIDs[i].StartLine < scopeIDs[j].StartLine + }) - if len(allReferences) == 0 { - banner := strings.Repeat("=", 80) + "\n" - return fmt.Sprintf("%sNo references found for symbol: %s\n%s", - banner, symbolName, banner), nil - } + // Loop through sorted scopes and format output + for _, scopeID := range scopeIDs { + positions := scopeRefs[scopeID] + scopeInfo := scopeInfos[scopeID] + scopeText := scopeTexts[scopeID] // Get the stored text + + // Debug info (now reflects symbol finding) + // debugInfo := fmt.Sprintf("DEBUG: Scope='%s', HasKind=%v, Kind=%d (L%d-%d)", + // scopeInfo.Name, scopeInfo.HasKind, scopeInfo.Kind, scopeID.StartLine+1, scopeID.EndLine+1) + // allReferences = append(allReferences, " "+debugInfo) + + // Format scope header (using Kind if HasKind is true) + var scopeHeader string + if scopeInfo.HasKind { + kindStr := utilities.GetSymbolKindString(scopeInfo.Kind) + displayName := scopeInfo.Name + if kindStr != "" && kindStr != "Unknown" { + displayName = fmt.Sprintf("%s %s", kindStr, scopeInfo.Name) + } + scopeHeader = fmt.Sprintf(" %s (lines %d-%d, %d references)", displayName, scopeID.StartLine+1, scopeID.EndLine+1, len(positions)) + } else { + scopeHeader = fmt.Sprintf(" Scope: %s (lines %d-%d, %d references)", scopeInfo.Name, scopeID.StartLine+1, scopeID.EndLine+1, len(positions)) + } + allReferences = append(allReferences, scopeHeader) + + // Format reference positions (no changes) + var positionStrs []string + var highlightLineIndices []int // Relative to the start of the scopeText + for _, pos := range positions { + positionStrs = append(positionStrs, fmt.Sprintf("L%d:C%d", pos.Line+1, pos.Character+1)) + // Calculate highlight index relative to scope start + highlightLineIndices = append(highlightLineIndices, int(pos.Line-scopeID.StartLine)) + } + // ... (chunking logic as before) ... + const chunkSize = 4 + for i := 0; i < len(positionStrs); i += chunkSize { + end := i + chunkSize + if end > len(positionStrs) { + end = len(positionStrs) + } + positionChunk := positionStrs[i:end] + allReferences = append(allReferences, fmt.Sprintf(" References: %s", strings.Join(positionChunk, ", "))) + } + + // Format scope text (truncation, line numbers, highlighting) + scopeLines := strings.Split(scopeText, "\n") // Use the stored text + + // --- Truncation Logic --- (needs adjustment for highlightLineIndices) + finalScopeLines := scopeLines // Start with original lines + finalHighlightIndices := highlightLineIndices // Start with original indices + if len(scopeLines) > 50 { + // ... (Existing truncation logic, BUT ensure it correctly maps original highlightLineIndices to the indices in the *truncated* output) ... + + // Simplified recalculation (can be improved for precision) + importantLines := make(map[int]bool) + for i := 0; i < 5 && i < len(scopeLines); i++ { + importantLines[i] = true + } + for i := len(scopeLines) - 3; i < len(scopeLines) && i >= 0; i++ { + importantLines[i] = true + } + for _, hlLine := range highlightLineIndices { // Use original indices here + for offset := -2; offset <= 2; offset++ { + lineIdx := hlLine + offset + if lineIdx >= 0 && lineIdx < len(scopeLines) { + importantLines[lineIdx] = true + } + } + } + + var truncatedLines []string + originalToTruncatedIndexMap := make(map[int]int) + currentTruncatedIndex := 0 + inSkipSection := false + lastShownIndex := -1 + + for i := 0; i < len(scopeLines); i++ { + if importantLines[i] { + if inSkipSection { + truncatedLines = append(truncatedLines, fmt.Sprintf(" ... %d lines skipped ...", i-lastShownIndex-1)) + currentTruncatedIndex++ // Account for the skip line + inSkipSection = false + } + truncatedLines = append(truncatedLines, scopeLines[i]) + originalToTruncatedIndexMap[i] = currentTruncatedIndex // Map original index to truncated index + currentTruncatedIndex++ + lastShownIndex = i + } else if !inSkipSection && lastShownIndex >= 0 { + inSkipSection = true + } + } + if inSkipSection && lastShownIndex < len(scopeLines)-1 { + skippedLines := len(scopeLines) - lastShownIndex - 1 + if skippedLines > 0 { + truncatedLines = append(truncatedLines, fmt.Sprintf(" ... %d lines skipped ...", skippedLines)) + } + } + + // Recalculate highlight indices based on the map + newHighlightIndices := []int{} + for _, origIdx := range highlightLineIndices { + if truncatedIdx, ok := originalToTruncatedIndexMap[origIdx]; ok { + newHighlightIndices = append(newHighlightIndices, truncatedIdx) + } + } + + finalScopeLines = truncatedLines // Use the truncated lines for display + finalHighlightIndices = newHighlightIndices // Use the new indices for highlighting + + } // End truncation + + // --- Line Numbering / Formatting --- + var formattedScope strings.Builder + lineNum := int(scopeID.StartLine) + 1 // Start numbering from original scope start + + for i, line := range finalScopeLines { + isRef := false + for _, hl := range finalHighlightIndices { // Use potentially recalculated indices + if i == hl { + isRef = true + break + } + } + + if strings.Contains(line, "lines skipped") { + // Handle skip marker line + if showLineNumbers { + var skipped int + fmt.Sscanf(line, " ... %d lines skipped ...", &skipped) // Ignore error, default skip is 1 line display adjust + formattedScope.WriteString(line + "\n") + lineNum += skipped // Adjust line number count + } else { + formattedScope.WriteString(line + "\n") // Show skip marker even without line nums + } + } else { + // Handle regular code line + if showLineNumbers { + numStr := fmt.Sprintf("%d", lineNum) + padding := strings.Repeat(" ", 5-len(numStr)) + marker := "|" + if isRef { + marker = ">" + } + formattedScope.WriteString(fmt.Sprintf("%s%s%s %s\n", padding, numStr, marker, line)) + } else { + // Add simple marker even without line numbers + marker := " " // Indent non-ref lines + if isRef { + marker = "> " + } + formattedScope.WriteString(marker + line + "\n") + } + lineNum++ // Increment for the next actual code line + } + } + + // Add the formatted scope with indentation + trimmedFormattedScope := strings.TrimRight(formattedScope.String(), " \n\t") + allReferences = append(allReferences, " "+strings.ReplaceAll(trimmedFormattedScope, "\n", "\n ")) + + } // End loop through scopes + + // Add blank line between files + if filesProcessed < len(refsByFile) { + allReferences = append(allReferences, "") + } + + } // End loop through files return strings.Join(allReferences, "\n"), nil } diff --git a/internal/tools/get-codelens.go b/internal/tools/get-codelens.go index bc71536..c6f3611 100644 --- a/internal/tools/get-codelens.go +++ b/internal/tools/get-codelens.go @@ -39,8 +39,7 @@ func GetCodeLens(ctx context.Context, client *lsp.Client, filePath string) (stri // Format the code lens results var output strings.Builder - output.WriteString(fmt.Sprintf("Code Lens results for %s:\n", filePath)) - output.WriteString(strings.Repeat("=", 80) + "\n\n") + output.WriteString(fmt.Sprintf("Code Lens results for %s:\n\n", filePath)) for i, lens := range codeLensResult { output.WriteString(fmt.Sprintf("[%d] Location: Lines %d-%d\n", diff --git a/internal/tools/hover.go b/internal/tools/hover.go new file mode 100644 index 0000000..913060e --- /dev/null +++ b/internal/tools/hover.go @@ -0,0 +1,56 @@ +package tools + +import ( + "context" + "fmt" + "strings" + + "github.com/isaacphi/mcp-language-server/internal/lsp" + "github.com/isaacphi/mcp-language-server/internal/protocol" +) + +// GetHoverInfo retrieves hover information (type, documentation) for a symbol at the specified position +func GetHoverInfo(ctx context.Context, client *lsp.Client, filePath string, line, column int) (string, error) { + // Open the file if not already open + err := client.OpenFile(ctx, filePath) + if err != nil { + return "", fmt.Errorf("could not open file: %v", err) + } + + // Convert 1-indexed line/column to 0-indexed for LSP protocol + uri := protocol.DocumentUri("file://" + filePath) + position := protocol.Position{ + Line: uint32(line - 1), + Character: uint32(column - 1), + } + + // Create the hover parameters + params := protocol.HoverParams{} + + // Set TextDocument and Position via embedded struct + params.TextDocument = protocol.TextDocumentIdentifier{ + URI: uri, + } + params.Position = position + + // Execute the hover request + hoverResult, err := client.Hover(ctx, params) + if err != nil { + return "", fmt.Errorf("failed to get hover information: %v", err) + } + + var result strings.Builder + result.WriteString("Hover Information\n") + + // Process the hover contents based on Markup content + if hoverResult.Contents.Value == "" { + result.WriteString("No hover information available for this position") + } else { + if hoverResult.Contents.Kind != "" { + result.WriteString(fmt.Sprintf("Kind: %s\n\n", hoverResult.Contents.Kind)) + } + result.WriteString(hoverResult.Contents.Value) + } + + return result.String(), nil +} diff --git a/internal/tools/read-definition.go b/internal/tools/read-definition.go index 776014c..11d1f45 100644 --- a/internal/tools/read-definition.go +++ b/internal/tools/read-definition.go @@ -3,88 +3,351 @@ package tools import ( "context" "fmt" - "log" + "os" + "sort" // Needed for sorting definitions if multiple found "strings" "github.com/isaacphi/mcp-language-server/internal/lsp" "github.com/isaacphi/mcp-language-server/internal/protocol" + "github.com/isaacphi/mcp-language-server/internal/utilities" + // "github.com/davecgh/go-spew/spew" // Useful for debugging complex structs ) +// DefinitionInfo holds the refined information for a single definition +type DefinitionInfo struct { + SymbolName string + SymbolKind protocol.SymbolKind + HasKind bool + FilePath string + Range protocol.Range // The precise range of the definition symbol + DefinitionText string + // ContainerName string // Can be added if needed by traversing DocumentSymbol parents +} + +// ReadDefinition intelligently finds and extracts the definition text for a symbol. +// It prioritizes using documentSymbol for precise range finding. func ReadDefinition(ctx context.Context, client *lsp.Client, symbolName string, showLineNumbers bool) (string, error) { - symbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{ - Query: symbolName, - }) + debugLogger.Printf("--- GetDefinition called for symbol: %s ---\n", symbolName) + + // --- Stage 1: Find *potential* symbol locations --- + // We use workspace/symbol first to get *any* location (definition or usage) to start the process. + wsSymbolResult, err := client.Symbol(ctx, protocol.WorkspaceSymbolParams{Query: symbolName}) if err != nil { - return "", fmt.Errorf("Failed to fetch symbol: %v", err) + return "", fmt.Errorf("failed to fetch workspace symbols for '%s': %w", symbolName, err) } - - results, err := symbolResult.Results() + wsSymbols, err := wsSymbolResult.Results() if err != nil { - return "", fmt.Errorf("Failed to parse results: %v", err) + return "", fmt.Errorf("failed to parse workspace symbol results for '%s': %w", symbolName, err) } - var definitions []string - for _, symbol := range results { - kind := "" - container := "" - - // Skip symbols that we are not looking for. workspace/symbol may return - // a large number of fuzzy matches. - switch v := symbol.(type) { - case *protocol.SymbolInformation: - // SymbolInformation results have richer data. - kind = fmt.Sprintf("Kind: %s\n", protocol.TableKindMap[v.Kind]) - if v.ContainerName != "" { - container = fmt.Sprintf("Container Name: %s\n", v.ContainerName) - } - if v.Kind == protocol.Method && strings.HasSuffix(symbol.GetName(), symbolName) { - break + var initialLocations []protocol.Location + processedURIs := make(map[protocol.DocumentUri]bool) // Avoid hitting definition/documentSymbol multiple times for the same file if symbol has multiple entries there + + debugLogger.Printf("Found %d potential workspace symbols for '%s'\n", len(wsSymbols), symbolName) + for _, symbol := range wsSymbols { + // Strict name match is crucial here + if symbol.GetName() != symbolName { + continue + } + loc := symbol.GetLocation() + // Skip invalid locations or already processed files + if loc.URI == "" || processedURIs[loc.URI] { + continue + } + + // We only need one good starting point per file. + // Using the first match is usually sufficient. + initialLocations = append(initialLocations, loc) + processedURIs[loc.URI] = true + debugLogger.Printf(" -> Found potential initial location in %s at L%d\n", loc.URI, loc.Range.Start.Line+1) + // Optimization: If we only need *one* definition, we could potentially break here. + // But let's find all distinct definitions for completeness. + } + + if len(initialLocations) == 0 { + debugLogger.Printf("No initial locations found via workspace/symbol matching name '%s' exactly.\n", symbolName) + return fmt.Sprintf("Symbol '%s' not found in workspace.", symbolName), nil + } + + // --- Stage 2 & 3: Refine Location & Find Precise Scope --- + var foundDefinitions []DefinitionInfo + processedDefinitionRanges := make(map[string]bool) // Key: "URI:StartLine:StartChar" + + for _, startLoc := range initialLocations { + debugLogger.Printf("\n--- Processing initial location: %s:%d ---\n", startLoc.URI, startLoc.Range.Start.Line+1) + + // --- Stage 2: Use textDocument/definition for canonical location --- + defParams := protocol.DefinitionParams{ + TextDocumentPositionParams: protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: startLoc.URI}, + Position: startLoc.Range.Start, // Use the start of the workspace symbol's range + }, + } + defResult, err := client.Definition(ctx, defParams) + if err != nil { + debugLogger.Printf("Warning: textDocument/definition call failed for %s:%d: %v. Skipping this path.\n", startLoc.URI, startLoc.Range.Start.Line+1, err) + continue // Try next initial location if any + } + + // --- Stage 3: Process each definition location found --- + var definitionLocations []protocol.Location + + // --- Unpack the result --- + // Helper function to extract locations from the potentially nested value + extractLocations := func(value interface{}) ([]protocol.Location, bool) { + var extracted []protocol.Location + switch v := value.(type) { + case nil: + debugLogger.Printf(" Inner definition value is nil.\n") + return nil, true // Successfully processed null, result is empty list + case protocol.Location: + extracted = []protocol.Location{v} + debugLogger.Printf(" Inner definition resolved to Single Location: %s L%d:%d\n", v.URI, v.Range.Start.Line+1, v.Range.Start.Character+1) + return extracted, true + case []protocol.Location: + if len(v) == 0 { + debugLogger.Printf(" Inner definition resolved to an EMPTY slice of Locations.\n") + } else { + debugLogger.Printf(" Inner definition resolved to Multiple Locations (%d)\n", len(v)) + // Optionally log the first few locations + for i := 0; i < len(v) && i < 3; i++ { + debugLogger.Printf(" Loc %d: %s L%d:%d\n", i, v[i].URI, v[i].Range.Start.Line+1, v[i].Range.Start.Character+1) + } + } + extracted = v + return extracted, true + case []protocol.LocationLink: + if len(v) == 0 { + debugLogger.Printf(" Inner definition resolved to an EMPTY slice of LocationLinks.\n") + extracted = []protocol.Location{} // Initialize empty slice + } else { + debugLogger.Printf(" Inner definition resolved to LocationLinks (%d), extracting targets...\n", len(v)) + extracted = make([]protocol.Location, 0, len(v)) // Initialize slice + for linkIdx, link := range v { + targetRange := link.TargetSelectionRange + zeroRange := protocol.Range{} + if targetRange == zeroRange || (targetRange.Start.Line == 0 && targetRange.Start.Character == 0 && targetRange.End.Line == 0 && targetRange.End.Character == 0) { + debugLogger.Printf(" Link %d: TargetSelectionRange is zero/empty, falling back to TargetRange.\n", linkIdx) + targetRange = link.TargetRange + } + + if link.TargetURI == "" { + debugLogger.Printf(" Link %d: Skipping because TargetURI is empty.\n", linkIdx) + continue + } + + if targetRange.Start.Line > targetRange.End.Line || (targetRange.Start.Line == targetRange.End.Line && targetRange.Start.Character > targetRange.End.Character) { + debugLogger.Printf(" Link %d: Skipping Link Target '%s' due to invalid range: L%d:%d - L%d:%d\n", + linkIdx, link.TargetURI, targetRange.Start.Line+1, targetRange.Start.Character+1, targetRange.End.Line+1, targetRange.End.Character+1) + continue + } + + extractedLoc := protocol.Location{ + URI: link.TargetURI, + Range: targetRange, + } + extracted = append(extracted, extractedLoc) + debugLogger.Printf(" Link %d: Extracted Target: %s L%d:%d - L%d:%d\n", + linkIdx, + extractedLoc.URI, + extractedLoc.Range.Start.Line+1, extractedLoc.Range.Start.Character+1, + extractedLoc.Range.End.Line+1, extractedLoc.Range.End.Character+1) + } + if len(extracted) == 0 { + debugLogger.Printf(" Finished processing LocationLinks, but none resulted in a valid Location.\n") + } + } + return extracted, true // Return the (potentially empty) extracted list + + default: + // This case means the *inner* value was unexpected + debugLogger.Printf("Error: Inner definition value contained an unexpected type (%T).\n", value) + return nil, false // Indicate failure to extract } - if symbol.GetName() != symbolName { + } + + // --- Main Type Switch on defResult.Value --- + var ok bool + // ** Adjust the type name 'protocol.Or_Definition' if it's different in your library! ** + switch v := defResult.Value.(type) { + case protocol.Or_Definition: // Check for the nested "Or" type first + debugLogger.Printf(" Definition result Value is type %T, extracting inner value...\n", v) + // Recursively (or directly) check the inner value + definitionLocations, ok = extractLocations(v.Value) + if !ok { + // The inner extraction failed + debugLogger.Printf("Error: Failed to extract locations from nested %T. Skipping this path.\n", v) continue } default: - if symbol.GetName() != symbolName { + // Try extracting directly if it wasn't the nested type + debugLogger.Printf(" Definition result Value is type %T, attempting direct extraction...\n", v) + definitionLocations, ok = extractLocations(v) // v here is defResult.Value + if !ok { + // Direct extraction failed (e.g., default case in extractLocations hit) + debugLogger.Printf("Error: Direct extraction failed for type %T. Skipping this path.\n", v) continue } } - log.Printf("Symbol: %s\n", symbol.GetName()) - loc := symbol.GetLocation() + // Now, check if we successfully extracted any locations after handling potential nesting + if len(definitionLocations) == 0 { + debugLogger.Printf("Warning: No valid definition locations were extracted after processing the response for %s:%d. Skipping to next initial location (if any).\n", startLoc.URI, startLoc.Range.Start.Line+1) + continue // Try next initial location + } - banner := strings.Repeat("=", 80) + "\n" - definition, loc, err := GetFullDefinition(ctx, client, loc) - locationInfo := fmt.Sprintf( - "Symbol: %s\n"+ - "File: %s\n"+ - kind+ - container+ - "Start Position: Line %d, Column %d\n"+ - "End Position: Line %d, Column %d\n"+ - "%s\n", - symbol.GetName(), - strings.TrimPrefix(string(loc.URI), "file://"), - loc.Range.Start.Line+1, - loc.Range.Start.Character+1, - loc.Range.End.Line+1, - loc.Range.End.Character+1, - strings.Repeat("=", 80)) + // --- Proceed with the rest of the loop using the definitionLocations slice --- + processedAnyInThisBatch := false // Track if we successfully process at least one defLoc from this batch + for _, defLoc := range definitionLocations { + // ... (rest of the code: checking defLoc, processedRanges, getting symbols, reading file, getting text, appending results) + // ... (No changes needed in the rest of the loop below this point) ... - if err != nil { - log.Printf("Error getting definition: %v\n", err) - continue - } + // Check if defLoc itself is valid (sometimes servers return empty locations) + if defLoc.URI == "" { + debugLogger.Printf(" -> Skipping an empty/invalid location received from definition result.\n") + continue + } - if showLineNumbers { - definition = addLineNumbers(definition, int(loc.Range.Start.Line)+1) + defLocKey := fmt.Sprintf("%s:%d:%d", defLoc.URI, defLoc.Range.Start.Line, defLoc.Range.Start.Character) + if processedDefinitionRanges[defLocKey] { + debugLogger.Printf(" -> Skipping already processed definition location: %s\n", defLocKey) + continue // Avoid processing the exact same definition multiple times + } + // Mark immediately *before* trying file IO etc. + processedDefinitionRanges[defLocKey] = true + debugLogger.Printf(" -> Processing definition location: %s L%d:%d - L%d:%d\n", defLoc.URI, defLoc.Range.Start.Line+1, defLoc.Range.Start.Character+1, defLoc.Range.End.Line+1, defLoc.Range.End.Character+1) + filePath := strings.TrimPrefix(string(defLoc.URI), "file://") + + // --- Stage 3a: Get Document Symbols for the definition's file --- + var preciseRange protocol.Range = defLoc.Range // Default to definition result range + var defSymbolKind protocol.SymbolKind = 0 + var hasKind bool = false + + docSymParams := protocol.DocumentSymbolParams{TextDocument: protocol.TextDocumentIdentifier{URI: defLoc.URI}} + docSymResult, docSymErr := client.DocumentSymbol(ctx, docSymParams) + + if docSymErr == nil { + docSymbols, _ := docSymResult.Results() + if len(docSymbols) > 0 { + if _, ok := docSymbols[0].(*protocol.DocumentSymbol); ok { + debugLogger.Printf(" -> Searching document symbols in %s for position L%d:%d\n", defLoc.URI, defLoc.Range.Start.Line+1, defLoc.Range.Start.Character+1) + containingSymbol, foundSymbol := findSymbolContainingPosition(docSymbols, defLoc.Range.Start, 0) + + if foundSymbol { + if containingSymbol.Name == symbolName { + debugLogger.Printf(" --> Found matching DocumentSymbol: '%s' (%s), Range: L%d:%d - L%d:%d\n", + containingSymbol.Name, utilities.GetSymbolKindString(containingSymbol.Kind), + containingSymbol.Range.Start.Line+1, containingSymbol.Range.Start.Character+1, + containingSymbol.Range.End.Line+1, containingSymbol.Range.End.Character+1) + preciseRange = containingSymbol.Range + defSymbolKind = containingSymbol.Kind + hasKind = true + } else { + debugLogger.Printf(" --> Found containing DocumentSymbol '%s' but name mismatch (expected '%s'). Using its range: L%d:%d - L%d:%d\n", + containingSymbol.Name, symbolName, + containingSymbol.Range.Start.Line+1, containingSymbol.Range.Start.Character+1, + containingSymbol.Range.End.Line+1, containingSymbol.Range.End.Character+1) + preciseRange = containingSymbol.Range + defSymbolKind = containingSymbol.Kind + hasKind = true + } + } else { + debugLogger.Printf(" --> No specific DocumentSymbol found containing L%d:%d. Using range from textDocument/definition.\n", defLoc.Range.Start.Line+1, defLoc.Range.Start.Character+1) + } + } else { + debugLogger.Printf(" -> Received SymbolInformation instead of DocumentSymbol for %s. Using range from textDocument/definition.\n", defLoc.URI) + } + } else { + debugLogger.Printf(" -> No document symbols returned for %s. Using range from textDocument/definition.\n", defLoc.URI) + } + } else { + debugLogger.Printf("Warning: Failed to get document symbols for %s: %v. Using range from textDocument/definition.\n", defLoc.URI, docSymErr) + } + + // --- Stage 4: Fetch Definition Text using the determined range --- + debugLogger.Printf(" Attempting to read file: %s\n", filePath) + fileContent, readErr := os.ReadFile(filePath) + if readErr != nil { + debugLogger.Printf("Error: Failed to read file content for %s: %v. Skipping this definition location.\n", filePath, readErr) + continue // Skip this defLoc + } + debugLogger.Printf(" Successfully read %d bytes from %s\n", len(fileContent), filePath) + + debugLogger.Printf(" Attempting to extract text for range: L%d:%d - L%d:%d\n", preciseRange.Start.Line+1, preciseRange.Start.Character+1, preciseRange.End.Line+1, preciseRange.End.Character+1) + definitionText, textErr := getTextForRange(ctx, defLoc.URI, fileContent, preciseRange) + if textErr != nil { + debugLogger.Printf("Error: Failed to extract text for range L%d-L%d in %s: %v. Skipping this definition location.\n", preciseRange.Start.Line+1, preciseRange.End.Line+1, filePath, textErr) + continue // Skip this defLoc + } + debugLogger.Printf(" Successfully extracted text (length %d).\n", len(definitionText)) + + // --- Append to Results --- + debugLogger.Printf(" --> SUCCESS: Appending definition to results.\n") + foundDefinitions = append(foundDefinitions, DefinitionInfo{ + SymbolName: symbolName, // Use the requested name + SymbolKind: defSymbolKind, + HasKind: hasKind, + FilePath: filePath, + Range: preciseRange, + DefinitionText: definitionText, + }) + processedAnyInThisBatch = true // Mark success for this batch + + } // End loop through definitionLocations + + if !processedAnyInThisBatch { + debugLogger.Printf(" -> Finished processing all extracted locations for initial location %s:%d, but none resulted in a successful definition append.\n", startLoc.URI, startLoc.Range.Start.Line+1) } - definitions = append(definitions, banner+locationInfo+definition+"\n") + } // End loop through initialLocations + + if len(foundDefinitions) == 0 { + debugLogger.Printf("--- No definitions found after refining locations for '%s' ---\n", symbolName) + // Provide a more informative message if possible + if len(initialLocations) > 0 { + return fmt.Sprintf("Symbol '%s' found in workspace, but could not resolve its precise definition location.", symbolName), nil + } + // Fallback to the original message if even workspace symbols failed + return fmt.Sprintf("Symbol '%s' not found.", symbolName), nil } - if len(definitions) == 0 { - return fmt.Sprintf("%s not found", symbolName), nil + // --- Stage 5: Format Output --- + // Sort definitions by file path then start line for consistent output + sort.Slice(foundDefinitions, func(i, j int) bool { + if foundDefinitions[i].FilePath != foundDefinitions[j].FilePath { + return foundDefinitions[i].FilePath < foundDefinitions[j].FilePath + } + return foundDefinitions[i].Range.Start.Line < foundDefinitions[j].Range.Start.Line + }) + + var output strings.Builder + for i, defInfo := range foundDefinitions { + if i > 0 { + output.WriteString("\n---\n\n") // Separator for multiple definitions + } + + // Header + output.WriteString(fmt.Sprintf("Symbol: %s\n", defInfo.SymbolName)) + if defInfo.HasKind { + kindStr := utilities.GetSymbolKindString(defInfo.SymbolKind) + if kindStr != "" && kindStr != "Unknown" { + output.WriteString(fmt.Sprintf("Kind: %s\n", kindStr)) + } + } + output.WriteString(fmt.Sprintf("File: %s\n", defInfo.FilePath)) + output.WriteString(fmt.Sprintf("Location: Lines %d-%d\n", + defInfo.Range.Start.Line+1, + defInfo.Range.End.Line+1)) + output.WriteString("\n") // Separator before code + + // Code + codeBlock := defInfo.DefinitionText + if showLineNumbers { + codeBlock = addLineNumbers(codeBlock, int(defInfo.Range.Start.Line)+1) + } + output.WriteString(codeBlock) } - return strings.Join(definitions, "\n"), nil + debugLogger.Printf("--- GetDefinition finished for '%s', found %d definition(s) ---\n", symbolName, len(foundDefinitions)) + return output.String(), nil } diff --git a/internal/tools/rename-symbol.go b/internal/tools/rename-symbol.go new file mode 100644 index 0000000..c2c73c9 --- /dev/null +++ b/internal/tools/rename-symbol.go @@ -0,0 +1,74 @@ +package tools + +import ( + "context" + "fmt" + + "github.com/isaacphi/mcp-language-server/internal/lsp" + "github.com/isaacphi/mcp-language-server/internal/protocol" + "github.com/isaacphi/mcp-language-server/internal/utilities" +) + +// RenameSymbol renames a symbol (variable, function, class, etc.) at the specified position +// It uses the LSP rename functionality to handle all references across files +func RenameSymbol(ctx context.Context, client *lsp.Client, filePath string, line, column int, newName string) (string, error) { + // Open the file if not already open + err := client.OpenFile(ctx, filePath) + if err != nil { + return "", fmt.Errorf("could not open file: %v", err) + } + + // Convert 1-indexed line/column to 0-indexed for LSP protocol + uri := protocol.DocumentUri("file://" + filePath) + position := protocol.Position{ + Line: uint32(line - 1), + Character: uint32(column - 1), + } + + // Create the rename parameters + params := protocol.RenameParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: uri, + }, + Position: position, + NewName: newName, + } + + // Skip the PrepareRename check as it might not be supported by all language servers + // Execute the rename directly + + // Execute the rename operation + workspaceEdit, err := client.Rename(ctx, params) + if err != nil { + return "", fmt.Errorf("failed to rename symbol: %v", err) + } + + // Count the changes that will be made + changeCount := 0 + fileCount := 0 + + // Count changes in Changes field + if workspaceEdit.Changes != nil { + fileCount = len(workspaceEdit.Changes) + for _, edits := range workspaceEdit.Changes { + changeCount += len(edits) + } + } + + // Count changes in DocumentChanges field + for _, change := range workspaceEdit.DocumentChanges { + if change.TextDocumentEdit != nil { + fileCount++ + changeCount += len(change.TextDocumentEdit.Edits) + } + } + + // Apply the workspace edit to files + if err := utilities.ApplyWorkspaceEdit(workspaceEdit); err != nil { + return "", fmt.Errorf("failed to apply changes: %v", err) + } + + // Generate a summary of changes made + return fmt.Sprintf("Successfully renamed symbol to '%s'.\nUpdated %d occurrences across %d files.", + newName, changeCount, fileCount), nil +} diff --git a/internal/tools/utilities.go b/internal/tools/utilities.go index fcb44c3..d96cabc 100644 --- a/internal/tools/utilities.go +++ b/internal/tools/utilities.go @@ -216,18 +216,175 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr } // addLineNumbers adds line numbers to each line of text with proper padding, starting from startLine -func addLineNumbers(text string, startLine int) string { +// If highlightLines is provided, those line numbers (0-indexed relative to the start of the text) will be marked +func addLineNumbers(text string, startLine int, highlightLines ...int) string { lines := strings.Split(text, "\n") // Calculate padding width based on the number of digits in the last line number - lastLineNum := startLine + len(lines) + lastLineNum := startLine + len(lines) - 1 padding := len(strconv.Itoa(lastLineNum)) + // Convert highlight lines to a map for efficient lookup + highlights := make(map[int]bool) + for _, line := range highlightLines { + highlights[line] = true + } + var result strings.Builder for i, line := range lines { // Format line number with padding and separator - lineNum := strconv.Itoa(startLine + i) - linePadding := strings.Repeat(" ", padding-len(lineNum)) - result.WriteString(fmt.Sprintf("%s%s|%s\n", linePadding, lineNum, line)) + lineNum := startLine + i + lineNumStr := strconv.Itoa(lineNum) + linePadding := strings.Repeat(" ", padding-len(lineNumStr)) + + // Determine if this line should be highlighted + marker := "|" + if highlights[i] { + marker = ">" // Use '>' to indicate highlighted lines + } + + result.WriteString(fmt.Sprintf("%s%s%s %s\n", linePadding, lineNumStr, marker, line)) } return result.String() } + +// GetDefinitionWithContext returns the text around a given position with configurable context, +// along with the location (Range) corresponding to that returned text. +// contextLines specifies how many lines before and after the reference line to include. +// loc is the location of the original reference point. +func GetDefinitionWithContext(ctx context.Context, client *lsp.Client /* Remove client if not used */, loc protocol.Location, contextLines int) (string, protocol.Location, error) { + // Convert URI to filesystem path + filePath, err := url.PathUnescape(strings.TrimPrefix(string(loc.URI), "file://")) + if err != nil { + return "", protocol.Location{}, fmt.Errorf("failed to unescape URI: %w", err) + } + + // Read the file content + content, err := os.ReadFile(filePath) + if err != nil { + // Return zero location on error + return "", protocol.Location{}, fmt.Errorf("failed to read file '%s': %w", filePath, err) + } + + // It's generally safer to handle different line endings + // Replace CRLF with LF for consistent splitting + normalizedContent := strings.ReplaceAll(string(content), "\r\n", "\n") + fileLines := strings.Split(normalizedContent, "\n") + + // Calculate the range to show, ensuring we don't go out of bounds + refLine := int(loc.Range.Start.Line) // The line where the reference occurs + + // Check if the reference line itself is valid + if refLine < 0 || refLine >= len(fileLines) { + return "", protocol.Location{}, fmt.Errorf("reference line %d is out of bounds for file %s (0-%d)", refLine+1, filePath, len(fileLines)-1) + } + + startLine := refLine - contextLines + if startLine < 0 { + startLine = 0 + } + + endLine := refLine + contextLines + if endLine >= len(fileLines) { + endLine = len(fileLines) - 1 + } + + // Ensure startLine is not greater than endLine (can happen if contextLines is large and file is small) + if startLine > endLine { + startLine = endLine + } + + // Extract the lines + selectedLines := fileLines[startLine : endLine+1] + contextText := strings.Join(selectedLines, "\n") + + // Create the location corresponding to the extracted text + // Start position: beginning of the startLine + // End position: end of the endLine (use a large character number or actual length if needed, + // but for scope identification, just the lines are often sufficient). + // Using length of last line for slightly more accuracy. + endChar := uint32(0) + if endLine >= 0 && endLine < len(fileLines) { // Check bounds for fileLines[endLine] + endChar = uint32(len(fileLines[endLine])) + } + + contextLocation := protocol.Location{ + URI: loc.URI, // Use the original URI + Range: protocol.Range{ + Start: protocol.Position{ + Line: uint32(startLine), + Character: 0, // Start of the line + }, + End: protocol.Position{ + Line: uint32(endLine), + Character: endChar, // End of the last included line + }, + }, + } + + // Return the extracted text, its location, and nil error + return contextText, contextLocation, nil +} + +// TruncateDefinition shortens a definition if it's too long +// It keeps the beginning, the context around targetLine, and the end +func TruncateDefinition(definition string, targetLine int, contextSize int, maxLines int) string { + lines := strings.Split(definition, "\n") + + // If the definition is already short enough, just return it + if len(lines) <= maxLines { + return definition + } + + // Calculate the range to keep around the target line + contextStart := targetLine - contextSize + if contextStart < 0 { + contextStart = 0 + } + + contextEnd := targetLine + contextSize + if contextEnd >= len(lines) { + contextEnd = len(lines) - 1 + } + + // Decide how many lines to keep from beginning and end + remainingLines := maxLines - (contextEnd - contextStart + 1) - 2 // -2 for ellipsis markers + startLines := remainingLines / 2 + endLines := remainingLines - startLines + + // Adjust if context overlaps with start/end segments + if contextStart < startLines { + startLines = contextStart + endLines = remainingLines - startLines + } + + if contextEnd > (len(lines) - 1 - endLines) { + endLines = len(lines) - 1 - contextEnd + startLines = remainingLines - endLines + } + + // Create the resulting truncated definition + var result []string + + // Add beginning lines if not overlapping with context + if contextStart > startLines { + result = append(result, lines[:startLines]...) + result = append(result, "...") + } else { + // Just use all lines up to context start + result = append(result, lines[:contextStart]...) + } + + // Add the context around the target line + result = append(result, lines[contextStart:contextEnd+1]...) + + // Add end lines if not overlapping with context + if contextEnd < len(lines)-1-endLines { + result = append(result, "...") + result = append(result, lines[len(lines)-endLines:]...) + } else { + // Just use all lines from context end + result = append(result, lines[contextEnd+1:]...) + } + + return strings.Join(result, "\n") +} diff --git a/internal/utilities/symbol.go b/internal/utilities/symbol.go new file mode 100644 index 0000000..da31c7e --- /dev/null +++ b/internal/utilities/symbol.go @@ -0,0 +1,107 @@ +package utilities + +import ( + "fmt" + "reflect" + + "github.com/isaacphi/mcp-language-server/internal/protocol" +) + +// Symbol Kind String Mapping +// This is a map of LSP SymbolKind values to their human-readable string representation +// Used by both document_symbols.go and find-references.go + +// GetSymbolKindString converts a SymbolKind to a descriptive format string with brackets +func GetSymbolKindString(kind protocol.SymbolKind) string { + switch kind { + case 1: // File + return "[File]" + case 2: // Module + return "[Module]" + case 3: // Namespace + return "[Namespace]" + case 4: // Package + return "[Package]" + case 5: // Class + return "[Class]" + case 6: // Method + return "[Method]" + case 7: // Property + return "[Property]" + case 8: // Field + return "[Field]" + case 9: // Constructor + return "[Constructor]" + case 10: // Enum + return "[Enum]" + case 11: // Interface + return "[Interface]" + case 12: // Function + return "[Function]" + case 13: // Variable + return "[Variable]" + case 14: // Constant + return "[Constant]" + case 15: // String + return "[String]" + case 16: // Number + return "[Number]" + case 17: // Boolean + return "[Boolean]" + case 18: // Array + return "[Array]" + case 19: // Object + return "[Object]" + case 20: // Key + return "[Key]" + case 21: // Null + return "[Null]" + case 22: // EnumMember + return "[EnumMember]" + case 23: // Struct + return "[Struct]" + case 24: // Event + return "[Event]" + case 25: // Operator + return "[Operator]" + case 26: // TypeParameter + return "[TypeParameter]" + default: + return "[Unknown]" + } +} + +// FormatSymbolWithKind formats a symbol with its kind in a consistent way across the codebase +func FormatSymbolWithKind(kind, name string) string { + if kind == "" { + return name + } + return fmt.Sprintf("%s %s", kind, name) +} + +// ExtractSymbolKind attempts to get the SymbolKind from a DocumentSymbolResult using reflection +// Returns the formatted kind string with brackets (e.g. [Function]) +func ExtractSymbolKind(sym protocol.DocumentSymbolResult) string { + // Default to Symbol + kindStr := "[Symbol]" + + // Try to extract kind through reflection since we have different struct types + // with different ways to access Kind + symValue := reflect.ValueOf(sym).Elem() + + // Try direct Kind field + if kindField := symValue.FieldByName("Kind"); kindField.IsValid() { + kind := protocol.SymbolKind(kindField.Uint()) + return GetSymbolKindString(kind) + } + + // Try BaseSymbolInformation.Kind + if baseField := symValue.FieldByName("BaseSymbolInformation"); baseField.IsValid() { + if kindField := baseField.FieldByName("Kind"); kindField.IsValid() { + kind := protocol.SymbolKind(kindField.Uint()) + return GetSymbolKindString(kind) + } + } + + return kindStr +} diff --git a/main.go b/main.go index c880c54..253f596 100644 --- a/main.go +++ b/main.go @@ -14,7 +14,7 @@ import ( "github.com/isaacphi/mcp-language-server/internal/lsp" "github.com/isaacphi/mcp-language-server/internal/watcher" - "github.com/metoro-io/mcp-golang" + mcp_golang "github.com/metoro-io/mcp-golang" "github.com/metoro-io/mcp-golang/transport/stdio" ) diff --git a/mcp-client/.python-version b/mcp-client/.python-version new file mode 100644 index 0000000..24ee5b1 --- /dev/null +++ b/mcp-client/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/mcp-client/README.md b/mcp-client/README.md new file mode 100644 index 0000000..1c58267 --- /dev/null +++ b/mcp-client/README.md @@ -0,0 +1,66 @@ +In vscode you have to `code .` inside this project for python environment to load. + +# Find references +python main.py find_references symbolName=ScopeInfo showLineNumbers=true +python main.py find_references symbolName=debugLogger showLineNumbers=true +python main.py find_references symbolName=server showLineNumbers=true + +# Get definition +python main.py read_definition symbolName=ApplyTextEditArgs showLineNumbers=true +python main.py read_definition symbolName=server showLineNumbers=true + +# Get diagnostics for a specific file +python main.py get_diagnostics filePath=internal/tools/diagnostics.go showLineNumbers=true includeContext=false + +# Get hover info (assuming line/column are 1-based as per Go comments) +python main.py hover filePath=internal/tools/find-references.go line=65 column=6 + +# Get document symbols +python main.py document_symbols filePath=internal/tools/find-references.go + +# Apply a hypothetical edit (ensure JSON structure is correct if needed) +# Note: Passing complex structures like lists of objects via key=value is hard. +# This tool might require modifications or a different input method (e.g., reading JSON from a file) +# python main.py apply_text_edit filePath=myfile.go edits='[{"range": ...}]' # <-- This simple parsing won't work well for JSON + +# Call a tool with no arguments (if any exist) +# python main.py some_tool_with_no_args + +# Example for a large Rust project +python main.py --workspace /Users/orsen/Develop/ato \ + --lsp rust-analyzer \ + --delay 10 \ + find_references symbolName=WalletManager showLineNumbers=true + +python main.py --workspace /Users/orsen/Develop/ato \ + --lsp rust-analyzer \ + --delay 6 \ + read_definition symbolName=run_collect_holders_with_progress showLineNumbers=true + +python main.py --workspace /Users/orsen/Develop/ato \ + --lsp rust-analyzer \ + --delay 6 \ + find_references symbolName=run_collect_holders_with_progress showLineNumbers=true + +python main.py --workspace /Users/orsen/Develop/ato \ + --lsp rust-analyzer \ + --delay 6 \ + document_symbols filePath=/Users/orsen/Develop/ato/bot/src/wallet_manager.rs showLineNumbers=true + +python main.py --workspace /Users/orsen/Develop/ato \ + --lsp rust-analyzer \ + --delay 6 \ + hover filePath=/Users/orsen/Develop/ato/bot/src/wallet_manager.rs line=2983 column=28 + + +# Example for the Go project (might need less delay) +python main.py --workspace /Users/orsen/Develop/mcp-language-server \ + --lsp gopls \ + --delay 5 \ + find_references symbolName=ScopeInfo showLineNumbers=true + +# Example hover call with delay +python main.py --workspace /path/to/your/large-rust-project \ + --lsp rust-analyzer \ + --delay 15 \ + hover filePath=src/some_module/file.rs line=123 column=15 diff --git a/mcp-client/main.py b/mcp-client/main.py new file mode 100644 index 0000000..3ac6fba --- /dev/null +++ b/mcp-client/main.py @@ -0,0 +1,262 @@ +import argparse +import asyncio +import json +import sys +import time +from mcp import ClientSession, StdioServerParameters, types +from mcp.client.stdio import stdio_client + +# --- Server Configuration (Modify as needed) --- +SERVER_COMMAND = "/Users/orsen/Develop/mcp-language-server/mcp-language-server" +SERVER_ARGS = [ + "--workspace", + # Use a placeholder or make this configurable too if needed + # "/Users/orsen/Develop/ato", # Rust project + "/Users/orsen/Develop/mcp-language-server", # Default back to original + "--lsp", + # "rust-analyzer", + "gopls", +] +SERVER_ENV = { + # "MCP_DEBUG_LOG": "true", # Enable debug logging +} +SERVER_NAME = "language-server" # Used for logging/identification if needed + + +def parse_value(value_str): + """Attempts to parse a string value into bool, int, float, or keeps as string.""" + val_lower = value_str.lower() + if val_lower == 'true': + return True + if val_lower == 'false': + return False + try: + return int(value_str) + except ValueError: + pass + try: + return float(value_str) + except ValueError: + pass + # If it's quoted, remove quotes (basic handling) + if len(value_str) >= 2 and value_str.startswith('"') and value_str.endswith('"'): + return value_str[1:-1] + if len(value_str) >= 2 and value_str.startswith("'") and value_str.endswith("'"): + return value_str[1:-1] + return value_str + + +def parse_tool_arguments(arg_list): + """Parses a list of 'key=value' strings into a dictionary.""" + parsed_args = {} + if not arg_list: + return parsed_args, None # No arguments provided is valid + + for arg_pair in arg_list: + if '=' not in arg_pair: + return None, f"Invalid argument format: '{arg_pair}'. Expected 'key=value'." + key, value_str = arg_pair.split('=', 1) + if not key: + return None, f"Argument key cannot be empty in '{arg_pair}'." + parsed_args[key] = parse_value(value_str) + + return parsed_args, None + + +async def run_mcp_tool_cli(tool_name, tool_arguments, initial_delay_s, workspace_path, lsp_name): + """Connects to the MCP server and executes the specified tool after an initial delay.""" + + # --- Update Server Args dynamically --- + # Find workspace arg index + try: + ws_idx = SERVER_ARGS.index("--workspace") + SERVER_ARGS[ws_idx + 1] = workspace_path + except (ValueError, IndexError): + print( + "Warning: --workspace argument not found/updated in SERVER_ARGS template.", + file=sys.stderr, + ) + # Optionally add them if not found + # SERVER_ARGS.extend(["--workspace", workspace_path]) + + # Find lsp arg index + try: + lsp_idx = SERVER_ARGS.index("--lsp") + SERVER_ARGS[lsp_idx + 1] = lsp_name + except (ValueError, IndexError): + print("Warning: --lsp argument not found/updated in SERVER_ARGS template.", file=sys.stderr) + # Optionally add them if not found + # SERVER_ARGS.extend(["--lsp", lsp_name]) + + print(f"--- Configuration ---") + print(f"Server Command: {SERVER_COMMAND}") + print(f"Server Args: {SERVER_ARGS}") + print(f"Target Tool: {tool_name}") + print(f"Tool Arguments: {json.dumps(tool_arguments, indent=2)}") + print(f"Initial Delay: {initial_delay_s} seconds") + print("-" * 20 + "\n") + + server_params = StdioServerParameters( + command=SERVER_COMMAND, + args=SERVER_ARGS, + env=SERVER_ENV, + ) + + try: + print("Attempting to start server via stdio...") + async with stdio_client(server_params) as (read_stream, write_stream): + print("Server process likely started. Establishing MCP session...") + + async with ClientSession(read_stream, write_stream) as session: + print("Initializing MCP session...") + init_result = await session.initialize() + print(f"Session initialized successfully!") + + # --- ADDED DELAY --- + if initial_delay_s > 0: + print(f"Waiting {initial_delay_s} seconds for server initialization...") + # Optional: Add a simple progress indicator + for i in range(initial_delay_s): + print(f" Waiting... {i+1}/{initial_delay_s}s", end='\r') + await asyncio.sleep(1) + print("\nWait finished.") # Newline after progress indication + else: + print("No initial delay specified.") + # --- END ADDED DELAY --- + + print(f"\nCalling tool '{tool_name}'...") + + result = await session.call_tool(tool_name, arguments=tool_arguments) + + # --- Pretty Print Result (same as before) --- + print("\n--- Tool Result ---") + # The result type depends on what the tool returns. + # It could be a primitive, dict, list, etc. + if isinstance(result, (dict, list)): + print(json.dumps(result, indent=2)) + else: + print(result) + print("-------------------\n") + + # --- New Pretty Printing Logic --- + print("\n--- Tool Result (Formatted) ---") + if hasattr(result, 'isError') and result.isError: + print("Tool call resulted in an error.") + if hasattr(result, 'content') and result.content: + error_text = "" + for content_item in result.content: + if ( + hasattr(content_item, 'type') + and content_item.type == 'text' + and hasattr(content_item, 'text') + ): + error_text += content_item.text + elif isinstance(content_item, str): + error_text += content_item + if error_text: + print("Error details:") + print(error_text) + else: + print(f"Raw error result object: {result}") + else: + print(f"Raw error result object: {result}") + + elif hasattr(result, 'content') and result.content: + full_text_output = "" + for content_item in result.content: + if ( + hasattr(content_item, 'type') + and content_item.type == 'text' + and hasattr(content_item, 'text') + ): + full_text_output += content_item.text + else: + full_text_output += f"\n[Unsupported Content Type: {type(content_item)}]\n{content_item}\n" + print(full_text_output.strip()) + + else: + print("Tool returned a result without standard content structure:") + if isinstance(result, (dict, list)): + print(json.dumps(result, indent=2)) + else: + print(result) + + print("-------------------\n") + print("Client finished.") + + except Exception as e: + print(f"\n--- An Error Occurred ---", file=sys.stderr) + print(f"Error type: {type(e).__name__}", file=sys.stderr) + print(f"Error details: {e}", file=sys.stderr) + import traceback + + traceback.print_exc(file=sys.stderr) + print("-------------------------\n", file=sys.stderr) + sys.exit(1) # Exit with error code + + +def main(): + parser = argparse.ArgumentParser( + description="MCP Client CLI to interact with a language server.", + epilog="Example: python %(prog)s --workspace /path/to/proj --lsp rust-analyzer --delay 20 find_references symbolName=MyStruct", + ) + + # --- Added Arguments for Configuration --- + parser.add_argument( + "--workspace", + required=True, + help="Path to the project workspace directory for the language server.", + ) + parser.add_argument( + "--lsp", + required=True, + choices=['gopls', 'rust-analyzer'], # Add more LSP names if needed + help="Name of the Language Server Protocol implementation to use.", + ) + parser.add_argument( + "--delay", + type=int, + default=0, + metavar='SECONDS', + help="Initial delay in seconds to wait for server initialization before sending the tool request. Default: 0", + ) + # --- End Added Arguments --- + + parser.add_argument( + "tool_name", + help="The name of the MCP tool to call (e.g., 'find_references', 'read_definition').", + ) + parser.add_argument( + "tool_args", + nargs='*', # 0 or more arguments + help="Arguments for the tool, specified as 'key=value' pairs. " + "Values 'true'/'false' are parsed as booleans, numbers as int/float if possible, otherwise as strings. " + "Use quotes for values with spaces if your shell requires it (e.g., 'filePath=\"my file.go\"').", + ) + + args = parser.parse_args() + + # Parse the key=value arguments into a dictionary + arguments_dict, error_msg = parse_tool_arguments(args.tool_args) + + if error_msg: + parser.error(error_msg) # argparse handles printing usage and exiting + + # Run the async main function, passing the new config options + try: + asyncio.run( + run_mcp_tool_cli( + args.tool_name, + arguments_dict, + args.delay, + args.workspace, # Pass workspace path + args.lsp, # Pass LSP name + ) + ) + except KeyboardInterrupt: + print("\nClient interrupted by user.") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/mcp-client/pyproject.toml b/mcp-client/pyproject.toml new file mode 100644 index 0000000..2dcf5ed --- /dev/null +++ b/mcp-client/pyproject.toml @@ -0,0 +1,9 @@ +[project] +name = "mcp-client" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.13" +dependencies = [ + "mcp[cli]>=1.6.0", +] diff --git a/mcp-client/uv.lock b/mcp-client/uv.lock new file mode 100644 index 0000000..3379929 --- /dev/null +++ b/mcp-client/uv.lock @@ -0,0 +1,346 @@ +version = 1 +revision = 1 +requires-python = ">=3.13" + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, +] + +[[package]] +name = "anyio" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/7d/4c1bd541d4dffa1b52bd83fb8527089e097a106fc90b467a7313b105f840/anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028", size = 190949 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916 }, +] + +[[package]] +name = "certifi" +version = "2025.1.31" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/ab/c9f1e32b7b1bf505bf26f0ef697775960db7932abeb7b516de930ba2705f/certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651", size = 167577 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/fc/bce832fd4fd99766c04d1ee0eead6b0ec6486fb100ae5e74c1d91292b982/certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe", size = 166393 }, +] + +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + +[[package]] +name = "h11" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, +] + +[[package]] +name = "httpcore" +version = "1.0.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/45/ad3e1b4d448f22c0cff4f5692f5ed0666658578e358b8d58a19846048059/httpcore-1.0.8.tar.gz", hash = "sha256:86e94505ed24ea06514883fd44d2bc02d90e77e7979c8eb71b90f41d364a1bad", size = 85385 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/8d/f052b1e336bb2c1fc7ed1aaed898aa570c0b61a09707b108979d9fc6e308/httpcore-1.0.8-py3-none-any.whl", hash = "sha256:5254cf149bcb5f75e9d1b2b9f729ea4a4b883d1ad7379fc632b727cec23674be", size = 78732 }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, +] + +[[package]] +name = "httpx-sse" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/60/8f4281fa9bbf3c8034fd54c0e7412e66edbab6bc74c4996bd616f8d0406e/httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721", size = 12624 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/9b/a181f281f65d776426002f330c31849b86b31fc9d848db62e16f03ff739f/httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f", size = 7819 }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, +] + +[[package]] +name = "markdown-it-py" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 }, +] + +[[package]] +name = "mcp" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "uvicorn" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077 }, +] + +[package.optional-dependencies] +cli = [ + { name = "python-dotenv" }, + { name = "typer" }, +] + +[[package]] +name = "mcp-client" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "mcp", extra = ["cli"] }, +] + +[package.metadata] +requires-dist = [{ name = "mcp", extras = ["cli"], specifier = ">=1.6.0" }] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, +] + +[[package]] +name = "pydantic" +version = "2.11.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/2e/ca897f093ee6c5f3b0bee123ee4465c50e75431c3d5b6a3b44a47134e891/pydantic-2.11.3.tar.gz", hash = "sha256:7471657138c16adad9322fe3070c0116dd6c3ad8d649300e3cbdfe91f4db4ec3", size = 785513 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/1d/407b29780a289868ed696d1616f4aad49d6388e5a77f567dcd2629dcd7b8/pydantic-2.11.3-py3-none-any.whl", hash = "sha256:a082753436a07f9ba1289c6ffa01cd93db3548776088aa917cc43b63f68fa60f", size = 443591 }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/17/19/ed6a078a5287aea7922de6841ef4c06157931622c89c2a47940837b5eecd/pydantic_core-2.33.1.tar.gz", hash = "sha256:bcc9c6fdb0ced789245b02b7d6603e17d1563064ddcfc36f046b61c0c05dd9df", size = 434395 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/24/eed3466a4308d79155f1cdd5c7432c80ddcc4530ba8623b79d5ced021641/pydantic_core-2.33.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:70af6a21237b53d1fe7b9325b20e65cbf2f0a848cf77bed492b029139701e66a", size = 2033551 }, + { url = "https://files.pythonhosted.org/packages/ab/14/df54b1a0bc9b6ded9b758b73139d2c11b4e8eb43e8ab9c5847c0a2913ada/pydantic_core-2.33.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:282b3fe1bbbe5ae35224a0dbd05aed9ccabccd241e8e6b60370484234b456266", size = 1852785 }, + { url = "https://files.pythonhosted.org/packages/fa/96/e275f15ff3d34bb04b0125d9bc8848bf69f25d784d92a63676112451bfb9/pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b315e596282bbb5822d0c7ee9d255595bd7506d1cb20c2911a4da0b970187d3", size = 1897758 }, + { url = "https://files.pythonhosted.org/packages/b7/d8/96bc536e975b69e3a924b507d2a19aedbf50b24e08c80fb00e35f9baaed8/pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1dfae24cf9921875ca0ca6a8ecb4bb2f13c855794ed0d468d6abbec6e6dcd44a", size = 1986109 }, + { url = "https://files.pythonhosted.org/packages/90/72/ab58e43ce7e900b88cb571ed057b2fcd0e95b708a2e0bed475b10130393e/pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6dd8ecfde08d8bfadaea669e83c63939af76f4cf5538a72597016edfa3fad516", size = 2129159 }, + { url = "https://files.pythonhosted.org/packages/dc/3f/52d85781406886c6870ac995ec0ba7ccc028b530b0798c9080531b409fdb/pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f593494876eae852dc98c43c6f260f45abdbfeec9e4324e31a481d948214764", size = 2680222 }, + { url = "https://files.pythonhosted.org/packages/f4/56/6e2ef42f363a0eec0fd92f74a91e0ac48cd2e49b695aac1509ad81eee86a/pydantic_core-2.33.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:948b73114f47fd7016088e5186d13faf5e1b2fe83f5e320e371f035557fd264d", size = 2006980 }, + { url = "https://files.pythonhosted.org/packages/4c/c0/604536c4379cc78359f9ee0aa319f4aedf6b652ec2854953f5a14fc38c5a/pydantic_core-2.33.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e11f3864eb516af21b01e25fac915a82e9ddad3bb0fb9e95a246067398b435a4", size = 2120840 }, + { url = "https://files.pythonhosted.org/packages/1f/46/9eb764814f508f0edfb291a0f75d10854d78113fa13900ce13729aaec3ae/pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:549150be302428b56fdad0c23c2741dcdb5572413776826c965619a25d9c6bde", size = 2072518 }, + { url = "https://files.pythonhosted.org/packages/42/e3/fb6b2a732b82d1666fa6bf53e3627867ea3131c5f39f98ce92141e3e3dc1/pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:495bc156026efafd9ef2d82372bd38afce78ddd82bf28ef5276c469e57c0c83e", size = 2248025 }, + { url = "https://files.pythonhosted.org/packages/5c/9d/fbe8fe9d1aa4dac88723f10a921bc7418bd3378a567cb5e21193a3c48b43/pydantic_core-2.33.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ec79de2a8680b1a67a07490bddf9636d5c2fab609ba8c57597e855fa5fa4dacd", size = 2254991 }, + { url = "https://files.pythonhosted.org/packages/aa/99/07e2237b8a66438d9b26482332cda99a9acccb58d284af7bc7c946a42fd3/pydantic_core-2.33.1-cp313-cp313-win32.whl", hash = "sha256:ee12a7be1742f81b8a65b36c6921022301d466b82d80315d215c4c691724986f", size = 1915262 }, + { url = "https://files.pythonhosted.org/packages/8a/f4/e457a7849beeed1e5defbcf5051c6f7b3c91a0624dd31543a64fc9adcf52/pydantic_core-2.33.1-cp313-cp313-win_amd64.whl", hash = "sha256:ede9b407e39949d2afc46385ce6bd6e11588660c26f80576c11c958e6647bc40", size = 1956626 }, + { url = "https://files.pythonhosted.org/packages/20/d0/e8d567a7cff7b04e017ae164d98011f1e1894269fe8e90ea187a3cbfb562/pydantic_core-2.33.1-cp313-cp313-win_arm64.whl", hash = "sha256:aa687a23d4b7871a00e03ca96a09cad0f28f443690d300500603bd0adba4b523", size = 1909590 }, + { url = "https://files.pythonhosted.org/packages/ef/fd/24ea4302d7a527d672c5be06e17df16aabfb4e9fdc6e0b345c21580f3d2a/pydantic_core-2.33.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:401d7b76e1000d0dd5538e6381d28febdcacb097c8d340dde7d7fc6e13e9f95d", size = 1812963 }, + { url = "https://files.pythonhosted.org/packages/5f/95/4fbc2ecdeb5c1c53f1175a32d870250194eb2fdf6291b795ab08c8646d5d/pydantic_core-2.33.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7aeb055a42d734c0255c9e489ac67e75397d59c6fbe60d155851e9782f276a9c", size = 1986896 }, + { url = "https://files.pythonhosted.org/packages/71/ae/fe31e7f4a62431222d8f65a3bd02e3fa7e6026d154a00818e6d30520ea77/pydantic_core-2.33.1-cp313-cp313t-win_amd64.whl", hash = "sha256:338ea9b73e6e109f15ab439e62cb3b78aa752c7fd9536794112e14bee02c8d18", size = 1931810 }, +] + +[[package]] +name = "pydantic-settings" +version = "2.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/88/82/c79424d7d8c29b994fb01d277da57b0a9b09cc03c3ff875f9bd8a86b2145/pydantic_settings-2.8.1.tar.gz", hash = "sha256:d5c663dfbe9db9d5e1c646b2e161da12f0d734d422ee56f567d0ea2cee4e8585", size = 83550 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/53/a64f03044927dc47aafe029c42a5b7aabc38dfb813475e0e1bf71c4a59d0/pydantic_settings-2.8.1-py3-none-any.whl", hash = "sha256:81942d5ac3d905f7f3ee1a70df5dfb62d5569c12f51a5a647defc1c3d9ee2e9c", size = 30839 }, +] + +[[package]] +name = "pygments" +version = "2.19.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7c/2d/c3338d48ea6cc0feb8446d8e6937e1408088a72a39937982cc6111d17f84/pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f", size = 4968581 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 }, +] + +[[package]] +name = "python-dotenv" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/2c/7bb1416c5620485aa793f2de31d3df393d3686aa8a8506d11e10e13c5baf/python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5", size = 39920 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/18/98a99ad95133c6a6e2005fe89faedf294a748bd5dc803008059409ac9b1e/python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d", size = 20256 }, +] + +[[package]] +name = "rich" +version = "14.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/53/830aa4c3066a8ab0ae9a9955976fb770fe9c6102117c8ec4ab3ea62d89e8/rich-14.0.0.tar.gz", hash = "sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725", size = 224078 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/63f4c7ebc259242c89b3acafdb37b41d1185c07ff0011164674e9076b491/rich-14.0.0-py3-none-any.whl", hash = "sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0", size = 243229 }, +] + +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755 }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, +] + +[[package]] +name = "sse-starlette" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/a4/80d2a11af59fe75b48230846989e93979c892d3a20016b42bb44edb9e398/sse_starlette-2.2.1.tar.gz", hash = "sha256:54470d5f19274aeed6b2d473430b08b4b379ea851d953b11d7f1c4a2c118b419", size = 17376 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/e0/5b8bd393f27f4a62461c5cf2479c75a2cc2ffa330976f9f00f5f6e4f50eb/sse_starlette-2.2.1-py3-none-any.whl", hash = "sha256:6410a3d3ba0c89e7675d4c273a301d64649c03a5ef1ca101f10b47f895fd0e99", size = 10120 }, +] + +[[package]] +name = "starlette" +version = "0.46.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037 }, +] + +[[package]] +name = "typer" +version = "0.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/6f/3991f0f1c7fcb2df31aef28e0594d8d54b05393a0e4e34c65e475c2a5d41/typer-0.15.2.tar.gz", hash = "sha256:ab2fab47533a813c49fe1f16b1a370fd5819099c00b119e0633df65f22144ba5", size = 100711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/fc/5b29fea8cee020515ca82cc68e3b8e1e34bb19a3535ad854cac9257b414c/typer-0.15.2-py3-none-any.whl", hash = "sha256:46a499c6107d645a9c13f7ee46c5d5096cae6f5fc57dd11eccbbb9ae3e44ddfc", size = 45061 }, +] + +[[package]] +name = "typing-extensions" +version = "4.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/37/23083fcd6e35492953e8d2aaaa68b860eb422b34627b13f2ce3eb6106061/typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef", size = 106967 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c", size = 45806 }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/5c/e6082df02e215b846b4b8c0b887a64d7d08ffaba30605502639d44c06b82/typing_inspection-0.4.0.tar.gz", hash = "sha256:9765c87de36671694a67904bf2c96e395be9c6439bb6c87b5142569dcdd65122", size = 76222 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, +] + +[[package]] +name = "uvicorn" +version = "0.34.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/37/dd92f1f9cedb5eaf74d9999044306e06abe65344ff197864175dbbd91871/uvicorn-0.34.1.tar.gz", hash = "sha256:af981725fc4b7ffc5cb3b0e9eda6258a90c4b52cb2a83ce567ae0a7ae1757afc", size = 76755 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/38/a5801450940a858c102a7ad9e6150146a25406a119851c993148d56ab041/uvicorn-0.34.1-py3-none-any.whl", hash = "sha256:984c3a8c7ca18ebaad15995ee7401179212c59521e67bfc390c07fa2b8d2e065", size = 62404 }, +] diff --git a/tools.go b/tools.go index b3ecc1a..4839137 100644 --- a/tools.go +++ b/tools.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/isaacphi/mcp-language-server/internal/tools" - "github.com/metoro-io/mcp-golang" + mcp_golang "github.com/metoro-io/mcp-golang" ) type ReadDefinitionArgs struct { @@ -37,8 +37,25 @@ type ExecuteCodeLensArgs struct { Index int `json:"index" jsonschema:"required,description=The index of the code lens to execute (from get_codelens output), 1 indexed"` } -func (s *server) registerTools() error { +type RenameSymbolArgs struct { + FilePath string `json:"filePath" jsonschema:"required,description=The path to the file containing the symbol to rename"` + Line int `json:"line" jsonschema:"required,description=The line number (1-indexed) where the symbol appears"` + Column int `json:"column" jsonschema:"required,description=The column number (1-indexed) where the symbol appears"` + NewName string `json:"newName" jsonschema:"required,description=The new name for the symbol"` +} + +type HoverArgs struct { + FilePath string `json:"filePath" jsonschema:"required,description=The path to the file containing the symbol to get hover information for"` + Line int `json:"line" jsonschema:"required,description=The line number (1-indexed) where the symbol appears"` + Column int `json:"column" jsonschema:"required,description=The column number (1-indexed) where the symbol appears"` +} + +type DocumentSymbolsArgs struct { + FilePath string `json:"filePath" jsonschema:"required,description=The path to the file to list symbols for"` + ShowLineNumbers bool `json:"showLineNumbers" jsonschema:"required,default=true,description=Include line numbers in the output"` +} +func (s *server) registerTools() error { err := s.mcpServer.RegisterTool( "apply_text_edit", "Apply multiple text edits to a file.", @@ -126,5 +143,50 @@ func (s *server) registerTools() error { return fmt.Errorf("failed to register tool: %v", err) } + err = s.mcpServer.RegisterTool( + "rename_symbol", + "Rename a symbol (variable, function, class, etc.) and all its references across files.", + func(args RenameSymbolArgs) (*mcp_golang.ToolResponse, error) { + text, err := tools.RenameSymbol(s.ctx, s.lspClient, args.FilePath, args.Line, args.Column, args.NewName) + if err != nil { + return nil, fmt.Errorf("Failed to rename symbol: %v", err) + } + return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(text)), nil + }, + ) + if err != nil { + return fmt.Errorf("failed to register tool: %v", err) + } + + err = s.mcpServer.RegisterTool( + "hover", + "Get hover information (type, documentation) for a symbol at the specified position.", + func(args HoverArgs) (*mcp_golang.ToolResponse, error) { + text, err := tools.GetHoverInfo(s.ctx, s.lspClient, args.FilePath, args.Line, args.Column) + if err != nil { + return nil, fmt.Errorf("Failed to get hover information: %v", err) + } + return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(text)), nil + }, + ) + if err != nil { + return fmt.Errorf("failed to register tool: %v", err) + } + + err = s.mcpServer.RegisterTool( + "document_symbols", + "List all symbols (functions, methods, classes, etc.) in a document in a hierarchical structure.", + func(args DocumentSymbolsArgs) (*mcp_golang.ToolResponse, error) { + text, err := tools.GetDocumentSymbols(s.ctx, s.lspClient, args.FilePath, args.ShowLineNumbers) + if err != nil { + return nil, fmt.Errorf("Failed to get document symbols: %v", err) + } + return mcp_golang.NewToolResponse(mcp_golang.NewTextContent(text)), nil + }, + ) + if err != nil { + return fmt.Errorf("failed to register tool: %v", err) + } + return nil }