Skip to content

Commit 4072af9

Browse files
authored
fix: add ctx to annotations (#200)
1 parent 420662a commit 4072af9

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/deepset_mcp/mcp/tool_factory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,15 @@ async def client_wrapper_with_context(*args: Any, **kwargs: Any) -> Any:
193193
ctx_param = inspect.Parameter(name="ctx", kind=inspect.Parameter.KEYWORD_ONLY, annotation=Context)
194194
new_params.append(ctx_param)
195195
client_wrapper_with_context.__signature__ = original_sig.replace(parameters=new_params) # type: ignore
196+
197+
# Remove client from docstring
196198
client_wrapper_with_context.__doc__ = remove_params_from_docstring(base_func.__doc__, {"client"})
197199

200+
# Remove client from annotations and add ctx
201+
new_annotations = {k: v for k, v in base_func.__annotations__.items() if k != "client"}
202+
new_annotations["ctx"] = Context
203+
client_wrapper_with_context.__annotations__ = new_annotations
204+
198205
return client_wrapper_with_context
199206
else:
200207

@@ -214,6 +221,10 @@ async def client_wrapper_without_context(*args: Any, **kwargs: Any) -> Any:
214221
# Remove client from docstring
215222
client_wrapper_without_context.__doc__ = remove_params_from_docstring(base_func.__doc__, {"client"})
216223

224+
# Remove client from annotations
225+
new_annotations = {k: v for k, v in base_func.__annotations__.items() if k != "client"}
226+
client_wrapper_without_context.__annotations__ = new_annotations
227+
217228
return client_wrapper_without_context
218229

219230

test/unit/test_tool_factory.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from unittest.mock import MagicMock, patch
1010

1111
import pytest
12+
from mcp.server.fastmcp import Context
1213

1314
from deepset_mcp.api.protocols import AsyncClientProtocol
1415
from deepset_mcp.mcp.tool_factory import (
@@ -265,6 +266,11 @@ async def sample_func(client: AsyncClientProtocol, a: int) -> str:
265266
assert ":param client:" not in result.__doc__
266267
assert ":param a:" in result.__doc__
267268

269+
# Check annotations were updated
270+
assert "client" not in result.__annotations__
271+
assert "ctx" in result.__annotations__
272+
assert result.__annotations__["ctx"] == Context
273+
268274
def test_client_signature_updated_without_context(self) -> None:
269275
"""Test that client parameter is removed without ctx."""
270276

@@ -289,6 +295,10 @@ async def sample_func(client: AsyncClientProtocol, a: int) -> str:
289295
assert result.__doc__ is not None
290296
assert ":param client:" not in result.__doc__
291297

298+
# Check annotations were updated
299+
assert "client" not in result.__annotations__
300+
assert "ctx" not in result.__annotations__
301+
292302
@pytest.mark.asyncio
293303
async def test_client_context_missing_raises_error(self) -> None:
294304
"""Test that missing context raises ValueError."""

0 commit comments

Comments
 (0)