diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index 9175245..6d33313 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -12,7 +12,7 @@ from httpx import HTTPStatusError from mcp_scan.direct_scanner import direct_scan, is_direct_scan -from mcp_scan.mcp_client import check_server_with_timeout, scan_mcp_config_file +from mcp_scan.mcp_client import check_server, scan_mcp_config_file from mcp_scan.models import ( Issue, RemoteServer, @@ -207,9 +207,7 @@ async def scan_server(self, server: ServerScanResult) -> ServerScanResult: logger.info("Scanning server: %s", server.name) result = server.clone() try: - result.signature = await check_server_with_timeout( - server.server, self.server_timeout, self.suppress_mcpserver_io - ) + result.signature = await check_server(server.server, self.server_timeout, self.suppress_mcpserver_io) logger.debug( "Server %s has %d prompts, %d resources, %d resouce templates, %d tools", server.name, diff --git a/src/mcp_scan/mcp_client.py b/src/mcp_scan/mcp_client.py index f05e55d..65a67a7 100644 --- a/src/mcp_scan/mcp_client.py +++ b/src/mcp_scan/mcp_client.py @@ -3,9 +3,11 @@ import os import shutil import subprocess +import sys from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncContextManager # noqa: UP035 +from typing import AsyncContextManager, Literal # noqa: UP035 +from urllib.parse import urlparse import pyjson5 from mcp import ClientSession, StdioServerParameters @@ -90,7 +92,7 @@ def get_client( raise ValueError(f"Invalid server config: {server_config}") -async def check_server( +async def _check_server_pass( server_config: StdioServer | RemoteServer | StaticToolsServer, timeout: int, suppress_mcpserver_io: bool, @@ -166,37 +168,65 @@ async def _check_server(verbose: bool) -> ServerSignature: return await _check_server(verbose=not suppress_mcpserver_io) -async def check_server_with_timeout( +async def check_server( server_config: StdioServer | RemoteServer | StaticToolsServer, timeout: int, suppress_mcpserver_io: bool, ) -> ServerSignature: logger.debug("Checking server with timeout: %s seconds", timeout) - try_sse = False - if isinstance(server_config, RemoteServer) and server_config.type is None: - server_config.type = "http" - logger.debug("Remote server with no type, trying http") - try_sse = True - - try: - result = await asyncio.wait_for(check_server(server_config, timeout, suppress_mcpserver_io), timeout) + if not isinstance(server_config, RemoteServer): + result = await asyncio.wait_for(_check_server_pass(server_config, timeout, suppress_mcpserver_io), timeout) logger.debug("Server check completed within timeout") return result - except asyncio.TimeoutError: - if not try_sse: - raise + else: + logger.debug(f"Remote server with url: {server_config.url}, type: {server_config.type or 'none'}") + strategy: list[tuple[Literal["sse", "http"], str]] = [] + url_path = urlparse(server_config.url).path + has_sse_in_url = url_path.endswith("/sse") + if has_sse_in_url: + url_with_sse = server_config.url + url_without_sse = server_config.url.replace("/sse", "") else: - logger.debug("Scan with HTTP failed, retrying with SSE") + url_with_sse = server_config.url + "/sse" + url_without_sse = server_config.url - server_config.type = "sse" - logger.debug("Remote server with no type, trying sse") - try: - result = await asyncio.wait_for(check_server(server_config, timeout, suppress_mcpserver_io), timeout) - logger.debug("Server check completed within timeout") - return result - except asyncio.TimeoutError: - raise + if server_config.type == "http" or server_config.type is None: + strategy.append(("http", url_without_sse)) + strategy.append(("http", url_with_sse)) + strategy.append(("sse", url_with_sse)) + strategy.append(("sse", url_without_sse)) + else: + strategy.append(("sse", url_with_sse)) + strategy.append(("sse", url_without_sse)) + strategy.append(("http", url_without_sse)) + strategy.append(("http", url_with_sse)) + + exceptions: list[Exception] = [] + for protocol, url in strategy: + try: + server_config.type = protocol + server_config.url = url + logger.debug(f"Trying {protocol} with url: {url}") + result = await asyncio.wait_for( + _check_server_pass(server_config, timeout, suppress_mcpserver_io), timeout + ) + logger.debug("Server check completed within timeout") + return result + except asyncio.TimeoutError as e: + logger.debug("Server check timed out") + exceptions.append(e) + continue + except Exception as e: + logger.debug("Server check failed") + exceptions.append(e) + continue + + # if python 3.11 or higher, use ExceptionGroup + if sys.version_info >= (3, 11): + raise ExceptionGroup("Could not connect to remote server", exceptions) # noqa: F821 + else: + raise Exception("Could not connect to remote server.") from exceptions[0] async def scan_mcp_config_file(path: str) -> MCPConfig: diff --git a/tests/unit/test_config_scan.py b/tests/unit/test_config_scan.py index 97e47b4..708bbb4 100644 --- a/tests/unit/test_config_scan.py +++ b/tests/unit/test_config_scan.py @@ -16,7 +16,7 @@ ) from pytest_lazy_fixtures import lf -from mcp_scan.mcp_client import check_server, check_server_with_timeout, scan_mcp_config_file +from mcp_scan.mcp_client import check_server, scan_mcp_config_file from mcp_scan.models import StdioServer, UnknownMCPConfig @@ -111,7 +111,7 @@ async def test_math_server(): path = "tests/mcp_servers/configs_files/math_config.json" servers = (await scan_mcp_config_file(path)).get_servers() for name, server in servers.items(): - signature = await check_server_with_timeout(server, 5, False) + signature = await check_server(server, 5, False) if name == "Math": assert len(signature.prompts) == 1 assert len(signature.resources) == 0 @@ -129,7 +129,7 @@ async def test_all_server(): path = "tests/mcp_servers/configs_files/all_config.json" servers = (await scan_mcp_config_file(path)).get_servers() for name, server in servers.items(): - signature = await check_server_with_timeout(server, 5, False) + signature = await check_server(server, 5, False) if name == "Math": assert len(signature.prompts) == 1 assert len(signature.resources) == 0 @@ -152,7 +152,7 @@ async def test_weather_server(): path = "tests/mcp_servers/configs_files/weather_config.json" servers = (await scan_mcp_config_file(path)).get_servers() for name, server in servers.items(): - signature = await check_server_with_timeout(server, 5, False) + signature = await check_server(server, 5, False) if name == "Weather": assert {t.name for t in signature.tools} == {"weather"} assert {p.name for p in signature.prompts} == {"good_morning"} diff --git a/tests/unit/test_control_server.py b/tests/unit/test_control_server.py index 9bef229..b003400 100644 --- a/tests/unit/test_control_server.py +++ b/tests/unit/test_control_server.py @@ -284,7 +284,7 @@ async def test_get_servers_from_path_sets_parse_error_and_uploads_payload(): @pytest.mark.asyncio async def test_scan_server_sets_http_status_error_and_uploads_payload(): """ - Patch MCPScanner to return a server, then make check_server_with_timeout raise HTTPStatusError and + Patch MCPScanner to return a server, then make check_server raise HTTPStatusError and ensure the server-level error message "server returned HTTP status code" is included on upload. """ @@ -296,7 +296,7 @@ def get_servers(self): patch.object(sys.modules["mcp_scan.MCPScanner"], "scan_mcp_config_file", return_value=DummyCfg()), patch.object( sys.modules["mcp_scan.MCPScanner"], - "check_server_with_timeout", + "check_server", side_effect=httpx.HTTPStatusError("bad", request=None, response=None), ), patch("mcp_scan.upload.get_user_info") as mock_get_user_info, @@ -329,7 +329,7 @@ def get_servers(self): @pytest.mark.asyncio async def test_scan_server_sets_could_not_start_error_and_uploads_payload(): """ - Patch MCPScanner to return a server, then make check_server_with_timeout raise a generic Exception and + Patch MCPScanner to return a server, then make check_server raise a generic Exception and ensure the server-level error message "could not start server" is included on upload. """ @@ -339,9 +339,7 @@ def get_servers(self): with ( patch.object(sys.modules["mcp_scan.MCPScanner"], "scan_mcp_config_file", return_value=DummyCfg()), - patch.object( - sys.modules["mcp_scan.MCPScanner"], "check_server_with_timeout", side_effect=Exception("spawn failed") - ), + patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server", side_effect=Exception("spawn failed")), patch("mcp_scan.upload.get_user_info") as mock_get_user_info, ): mock_get_user_info.return_value = ScanUserInfo() @@ -491,7 +489,7 @@ def get_servers(self): with ( patch.object(sys.modules["mcp_scan.MCPScanner"], "scan_mcp_config_file", return_value=DummyCfg()), - patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server_with_timeout", return_value=None), + patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server", return_value=None), ): async with MCPScanner(files=["/dummy/path"]) as scanner: result = await scanner.scan_path("/dummy/path", inspect_only=True) @@ -527,7 +525,7 @@ def get_servers(self): with ( patch.object(sys.modules["mcp_scan.MCPScanner"], "scan_mcp_config_file", return_value=DummyCfg()), - patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server_with_timeout", return_value=None), + patch.object(sys.modules["mcp_scan.MCPScanner"], "check_server", return_value=None), ): async with MCPScanner(files=["/dummy/path"]) as scanner: result = await scanner.scan_path("/dummy/path", inspect_only=True) diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py index 3eaeb85..db23eaa 100644 --- a/tests/unit/test_mcp_client.py +++ b/tests/unit/test_mcp_client.py @@ -16,7 +16,7 @@ ) from pytest_lazy_fixtures import lf -from mcp_scan.mcp_client import check_server, check_server_with_timeout, scan_mcp_config_file +from mcp_scan.mcp_client import _check_server_pass, check_server, scan_mcp_config_file from mcp_scan.models import StdioServer @@ -91,7 +91,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): # Test function with mocks with patch("mcp_scan.mcp_client.ClientSession", MockClientSession): server = StdioServer(command="mcp", args=["run", "some_file.py"]) - signature = await check_server(server, 2, True) + signature = await _check_server_pass(server, 2, True) # Verify the results assert len(signature.prompts) == 2 @@ -104,7 +104,7 @@ async def test_math_server(): path = "tests/mcp_servers/configs_files/math_config.json" servers = (await scan_mcp_config_file(path)).get_servers() for name, server in servers.items(): - signature = await check_server_with_timeout(server, 5, False) + signature = await check_server(server, 5, False) if name == "Math": assert len(signature.prompts) == 1 assert len(signature.resources) == 0 @@ -122,7 +122,7 @@ async def test_all_server(): path = "tests/mcp_servers/configs_files/all_config.json" servers = (await scan_mcp_config_file(path)).get_servers() for name, server in servers.items(): - signature = await check_server_with_timeout(server, 5, False) + signature = await check_server(server, 5, False) if name == "Math": assert len(signature.prompts) == 1 assert len(signature.resources) == 0 @@ -145,9 +145,27 @@ async def test_weather_server(): path = "tests/mcp_servers/configs_files/weather_config.json" servers = (await scan_mcp_config_file(path)).get_servers() for name, server in servers.items(): - signature = await check_server_with_timeout(server, 5, False) + signature = await check_server(server, 5, False) if name == "Weather": assert {t.name for t in signature.tools} == {"weather"} assert {p.name for p in signature.prompts} == {"good_morning"} assert {r.name for r in signature.resources} == {"weathers"} assert {rt.name for rt in signature.resource_templates} == {"weather_description"} + + +@pytest.fixture +def remote_mcp_server_just_url(): + return """ + { + "mcpServers": { + "remote": { + "url": "http://localhost:8000" + } + } + } + """ + + +@pytest.mark.asyncio +async def test_parse_server(): + pass