From 21c244c3d90a077957f26ea226270fbab73f5512 Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 14 Oct 2025 13:39:09 +0200 Subject: [PATCH 1/2] Add is_failure boolean to ScanError. Update ScanPathResult so that servers can be None. --- src/mcp_scan/MCPScanner.py | 32 ++++++++++++++++++------------ src/mcp_scan/models.py | 10 +++++++--- src/mcp_scan/printer.py | 2 +- src/mcp_scan/well_known_clients.py | 3 ++- tests/unit/test_control_server.py | 11 +++++++++- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/src/mcp_scan/MCPScanner.py b/src/mcp_scan/MCPScanner.py index e6a4822..5be9904 100644 --- a/src/mcp_scan/MCPScanner.py +++ b/src/mcp_scan/MCPScanner.py @@ -132,16 +132,19 @@ async def get_servers_from_path(self, path: str) -> ScanPathResult: except FileNotFoundError as e: error_msg = "file does not exist" logger.exception("%s: %s", error_msg, path) - result.error = ScanError(message=error_msg, exception=e) + # This is a non failing error, so we set is_failure to False. + result.error = ScanError(message=error_msg, exception=e, is_failure=False) except Exception as e: error_msg = "could not parse file" logger.exception("%s: %s", error_msg, path) - result.error = ScanError(message=error_msg, exception=e) + result.error = ScanError(message=error_msg, exception=e, is_failure=True) return result def check_server_changed(self, path_result: ScanPathResult) -> list[Issue]: logger.debug("Checking server changed: %s", path_result.path) issues: list[Issue] = [] + if path_result.servers is None: + return issues for server_idx, server in enumerate(path_result.servers): logger.debug( "Checking for changes in server %d/%d: %s", server_idx + 1, len(path_result.servers), server.name @@ -162,6 +165,8 @@ def check_server_changed(self, path_result: ScanPathResult) -> list[Issue]: def check_whitelist(self, path_result: ScanPathResult) -> list[Issue]: logger.debug("Checking whitelist for path: %s", path_result.path) issues: list[Issue] = [] + if path_result.servers is None: + return issues for server_idx, server in enumerate(path_result.servers): for entity_idx, entity in enumerate(server.entities): if self.storage_file.is_whitelisted(entity): @@ -195,11 +200,11 @@ async def scan_server(self, server: ServerScanResult) -> ServerScanResult: except HTTPStatusError as e: error_msg = "server returned HTTP status code" logger.exception("%s: %s", error_msg, server.name) - result.error = ScanError(message=error_msg, exception=e) + result.error = ScanError(message=error_msg, exception=e, is_failure=True) except Exception as e: error_msg = "could not start server" logger.exception("%s: %s", error_msg, server.name) - result.error = ScanError(message=error_msg, exception=e) + result.error = ScanError(message=error_msg, exception=e, is_failure=True) await self.emit("server_scanned", result) return result @@ -209,15 +214,16 @@ async def scan_path(self, path: str, inspect_only: bool = False) -> ScanPathResu logger.info("Scanning path: %s, inspect_only: %s", path, inspect_only) path_result = await self.get_servers_from_path(path) - for i, server in enumerate(path_result.servers): - if server.server.type == "stdio": - full_command = server.server.command + " " + " ".join(server.server.args or []) - # check if pattern is contained in full_command - if re.search(r"mcp[-_]scan.*mcp-server", full_command): - logger.info("Skipping scan of server %d/%d: %s", i + 1, len(path_result.servers), server.name) - continue - logger.debug("Scanning server %d/%d: %s", i + 1, len(path_result.servers), server.name) - path_result.servers[i] = await self.scan_server(server) + if path_result.servers is not None: + for i, server in enumerate(path_result.servers): + if server.server.type == "stdio": + full_command = server.server.command + " " + " ".join(server.server.args or []) + # check if pattern is contained in full_command + if re.search(r"mcp[-_]scan.*mcp-server", full_command): + logger.info("Skipping scan of server %d/%d: %s", i + 1, len(path_result.servers), server.name) + continue + logger.debug("Scanning server %d/%d: %s", i + 1, len(path_result.servers), server.name) + path_result.servers[i] = await self.scan_server(server) # add built-in tools if self.include_built_in: diff --git a/src/mcp_scan/models.py b/src/mcp_scan/models.py index f7bcf5c..f0b57c1 100644 --- a/src/mcp_scan/models.py +++ b/src/mcp_scan/models.py @@ -148,6 +148,7 @@ class ScanError(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) message: str | None = None exception: Exception | None = None + is_failure: bool = True @field_serializer("exception") def serialize_exception(self, exception: Exception | None, _info) -> str | None: @@ -165,6 +166,7 @@ def clone(self) -> "ScanError": return ScanError( message=self.message, exception=self.exception, + is_failure=self.is_failure, ) @@ -233,14 +235,16 @@ class ScanPathResult(BaseModel): model_config = ConfigDict() client: str | None = None path: str - servers: list[ServerScanResult] = Field(default_factory=list) + # servers is None if the MCP configuration file was missing or unparseable + # which prevented server discovery. + servers: list[ServerScanResult] | None = None issues: list[Issue] = Field(default_factory=list) labels: list[list[ScalarToolLabels]] = Field(default_factory=list) error: ScanError | None = None @property def entities(self) -> list[Entity]: - return list(chain.from_iterable(server.entities for server in self.servers)) + return list(chain.from_iterable(server.entities for server in self.servers)) if self.servers else [] def clone(self) -> "ScanPathResult": """ @@ -250,7 +254,7 @@ def clone(self) -> "ScanPathResult": output = ScanPathResult( path=self.path, client=self.client, - servers=[server.clone() for server in self.servers], + servers=[server.clone() for server in self.servers] if self.servers else None, issues=[issue.model_copy(deep=True) for issue in self.issues], labels=[[label.model_copy(deep=True) for label in labels] for labels in self.labels], error=self.error.clone() if self.error else None, diff --git a/src/mcp_scan/printer.py b/src/mcp_scan/printer.py index 77fc7c6..80fb04f 100644 --- a/src/mcp_scan/printer.py +++ b/src/mcp_scan/printer.py @@ -243,7 +243,7 @@ def print_scan_path_result( issues = [issue for issue in result.issues if issue.reference == (server_idx, entity_idx)] server_print.add(format_entity_line(entity, issues, inspect_mode)) - if len(result.servers) > 0: + if result.servers is not None and len(result.servers) > 0: rich.print(path_print_tree) # print global issues diff --git a/src/mcp_scan/well_known_clients.py b/src/mcp_scan/well_known_clients.py index 0d538f0..a987dd4 100644 --- a/src/mcp_scan/well_known_clients.py +++ b/src/mcp_scan/well_known_clients.py @@ -180,7 +180,8 @@ def get_builtin_tools(path_result: ScanPathResult) -> ScanPathResult: meta={}, ) ) - + if output.servers is None: + output.servers = [] output.servers.append( ServerScanResult(name=f"{client_display_name} (built-in)", server=server, signature=signature) ) diff --git a/tests/unit/test_control_server.py b/tests/unit/test_control_server.py index 45b2be2..e662258 100644 --- a/tests/unit/test_control_server.py +++ b/tests/unit/test_control_server.py @@ -159,7 +159,7 @@ async def test_upload_includes_scan_error_in_payload(): path_result_with_error = ScanPathResult( path="/test/path", servers=[server], - error=ScanError(message=scan_error_message, exception=Exception(exception_message)), + error=ScanError(message=scan_error_message, exception=Exception(exception_message), is_failure=True), ) with patch("mcp_scan.upload.get_user_info") as mock_get_user_info: @@ -196,6 +196,7 @@ async def test_upload_includes_scan_error_in_payload(): assert "error" in sent_result and sent_result["error"] is not None assert scan_error_message in sent_result["error"].get("message") assert exception_message in sent_result["error"].get("exception") + assert sent_result["error"]["is_failure"] is True @pytest.mark.asyncio @@ -225,8 +226,10 @@ async def test_get_servers_from_path_sets_file_not_found_error_and_uploads_paylo payload = json.loads(mock_post_method.call_args.kwargs["data"]) sent_result = payload["scan_path_results"][0] + assert sent_result["servers"] is None assert sent_result["path"] == "/nonexistent/path" assert sent_result["error"]["message"] == "file does not exist" + assert sent_result["error"]["is_failure"] is False assert "missing" in (sent_result["error"].get("exception") or "") @@ -257,8 +260,10 @@ async def test_get_servers_from_path_sets_parse_error_and_uploads_payload(): payload = json.loads(mock_post_method.call_args.kwargs["data"]) sent_result = payload["scan_path_results"][0] + assert sent_result["servers"] is None assert sent_result["path"] == "/bad/config" assert sent_result["error"]["message"] == "could not parse file" + assert sent_result["error"]["is_failure"] is True assert "parse failure" in (sent_result["error"].get("exception") or "") @@ -294,8 +299,10 @@ def get_servers(self): await upload([result], "https://control.mcp.scan", None, False) payload = json.loads(mock_post_method.call_args.kwargs["data"]) + assert payload["scan_path_results"][0]["servers"] is not None sent_result = payload["scan_path_results"][0] assert sent_result["servers"][0]["error"]["message"] == "server returned HTTP status code" + assert sent_result["servers"][0]["error"]["is_failure"] is True @pytest.mark.asyncio @@ -329,8 +336,10 @@ def get_servers(self): await upload([result], "https://control.mcp.scan", None, False) payload = json.loads(mock_post_method.call_args.kwargs["data"]) + assert payload["scan_path_results"][0]["servers"] is not None sent_result = payload["scan_path_results"][0] assert sent_result["servers"][0]["error"]["message"] == "could not start server" + assert sent_result["servers"][0]["error"]["is_failure"] is True @pytest.mark.asyncio From 4eba5fdc5a3c2e9e1b99d2708fae69ae558e6b8a Mon Sep 17 00:00:00 2001 From: Hemang Date: Tue, 14 Oct 2025 13:55:08 +0200 Subject: [PATCH 2/2] fix tests. --- src/mcp_scan/verify_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcp_scan/verify_api.py b/src/mcp_scan/verify_api.py index 20a6c49..735a178 100644 --- a/src/mcp_scan/verify_api.py +++ b/src/mcp_scan/verify_api.py @@ -99,6 +99,8 @@ def setup_tcp_connector() -> aiohttp.TCPConnector: async def analyze_scan_path( scan_path: ScanPathResult, base_url: str, additional_headers: dict = {}, opt_out_of_identity: bool = False, verbose: bool = False ) -> ScanPathResult: + if scan_path.servers is None: + return scan_path url = base_url[:-1] if base_url.endswith("/") else base_url if "snyk.io" not in base_url: url = url + "/api/v1/public/mcp-analysis"