@@ -5,14 +5,19 @@ import (
55 "fmt"
66 "net/url"
77 "os"
8+ "slices"
89 "strings"
910
1011 "github.com/isaacphi/mcp-language-server/internal/lsp"
1112 "github.com/isaacphi/mcp-language-server/internal/protocol"
1213)
1314
14- // Gets the full code block surrounding the start of the input location
15- func GetFullDefinition (ctx context.Context , client * lsp.Client , startLocation protocol.Location ) (string , protocol.Location , protocol.DocumentSymbolResult , error ) {
15+ type match struct {
16+ Symbol protocol.DocumentSymbolResult
17+ Range protocol.Range
18+ }
19+
20+ func identifyOverlappingSymbols (ctx context.Context , client * lsp.Client , startLocation protocol.Location ) ([]match , error ) {
1621 symParams := protocol.DocumentSymbolParams {
1722 TextDocument : protocol.TextDocumentIdentifier {
1823 URI : startLocation .URI ,
@@ -22,45 +27,59 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr
2227 // Get all symbols in document
2328 symResult , err := client .DocumentSymbol (ctx , symParams )
2429 if err != nil {
25- return "" , protocol. Location {}, nil , fmt .Errorf ("failed to get document symbols: %w" , err )
30+ return nil , fmt .Errorf ("failed to get document symbols: %w" , err )
2631 }
2732
2833 symbols , err := symResult .Results ()
2934 if err != nil {
30- return "" , protocol. Location {}, nil , fmt .Errorf ("failed to process document symbols: %w" , err )
35+ return nil , fmt .Errorf ("failed to process document symbols: %w" , err )
3136 }
3237
33- var symbolRange protocol.Range
34- var symbol protocol.DocumentSymbolResult
35- found := false
36-
3738 // Search for symbol at startLocation
38- var searchSymbols func (symbols []protocol.DocumentSymbolResult ) bool
39- searchSymbols = func (symbols []protocol.DocumentSymbolResult ) bool {
39+ // - multiple symbols might match (for example, a C++ namespace) so find
40+ // all of the matching symbols and use the smallest one (or the first one
41+ // if there is a tie)
42+ var matchingSymbols []match
43+
44+ var searchSymbols func (symbols []protocol.DocumentSymbolResult )
45+ searchSymbols = func (symbols []protocol.DocumentSymbolResult ) {
4046 for _ , sym := range symbols {
4147 if containsPosition (sym .GetRange (), startLocation .Range .Start ) {
42- symbol = sym
43- symbolRange = sym .GetRange ()
44- found = true
45- return true
48+ matchingSymbols = append (matchingSymbols , match {sym , sym .GetRange ()})
4649 }
50+
4751 // Handle nested symbols if it's a DocumentSymbol
4852 if ds , ok := sym .(* protocol.DocumentSymbol ); ok && len (ds .Children ) > 0 {
4953 childSymbols := make ([]protocol.DocumentSymbolResult , len (ds .Children ))
5054 for i := range ds .Children {
5155 childSymbols [i ] = & ds .Children [i ]
5256 }
53- if searchSymbols (childSymbols ) {
54- return true
55- }
57+ searchSymbols (childSymbols )
5658 }
5759 }
58- return false
5960 }
6061
61- found = searchSymbols (symbols )
62+ searchSymbols (symbols )
63+ return matchingSymbols , nil
64+ }
65+
66+ // Gets the full code block surrounding the start of the input location
67+ func GetFullDefinition (ctx context.Context , client * lsp.Client , startLocation protocol.Location ) (string , protocol.Location , protocol.DocumentSymbolResult , error ) {
68+
69+ matchingSymbols , err := identifyOverlappingSymbols (ctx , client , startLocation )
70+ if err != nil {
71+ return "" , protocol.Location {}, nil , err
72+ }
73+
74+ // Identify the smallest overlapping symbol
75+ slices .SortStableFunc (matchingSymbols , func (a , b match ) int {
76+ return int (a .Range .End .Line - a .Range .Start .Line ) - int (b .Range .End .Line - b .Range .Start .Line )
77+ })
78+
79+ if len (matchingSymbols ) > 0 {
80+ symbol := matchingSymbols [0 ].Symbol
81+ symbolRange := matchingSymbols [0 ].Range
6282
63- if found {
6483 // Convert URI to filesystem path
6584 filePath , err := url .PathUnescape (strings .TrimPrefix (string (startLocation .URI ), "file://" ))
6685 if err != nil {
@@ -125,16 +144,13 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr
125144 }
126145 }
127146
128- // Update location with new range
129- startLocation .Range = symbolRange
130-
131147 // Return the text within the range
132148 if int (symbolRange .End .Line ) >= len (lines ) {
133149 return "" , protocol.Location {}, nil , fmt .Errorf ("end line out of range" )
134150 }
135151
136152 selectedLines := lines [symbolRange .Start .Line : symbolRange .End .Line + 1 ]
137- return strings .Join (selectedLines , "\n " ), startLocation , symbol , nil
153+ return strings .Join (selectedLines , "\n " ), protocol. Location { URI : startLocation . URI , Range : symbolRange } , symbol , nil
138154 }
139155
140156 return "" , protocol.Location {}, nil , fmt .Errorf ("symbol not found" )
@@ -148,8 +164,8 @@ func GetLineRangesToDisplay(ctx context.Context, client *lsp.Client, locations [
148164 // For each location, get its container and add relevant lines
149165 for _ , loc := range locations {
150166 // Use GetFullDefinition to find container
151- _ , containerLoc , _ , err := GetFullDefinition (ctx , client , loc )
152- if err != nil {
167+ matchingSymbols , _ := identifyOverlappingSymbols (ctx , client , loc )
168+ if len ( matchingSymbols ) == 0 {
153169 // If container not found, just use the location's line
154170 refLine := int (loc .Range .Start .Line )
155171 linesToShow [refLine ] = true
@@ -163,9 +179,11 @@ func GetLineRangesToDisplay(ctx context.Context, client *lsp.Client, locations [
163179 continue
164180 }
165181
182+ containerRange := matchingSymbols [0 ].Range
183+
166184 // Add container start and end lines
167- containerStart := int (containerLoc . Range .Start .Line )
168- containerEnd := int (containerLoc . Range .End .Line )
185+ containerStart := int (containerRange .Start .Line )
186+ containerEnd := int (containerRange .End .Line )
169187 linesToShow [containerStart ] = true
170188 // linesToShow[containerEnd] = true
171189
0 commit comments