Skip to content

Commit c3aa9a3

Browse files
authored
Merge pull request isaacphi#7 from virtuald/use-smallest-symbol
clangd: use smallest matching symbol as definition result
2 parents eaed24a + 20419e5 commit c3aa9a3

File tree

9 files changed

+85
-116
lines changed

9 files changed

+85
-116
lines changed

integrationtests/snapshots/clangd/definition/method.snap

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,7 @@ Symbol: method
44
/TEST_OUTPUT/workspace/clangd/src/consumer.cpp
55
Kind: Method
66
Container Name: TestClass
7-
Range: L7:C1 - L15:C2
7+
Range: L14:C1 - L14:C47
88

9-
7|class TestClass {
10-
8| public:
11-
9| /**
12-
10| * @brief A method that takes an integer parameter.
13-
11| *
14-
12| * @param param The integer parameter to be processed.
15-
13| */
169
14| void method(int param) { helperFunction(); }
17-
15|};
1810

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
---
2+
3+
Symbol: nsFunction2
4+
/TEST_OUTPUT/workspace/clangd/src/namespace.cpp
5+
Kind: Function
6+
Container Name: ns
7+
Range: L11:C1 - L13:C2
8+
9+
11|void nsFunction2() {
10+
12| // empty
11+
13|}
12+

integrationtests/snapshots/python/definition/method.snap

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,8 @@ Symbol: test_method
44
/TEST_OUTPUT/workspace/main.py
55
Kind: Method
66
Container Name: TestClass
7-
Range: L18:C1 - L59:C22
7+
Range: L31:C1 - L41:C26
88

9-
18|class TestClass:
10-
19| """A test class with methods and attributes."""
11-
20|
12-
21| class_variable: str = "class variable"
13-
22|
14-
23| def __init__(self, value: int = 0):
15-
24| """Initialize the TestClass.
16-
25|
17-
26| Args:
18-
27| value: The initial value
19-
28| """
20-
29| self.value: int = value
21-
30|
229
31| def test_method(self, increment: int) -> int:
2310
32| """Increment the value by the given amount.
2411
33|
@@ -30,22 +17,4 @@ Range: L18:C1 - L59:C22
3017
39| """
3118
40| self.value += increment
3219
41| return self.value
33-
42|
34-
43| @staticmethod
35-
44| def static_method(items: list[str]) -> dict[str, int]:
36-
45| """Convert a list of items to a dictionary with item counts.
37-
46|
38-
47| Args:
39-
48| items: A list of strings
40-
49|
41-
50| Returns:
42-
51| A dictionary mapping items to their counts
43-
52| """
44-
53| result: dict[str, int] = {}
45-
54| for item in items:
46-
55| if item in result:
47-
56| result[item] += 1
48-
57| else:
49-
58| result[item] = 1
50-
59| return result
5120

integrationtests/snapshots/python/definition/static-method.snap

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,8 @@ Symbol: static_method
44
/TEST_OUTPUT/workspace/main.py
55
Kind: Method
66
Container Name: TestClass
7-
Range: L18:C1 - L59:C22
7+
Range: L43:C1 - L59:C22
88

9-
18|class TestClass:
10-
19| """A test class with methods and attributes."""
11-
20|
12-
21| class_variable: str = "class variable"
13-
22|
14-
23| def __init__(self, value: int = 0):
15-
24| """Initialize the TestClass.
16-
25|
17-
26| Args:
18-
27| value: The initial value
19-
28| """
20-
29| self.value: int = value
21-
30|
22-
31| def test_method(self, increment: int) -> int:
23-
32| """Increment the value by the given amount.
24-
33|
25-
34| Args:
26-
35| increment: The amount to increment by
27-
36|
28-
37| Returns:
29-
38| The new value
30-
39| """
31-
40| self.value += increment
32-
41| return self.value
33-
42|
349
43| @staticmethod
3510
44| def static_method(items: list[str]) -> dict[str, int]:
3611
45| """Convert a list of items to a dictionary with item counts.

integrationtests/snapshots/rust/definition/method.snap

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,21 @@ Symbol: method
44
/TEST_OUTPUT/workspace/src/types.rs
55
Kind: Function
66
Container Name: TestStruct
7-
Range: L18:C1 - L30:C2
7+
Range: L27:C1 - L29:C6
88

9-
18|// Implementation for TestStruct
10-
19|impl TestStruct {
11-
20| pub fn new(name: &str, value: i32) -> Self {
12-
21| TestStruct {
13-
22| name: String::from(name),
14-
23| value,
15-
24| }
16-
25| }
17-
26|
189
27| pub fn method(&self) -> String {
1910
28| format!("{}: {}", self.name, self.value)
2011
29| }
21-
30|}
2212

2313
---
2414

2515
Symbol: method
2616
/TEST_OUTPUT/workspace/src/types.rs
2717
Kind: Function
2818
Container Name: SharedStruct
29-
Range: L54:C1 - L64:C2
19+
Range: L61:C1 - L63:C6
3020

31-
54|impl SharedStruct {
32-
55| pub fn new(name: &str) -> Self {
33-
56| SharedStruct {
34-
57| name: String::from(name),
35-
58| }
36-
59| }
37-
60|
3821
61| pub fn method(&self) -> String {
3922
62| format!("SharedStruct: {}", self.name)
4023
63| }
41-
64|}
4224

integrationtests/tests/clangd/definition/definition_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ func TestReadDefinition(t *testing.T) {
6565
expectedText: "void method(int param)",
6666
snapshotName: "method",
6767
},
68+
{
69+
name: "Namespace function",
70+
symbolName: "nsFunction2",
71+
expectedText: "void nsFunction2()",
72+
snapshotName: "nsFunction",
73+
},
6874
{
6975
name: "Struct",
7076
symbolName: "TestStruct",

integrationtests/workspaces/clangd/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ TARGET_CLEAN_PROGRAM = clean_program # Assuming this is another program to be bu
2929
# Listing them explicitly for clarity in target dependencies.
3030
OBJ_MAIN = $(OBJDIR)/main.o
3131
# Add other specific object files your 'program' executable depends on
32-
OTHER_OBJS = $(OBJDIR)/helper.o $(OBJDIR)/types.o $(OBJDIR)/consumer.o $(OBJDIR)/another_consumer.o
32+
OTHER_OBJS = $(OBJDIR)/helper.o $(OBJDIR)/types.o $(OBJDIR)/consumer.o $(OBJDIR)/another_consumer.o $(OBJDIR)/namespace.o
3333
OBJ_FOR_CLEAN_PROGRAM = $(OBJDIR)/clean.o
3434

3535

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//
2+
// A file with a namespace in it
3+
//
4+
5+
namespace ns {
6+
7+
void nsFunction1() {
8+
// empty
9+
}
10+
11+
void nsFunction2() {
12+
// empty
13+
}
14+
15+
}

internal/tools/lsp-utilities.go

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,19 @@ import (
55
"fmt"
66
"net/url"
77
"os"
8+
"slices"
89
"strings"
910

1011
"github.com/isaacphi/mcp-language-server/internal/lsp"
1112
"github.com/isaacphi/mcp-language-server/internal/protocol"
1213
)
1314

14-
// Gets the full code block surrounding the start of the input location
15-
func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation protocol.Location) (string, protocol.Location, protocol.DocumentSymbolResult, error) {
15+
type match struct {
16+
Symbol protocol.DocumentSymbolResult
17+
Range protocol.Range
18+
}
19+
20+
func identifyOverlappingSymbols(ctx context.Context, client *lsp.Client, startLocation protocol.Location) ([]match, error) {
1621
symParams := protocol.DocumentSymbolParams{
1722
TextDocument: protocol.TextDocumentIdentifier{
1823
URI: startLocation.URI,
@@ -22,45 +27,59 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr
2227
// Get all symbols in document
2328
symResult, err := client.DocumentSymbol(ctx, symParams)
2429
if err != nil {
25-
return "", protocol.Location{}, nil, fmt.Errorf("failed to get document symbols: %w", err)
30+
return nil, fmt.Errorf("failed to get document symbols: %w", err)
2631
}
2732

2833
symbols, err := symResult.Results()
2934
if err != nil {
30-
return "", protocol.Location{}, nil, fmt.Errorf("failed to process document symbols: %w", err)
35+
return nil, fmt.Errorf("failed to process document symbols: %w", err)
3136
}
3237

33-
var symbolRange protocol.Range
34-
var symbol protocol.DocumentSymbolResult
35-
found := false
36-
3738
// Search for symbol at startLocation
38-
var searchSymbols func(symbols []protocol.DocumentSymbolResult) bool
39-
searchSymbols = func(symbols []protocol.DocumentSymbolResult) bool {
39+
// - multiple symbols might match (for example, a C++ namespace) so find
40+
// all of the matching symbols and use the smallest one (or the first one
41+
// if there is a tie)
42+
var matchingSymbols []match
43+
44+
var searchSymbols func(symbols []protocol.DocumentSymbolResult)
45+
searchSymbols = func(symbols []protocol.DocumentSymbolResult) {
4046
for _, sym := range symbols {
4147
if containsPosition(sym.GetRange(), startLocation.Range.Start) {
42-
symbol = sym
43-
symbolRange = sym.GetRange()
44-
found = true
45-
return true
48+
matchingSymbols = append(matchingSymbols, match{sym, sym.GetRange()})
4649
}
50+
4751
// Handle nested symbols if it's a DocumentSymbol
4852
if ds, ok := sym.(*protocol.DocumentSymbol); ok && len(ds.Children) > 0 {
4953
childSymbols := make([]protocol.DocumentSymbolResult, len(ds.Children))
5054
for i := range ds.Children {
5155
childSymbols[i] = &ds.Children[i]
5256
}
53-
if searchSymbols(childSymbols) {
54-
return true
55-
}
57+
searchSymbols(childSymbols)
5658
}
5759
}
58-
return false
5960
}
6061

61-
found = searchSymbols(symbols)
62+
searchSymbols(symbols)
63+
return matchingSymbols, nil
64+
}
65+
66+
// Gets the full code block surrounding the start of the input location
67+
func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation protocol.Location) (string, protocol.Location, protocol.DocumentSymbolResult, error) {
68+
69+
matchingSymbols, err := identifyOverlappingSymbols(ctx, client, startLocation)
70+
if err != nil {
71+
return "", protocol.Location{}, nil, err
72+
}
73+
74+
// Identify the smallest overlapping symbol
75+
slices.SortStableFunc(matchingSymbols, func(a, b match) int {
76+
return int(a.Range.End.Line-a.Range.Start.Line) - int(b.Range.End.Line-b.Range.Start.Line)
77+
})
78+
79+
if len(matchingSymbols) > 0 {
80+
symbol := matchingSymbols[0].Symbol
81+
symbolRange := matchingSymbols[0].Range
6282

63-
if found {
6483
// Convert URI to filesystem path
6584
filePath, err := url.PathUnescape(strings.TrimPrefix(string(startLocation.URI), "file://"))
6685
if err != nil {
@@ -125,16 +144,13 @@ func GetFullDefinition(ctx context.Context, client *lsp.Client, startLocation pr
125144
}
126145
}
127146

128-
// Update location with new range
129-
startLocation.Range = symbolRange
130-
131147
// Return the text within the range
132148
if int(symbolRange.End.Line) >= len(lines) {
133149
return "", protocol.Location{}, nil, fmt.Errorf("end line out of range")
134150
}
135151

136152
selectedLines := lines[symbolRange.Start.Line : symbolRange.End.Line+1]
137-
return strings.Join(selectedLines, "\n"), startLocation, symbol, nil
153+
return strings.Join(selectedLines, "\n"), protocol.Location{URI: startLocation.URI, Range: symbolRange}, symbol, nil
138154
}
139155

140156
return "", protocol.Location{}, nil, fmt.Errorf("symbol not found")
@@ -148,8 +164,8 @@ func GetLineRangesToDisplay(ctx context.Context, client *lsp.Client, locations [
148164
// For each location, get its container and add relevant lines
149165
for _, loc := range locations {
150166
// Use GetFullDefinition to find container
151-
_, containerLoc, _, err := GetFullDefinition(ctx, client, loc)
152-
if err != nil {
167+
matchingSymbols, _ := identifyOverlappingSymbols(ctx, client, loc)
168+
if len(matchingSymbols) == 0 {
153169
// If container not found, just use the location's line
154170
refLine := int(loc.Range.Start.Line)
155171
linesToShow[refLine] = true
@@ -163,9 +179,11 @@ func GetLineRangesToDisplay(ctx context.Context, client *lsp.Client, locations [
163179
continue
164180
}
165181

182+
containerRange := matchingSymbols[0].Range
183+
166184
// Add container start and end lines
167-
containerStart := int(containerLoc.Range.Start.Line)
168-
containerEnd := int(containerLoc.Range.End.Line)
185+
containerStart := int(containerRange.Start.Line)
186+
containerEnd := int(containerRange.End.Line)
169187
linesToShow[containerStart] = true
170188
// linesToShow[containerEnd] = true
171189

0 commit comments

Comments
 (0)