Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/mcp_scan/MCPScanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from httpx import HTTPStatusError

from mcp_scan.direct_scanner import direct_scan, is_direct_scan
from mcp_scan.mcp_client import check_server_with_timeout, scan_mcp_config_file
from mcp_scan.mcp_client import check_server, scan_mcp_config_file
from mcp_scan.models import (
Issue,
RemoteServer,
Expand Down Expand Up @@ -207,9 +207,7 @@ async def scan_server(self, server: ServerScanResult) -> ServerScanResult:
logger.info("Scanning server: %s", server.name)
result = server.clone()
try:
result.signature = await check_server_with_timeout(
server.server, self.server_timeout, self.suppress_mcpserver_io
)
result.signature = await check_server(server.server, self.server_timeout, self.suppress_mcpserver_io)
logger.debug(
"Server %s has %d prompts, %d resources, %d resouce templates, %d tools",
server.name,
Expand Down
76 changes: 53 additions & 23 deletions src/mcp_scan/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import os
import shutil
import subprocess
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncContextManager # noqa: UP035
from typing import AsyncContextManager, Literal # noqa: UP035
from urllib.parse import urlparse

import pyjson5
from mcp import ClientSession, StdioServerParameters
Expand Down Expand Up @@ -90,7 +92,7 @@ def get_client(
raise ValueError(f"Invalid server config: {server_config}")


async def check_server(
async def _check_server_pass(
server_config: StdioServer | RemoteServer | StaticToolsServer,
timeout: int,
suppress_mcpserver_io: bool,
Expand Down Expand Up @@ -166,37 +168,65 @@ async def _check_server(verbose: bool) -> ServerSignature:
return await _check_server(verbose=not suppress_mcpserver_io)


async def check_server_with_timeout(
async def check_server(
server_config: StdioServer | RemoteServer | StaticToolsServer,
timeout: int,
suppress_mcpserver_io: bool,
) -> ServerSignature:
logger.debug("Checking server with timeout: %s seconds", timeout)
try_sse = False

if isinstance(server_config, RemoteServer) and server_config.type is None:
server_config.type = "http"
logger.debug("Remote server with no type, trying http")
try_sse = True

try:
result = await asyncio.wait_for(check_server(server_config, timeout, suppress_mcpserver_io), timeout)
if not isinstance(server_config, RemoteServer):
result = await asyncio.wait_for(_check_server_pass(server_config, timeout, suppress_mcpserver_io), timeout)
logger.debug("Server check completed within timeout")
return result
except asyncio.TimeoutError:
if not try_sse:
raise
else:
logger.debug(f"Remote server with url: {server_config.url}, type: {server_config.type or 'none'}")
strategy: list[tuple[Literal["sse", "http"], str]] = []
url_path = urlparse(server_config.url).path
has_sse_in_url = url_path.endswith("/sse")
if has_sse_in_url:
url_with_sse = server_config.url
url_without_sse = server_config.url.replace("/sse", "")
else:
logger.debug("Scan with HTTP failed, retrying with SSE")
url_with_sse = server_config.url + "/sse"
url_without_sse = server_config.url

server_config.type = "sse"
logger.debug("Remote server with no type, trying sse")
try:
result = await asyncio.wait_for(check_server(server_config, timeout, suppress_mcpserver_io), timeout)
logger.debug("Server check completed within timeout")
return result
except asyncio.TimeoutError:
raise
if server_config.type == "http" or server_config.type is None:
strategy.append(("http", url_without_sse))
strategy.append(("http", url_with_sse))
strategy.append(("sse", url_with_sse))
strategy.append(("sse", url_without_sse))
else:
strategy.append(("sse", url_with_sse))
strategy.append(("sse", url_without_sse))
strategy.append(("http", url_without_sse))
strategy.append(("http", url_with_sse))

exceptions: list[Exception] = []
for protocol, url in strategy:
try:
server_config.type = protocol
server_config.url = url
logger.debug(f"Trying {protocol} with url: {url}")
result = await asyncio.wait_for(
_check_server_pass(server_config, timeout, suppress_mcpserver_io), timeout
)
logger.debug("Server check completed within timeout")
return result
except asyncio.TimeoutError as e:
logger.debug("Server check timed out")
exceptions.append(e)
continue
except Exception as e:
logger.debug("Server check failed")
exceptions.append(e)
continue

# if python 3.11 or higher, use ExceptionGroup
if sys.version_info >= (3, 11):
raise ExceptionGroup("Could not connect to remote server", exceptions) # noqa: F821
else:
raise Exception("Could not connect to remote server")


async def scan_mcp_config_file(path: str) -> MCPConfig:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/test_config_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from pytest_lazy_fixtures import lf

from mcp_scan.mcp_client import check_server, check_server_with_timeout, scan_mcp_config_file
from mcp_scan.mcp_client import check_server, scan_mcp_config_file
from mcp_scan.models import StdioServer, UnknownMCPConfig


Expand Down Expand Up @@ -111,7 +111,7 @@ async def test_math_server():
path = "tests/mcp_servers/configs_files/math_config.json"
servers = (await scan_mcp_config_file(path)).get_servers()
for name, server in servers.items():
signature = await check_server_with_timeout(server, 5, False)
signature = await check_server(server, 5, False)
if name == "Math":
assert len(signature.prompts) == 1
assert len(signature.resources) == 0
Expand All @@ -129,7 +129,7 @@ async def test_all_server():
path = "tests/mcp_servers/configs_files/all_config.json"
servers = (await scan_mcp_config_file(path)).get_servers()
for name, server in servers.items():
signature = await check_server_with_timeout(server, 5, False)
signature = await check_server(server, 5, False)
if name == "Math":
assert len(signature.prompts) == 1
assert len(signature.resources) == 0
Expand All @@ -152,7 +152,7 @@ async def test_weather_server():
path = "tests/mcp_servers/configs_files/weather_config.json"
servers = (await scan_mcp_config_file(path)).get_servers()
for name, server in servers.items():
signature = await check_server_with_timeout(server, 5, False)
signature = await check_server(server, 5, False)
if name == "Weather":
assert {t.name for t in signature.tools} == {"weather"}
assert {p.name for p in signature.prompts} == {"good_morning"}
Expand Down
14 changes: 6 additions & 8 deletions tests/unit/test_control_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ async def test_get_servers_from_path_sets_parse_error_and_uploads_payload():
@pytest.mark.asyncio
async def test_scan_server_sets_http_status_error_and_uploads_payload():
"""
Patch MCPScanner to return a server, then make check_server_with_timeout raise HTTPStatusError and
Patch MCPScanner to return a server, then make check_server raise HTTPStatusError and
ensure the server-level error message "server returned HTTP status code" is included on upload.
"""

Expand All @@ -296,7 +296,7 @@ def get_servers(self):
patch.object(sys.modules["mcp_scan.MCPScanner"], "scan_mcp_config_file", return_value=DummyCfg()),
patch.object(
sys.modules["mcp_scan.MCPScanner"],
"check_server_with_timeout",
"check_server",
side_effect=httpx.HTTPStatusError("bad", request=None, response=None),
),
patch("mcp_scan.upload.get_user_info") as mock_get_user_info,
Expand Down Expand Up @@ -329,7 +329,7 @@ def get_servers(self):
@pytest.mark.asyncio
async def test_scan_server_sets_could_not_start_error_and_uploads_payload():
"""
Patch MCPScanner to return a server, then make check_server_with_timeout raise a generic Exception and
Patch MCPScanner to return a server, then make check_server raise a generic Exception and
ensure the server-level error message "could not start server" is included on upload.
"""

Expand All @@ -339,9 +339,7 @@ def get_servers(self):

with (
patch.object(sys.modules["mcp_scan.MCPScanner"], "scan_mcp_config_file", return_value=DummyCfg()),
patch.object(
sys.modules["mcp_scan.MCPScanner"], "check_server_with_timeout", side_effect=Exception("spawn failed")
),
patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server", side_effect=Exception("spawn failed")),
patch("mcp_scan.upload.get_user_info") as mock_get_user_info,
):
mock_get_user_info.return_value = ScanUserInfo()
Expand Down Expand Up @@ -491,7 +489,7 @@ def get_servers(self):

with (
patch.object(sys.modules["mcp_scan.MCPScanner"], "scan_mcp_config_file", return_value=DummyCfg()),
patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server_with_timeout", return_value=None),
patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server", return_value=None),
):
async with MCPScanner(files=["/dummy/path"]) as scanner:
result = await scanner.scan_path("/dummy/path", inspect_only=True)
Expand Down Expand Up @@ -527,7 +525,7 @@ def get_servers(self):

with (
patch.object(sys.modules["mcp_scan.MCPScanner"], "scan_mcp_config_file", return_value=DummyCfg()),
patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server_with_timeout", return_value=None),
patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server", return_value=None),
):
async with MCPScanner(files=["/dummy/path"]) as scanner:
result = await scanner.scan_path("/dummy/path", inspect_only=True)
Expand Down
28 changes: 23 additions & 5 deletions tests/unit/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from pytest_lazy_fixtures import lf

from mcp_scan.mcp_client import check_server, check_server_with_timeout, scan_mcp_config_file
from mcp_scan.mcp_client import _check_server_pass, check_server, scan_mcp_config_file
from mcp_scan.models import StdioServer


Expand Down Expand Up @@ -91,7 +91,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
# Test function with mocks
with patch("mcp_scan.mcp_client.ClientSession", MockClientSession):
server = StdioServer(command="mcp", args=["run", "some_file.py"])
signature = await check_server(server, 2, True)
signature = await _check_server_pass(server, 2, True)

# Verify the results
assert len(signature.prompts) == 2
Expand All @@ -104,7 +104,7 @@ async def test_math_server():
path = "tests/mcp_servers/configs_files/math_config.json"
servers = (await scan_mcp_config_file(path)).get_servers()
for name, server in servers.items():
signature = await check_server_with_timeout(server, 5, False)
signature = await check_server(server, 5, False)
if name == "Math":
assert len(signature.prompts) == 1
assert len(signature.resources) == 0
Expand All @@ -122,7 +122,7 @@ async def test_all_server():
path = "tests/mcp_servers/configs_files/all_config.json"
servers = (await scan_mcp_config_file(path)).get_servers()
for name, server in servers.items():
signature = await check_server_with_timeout(server, 5, False)
signature = await check_server(server, 5, False)
if name == "Math":
assert len(signature.prompts) == 1
assert len(signature.resources) == 0
Expand All @@ -145,9 +145,27 @@ async def test_weather_server():
path = "tests/mcp_servers/configs_files/weather_config.json"
servers = (await scan_mcp_config_file(path)).get_servers()
for name, server in servers.items():
signature = await check_server_with_timeout(server, 5, False)
signature = await check_server(server, 5, False)
if name == "Weather":
assert {t.name for t in signature.tools} == {"weather"}
assert {p.name for p in signature.prompts} == {"good_morning"}
assert {r.name for r in signature.resources} == {"weathers"}
assert {rt.name for rt in signature.resource_templates} == {"weather_description"}


@pytest.fixture
def remote_mcp_server_just_url():
return """
{
"mcpServers": {
"remote": {
"url": "http://localhost:8000"
}
}
}
"""


@pytest.mark.asyncio
async def test_parse_server():
pass