Skip to content

Commit b243e2b

Browse files
dhsifssxbingW
authored andcommitted
refactor: tools
1 parent 34fc8fa commit b243e2b

File tree

9 files changed

+254
-149
lines changed

9 files changed

+254
-149
lines changed

mcp_server/README.md

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,27 @@ This tool used to say hello to someone
2121

2222
```python
2323
from pydantic import BaseModel, Field
24-
from tools import register_tool, Tool
24+
from tools import Tool, ABCTool, tools
2525

26+
# register to global tools
27+
@tools.register
2628
# Hello describe function paramters
27-
class Hello(BaseModel):
29+
class Hello(BaseModel, ABCTool):
30+
# tools paramters
2831
name: str = Field(description="username to say hello")
2932

30-
# hello is tool logic
31-
async def hello(arguments: dict) -> str:
32-
"""
33-
Say hello to someone
34-
"""
35-
return f"Hello {arguments['name']}"
36-
37-
# register tool to global variable
38-
register_tool(
39-
Tool(
40-
name="hello",
41-
description="say hello to someone",
42-
inputSchema=Hello.model_json_schema()
43-
),
44-
hello
45-
)
46-
```
33+
# run is tool logic, must use classmethod
34+
@classmethod
35+
async def run(arguments: dict) -> str:
36+
return f"Hello {arguments['name']}"
4737

48-
2. import this tool in `tools/__init__.py`
38+
# tool description, must use classmethod
39+
@classmethod
40+
def tool(self) -> Tool:
41+
return Tool(
42+
name="hello",
43+
description="say hello to someone",
44+
inputSchema=self.model_json_schema()
45+
)
4946

50-
```python
51-
from . import hello
52-
```
47+
```

mcp_server/config/config.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
import os
2+
import logging
23

34
class Config:
45
SAFELINE_ADDRESS: str
56
SAFELINE_API_TOKEN: str
67
SECRET: str
78
LISTEN_PORT: int
89
LISTEN_ADDRESS: str
10+
DEBUG: bool
911

1012
def __init__(self):
13+
set_log_level()
14+
15+
if os.getenv("MCP_SERVER_DEBUG"):
16+
self.DEBUG = True
17+
else:
18+
self.DEBUG = False
1119
self.SAFELINE_ADDRESS = os.getenv("SAFELINE_ADDRESS")
1220
if self.SAFELINE_ADDRESS:
1321
self.SAFELINE_ADDRESS = self.SAFELINE_ADDRESS.removesuffix("/")
@@ -26,4 +34,23 @@ def __init__(self):
2634

2735
@staticmethod
2836
def from_env():
29-
return Config()
37+
return Config()
38+
39+
40+
def set_log_level():
41+
level = logging.WARN
42+
log_level = os.getenv("MCO_SERVER_LOG_LEVEL")
43+
if log_level:
44+
match log_level.lower():
45+
case "debug":
46+
level = logging.DEBUG
47+
case "info":
48+
level = logging.INFO
49+
case "warn":
50+
level = logging.WARN
51+
case "error":
52+
level = logging.ERROR
53+
case "critical":
54+
level = logging.CRITICAL
55+
56+
logging.basicConfig(level=level)

mcp_server/middleware.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from starlette.requests import HTTPConnection
2+
from starlette.responses import PlainTextResponse
3+
from starlette.types import ASGIApp, Receive, Scope, Send
4+
from config import GLOBAL_CONFIG
5+
6+
7+
class AuthenticationMiddleware:
8+
def __init__(
9+
self,
10+
app: ASGIApp,
11+
) -> None:
12+
self.app = app
13+
14+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
15+
conn = HTTPConnection(scope)
16+
if GLOBAL_CONFIG.SECRET and GLOBAL_CONFIG.SECRET != "" and conn.headers.get("Secret") != GLOBAL_CONFIG.SECRET:
17+
response = PlainTextResponse("Unauthorized", status_code=401)
18+
await response(scope, receive, send)
19+
return
20+
21+
await self.app(scope, receive, send)

mcp_server/server.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
from starlette.requests import Request
77
import uvicorn
88
from starlette.responses import PlainTextResponse
9-
import tools
9+
from tools import tools
1010
from config import GLOBAL_CONFIG
11+
from middleware import AuthenticationMiddleware
12+
from starlette.middleware import Middleware
1113

1214
# Create an MCP server
1315
mcp_server = Server("SafeLine WAF mcp server")
1416
sse = mcp.server.sse.SseServerTransport("/messages/")
1517

1618
@mcp_server.list_tools()
1719
async def list_tools() -> list[Tool]:
18-
return tools.ALL_TOOLS
20+
return tools.all()
1921

2022
@mcp_server.call_tool()
2123
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
@@ -29,9 +31,6 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
2931

3032

3133
async def handle_sse(request: Request) -> None:
32-
if GLOBAL_CONFIG.SECRET and GLOBAL_CONFIG.SECRET != "" and request.headers.get("Secret") != GLOBAL_CONFIG.SECRET:
33-
return PlainTextResponse("Unauthorized", status_code=401)
34-
3534
async with sse.connect_sse(
3635
request.scope, request.receive, request._send
3736
) as [read_stream, write_stream]:
@@ -40,9 +39,9 @@ async def handle_sse(request: Request) -> None:
4039
)
4140

4241
def main():
43-
starlette_app = Starlette(debug=True,routes=[
42+
starlette_app = Starlette(debug=GLOBAL_CONFIG.DEBUG,routes=[
4443
Route("/sse", endpoint=handle_sse),
45-
Mount("/messages/", app=sse.handle_post_message),
44+
Mount("/messages/", app=sse.handle_post_message, middleware=[Middleware(AuthenticationMiddleware)]),
4645
])
4746

4847
uvicorn.run(starlette_app, host=GLOBAL_CONFIG.LISTEN_ADDRESS, port=GLOBAL_CONFIG.LISTEN_PORT)

mcp_server/tools/__init__.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,49 @@
11
from mcp.types import Tool
2-
from typing import Callable
2+
from abc import ABC, abstractmethod
3+
import os
4+
import importlib
5+
import logging
6+
class ABCTool(ABC):
7+
@classmethod
8+
@abstractmethod
9+
async def run(self, arguments:dict) -> str:
10+
pass
311

4-
ALL_TOOLS = []
5-
TOOL_FUNC_MAP = {}
12+
@classmethod
13+
@abstractmethod
14+
def tool(self) -> Tool:
15+
pass
616

7-
def register_tool(tool: Tool, func: Callable):
8-
ALL_TOOLS.append(tool)
9-
TOOL_FUNC_MAP[tool.name] = func
17+
class ToolRegister:
18+
_dict: dict[str, ABCTool] = {}
19+
20+
@classmethod
21+
def register(self, tool: ABCTool) -> ABCTool:
22+
tool_name = tool.tool().name
23+
logging.info(f"Registering tool: {tool_name}")
24+
if tool_name in self._dict:
25+
raise ValueError(f"Tool {tool_name} already registered")
26+
27+
self._dict[tool_name] = tool
28+
return tool
1029

11-
async def run(name:str, arguments:dict) -> str:
12-
if name not in TOOL_FUNC_MAP:
13-
return f"Unknown tool: {name}"
30+
def all(self) -> list[Tool]:
31+
return [tool.tool() for tool in self._dict.values()]
1432

15-
return await TOOL_FUNC_MAP[name](arguments)
33+
async def run(self, name: str, arguments: dict) -> str:
34+
if name not in self._dict:
35+
raise ValueError(f"Unknown tool: {name}")
36+
37+
return await self._dict[name].run(arguments)
38+
39+
def import_all_tools():
40+
for module in os.listdir(os.path.dirname(__file__)):
41+
if module == "__init__.py" or len(module) < 3 or not module.endswith(".py"):
42+
continue
43+
44+
module_name = module[:-3]
45+
importlib.import_module(f".{module_name}", package=__name__)
46+
47+
tools = ToolRegister()
1648

17-
from . import create_black_custom_rule, create_http_application
49+
import_all_tools()

mcp_server/tools/create_black_custom_rule.py

Lines changed: 0 additions & 60 deletions
This file was deleted.
Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,44 @@
11
from pydantic import BaseModel, Field
22
from utils.request import post_slce_api
3-
from tools import register_tool, Tool
3+
from tools import Tool, ABCTool, tools
44
from urllib.parse import urlparse
5-
6-
class CreateHttpApplication(BaseModel):
5+
@tools.register
6+
class CreateHttpApplication(BaseModel, ABCTool):
77
domain: str = Field(default="",description="application domain, if empty, match all domain")
88
port: int = Field(description="application listen port, must between 1 and 65535")
99
upstream: str = Field(description="application proxy address, must be a valid url")
1010

11-
12-
async def create_http_application(arguments:dict) -> str:
13-
"""
14-
Create a new HTTP application.
15-
16-
Args:
17-
domain: application domain
18-
port: application listen port
19-
upstream: application proxy address
20-
"""
21-
22-
port = arguments["port"]
23-
upstream = arguments["upstream"]
24-
domain = arguments["domain"]
25-
26-
if port is None or port < 1 or port > 65535:
27-
return "invalid port"
28-
29-
parsed_upstream = urlparse(upstream)
30-
if parsed_upstream.scheme != "https" and parsed_upstream.scheme != "http":
31-
return "invalid upstream scheme"
32-
if parsed_upstream.hostname == "":
33-
return "invalid upstream host"
34-
35-
return await post_slce_api("/api/open/site",{
36-
"server_names": [domain],
37-
"ports": [ str(port) ],
38-
"upstreams": [ upstream ],
39-
"type": 0,
40-
"static_default": 1,
41-
"health_check": True,
42-
"load_balance": {
43-
"balance_type": 1
44-
}
45-
})
46-
47-
register_tool(
48-
Tool(
49-
name="create_http_application",
50-
description="在雷池 WAF 上创建一个站点应用",
51-
inputSchema=CreateHttpApplication.model_json_schema()
52-
),
53-
create_http_application
54-
)
11+
@classmethod
12+
async def run(self, arguments:dict) -> str:
13+
port = arguments["port"]
14+
upstream = arguments["upstream"]
15+
domain = arguments["domain"]
16+
17+
if port is None or port < 1 or port > 65535:
18+
return "invalid port"
19+
20+
parsed_upstream = urlparse(upstream)
21+
if parsed_upstream.scheme != "https" and parsed_upstream.scheme != "http":
22+
return "invalid upstream scheme"
23+
if parsed_upstream.hostname == "":
24+
return "invalid upstream host"
25+
26+
return await post_slce_api("/api/open/site",{
27+
"server_names": [domain],
28+
"ports": [ str(port) ],
29+
"upstreams": [ upstream ],
30+
"type": 0,
31+
"static_default": 1,
32+
"health_check": True,
33+
"load_balance": {
34+
"balance_type": 1
35+
}
36+
})
37+
38+
@classmethod
39+
def tool(self) -> Tool:
40+
return Tool(
41+
name="create_http_application",
42+
description="在雷池 WAF 上创建一个站点应用",
43+
inputSchema=self.model_json_schema()
44+
)

0 commit comments

Comments
 (0)