Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions src/mcp_scan/MCPScanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,19 @@ async def get_servers_from_path(self, path: str) -> ScanPathResult:
except FileNotFoundError as e:
error_msg = "file does not exist"
logger.exception("%s: %s", error_msg, path)
result.error = ScanError(message=error_msg, exception=e)
# This is a non failing error, so we set is_failure to False.
result.error = ScanError(message=error_msg, exception=e, is_failure=False)
except Exception as e:
error_msg = "could not parse file"
logger.exception("%s: %s", error_msg, path)
result.error = ScanError(message=error_msg, exception=e)
result.error = ScanError(message=error_msg, exception=e, is_failure=True)
return result

def check_server_changed(self, path_result: ScanPathResult) -> list[Issue]:
logger.debug("Checking server changed: %s", path_result.path)
issues: list[Issue] = []
if path_result.servers is None:
return issues
for server_idx, server in enumerate(path_result.servers):
logger.debug(
"Checking for changes in server %d/%d: %s", server_idx + 1, len(path_result.servers), server.name
Expand All @@ -162,6 +165,8 @@ def check_server_changed(self, path_result: ScanPathResult) -> list[Issue]:
def check_whitelist(self, path_result: ScanPathResult) -> list[Issue]:
logger.debug("Checking whitelist for path: %s", path_result.path)
issues: list[Issue] = []
if path_result.servers is None:
return issues
for server_idx, server in enumerate(path_result.servers):
for entity_idx, entity in enumerate(server.entities):
if self.storage_file.is_whitelisted(entity):
Expand Down Expand Up @@ -195,11 +200,11 @@ async def scan_server(self, server: ServerScanResult) -> ServerScanResult:
except HTTPStatusError as e:
error_msg = "server returned HTTP status code"
logger.exception("%s: %s", error_msg, server.name)
result.error = ScanError(message=error_msg, exception=e)
result.error = ScanError(message=error_msg, exception=e, is_failure=True)
except Exception as e:
error_msg = "could not start server"
logger.exception("%s: %s", error_msg, server.name)
result.error = ScanError(message=error_msg, exception=e)
result.error = ScanError(message=error_msg, exception=e, is_failure=True)
await self.emit("server_scanned", result)
return result

Expand All @@ -209,15 +214,16 @@ async def scan_path(self, path: str, inspect_only: bool = False) -> ScanPathResu
logger.info("Scanning path: %s, inspect_only: %s", path, inspect_only)
path_result = await self.get_servers_from_path(path)

for i, server in enumerate(path_result.servers):
if server.server.type == "stdio":
full_command = server.server.command + " " + " ".join(server.server.args or [])
# check if pattern is contained in full_command
if re.search(r"mcp[-_]scan.*mcp-server", full_command):
logger.info("Skipping scan of server %d/%d: %s", i + 1, len(path_result.servers), server.name)
continue
logger.debug("Scanning server %d/%d: %s", i + 1, len(path_result.servers), server.name)
path_result.servers[i] = await self.scan_server(server)
if path_result.servers is not None:
for i, server in enumerate(path_result.servers):
if server.server.type == "stdio":
full_command = server.server.command + " " + " ".join(server.server.args or [])
# check if pattern is contained in full_command
if re.search(r"mcp[-_]scan.*mcp-server", full_command):
logger.info("Skipping scan of server %d/%d: %s", i + 1, len(path_result.servers), server.name)
continue
logger.debug("Scanning server %d/%d: %s", i + 1, len(path_result.servers), server.name)
path_result.servers[i] = await self.scan_server(server)

# add built-in tools
if self.include_built_in:
Expand Down
10 changes: 7 additions & 3 deletions src/mcp_scan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ class ScanError(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
message: str | None = None
exception: Exception | None = None
is_failure: bool = True

@field_serializer("exception")
def serialize_exception(self, exception: Exception | None, _info) -> str | None:
Expand All @@ -165,6 +166,7 @@ def clone(self) -> "ScanError":
return ScanError(
message=self.message,
exception=self.exception,
is_failure=self.is_failure,
)


Expand Down Expand Up @@ -233,14 +235,16 @@ class ScanPathResult(BaseModel):
model_config = ConfigDict()
client: str | None = None
path: str
servers: list[ServerScanResult] = Field(default_factory=list)
# servers is None if the MCP configuration file was missing or unparseable
# which prevented server discovery.
servers: list[ServerScanResult] | None = None
issues: list[Issue] = Field(default_factory=list)
labels: list[list[ScalarToolLabels]] = Field(default_factory=list)
error: ScanError | None = None

@property
def entities(self) -> list[Entity]:
return list(chain.from_iterable(server.entities for server in self.servers))
return list(chain.from_iterable(server.entities for server in self.servers)) if self.servers else []

def clone(self) -> "ScanPathResult":
"""
Expand All @@ -250,7 +254,7 @@ def clone(self) -> "ScanPathResult":
output = ScanPathResult(
path=self.path,
client=self.client,
servers=[server.clone() for server in self.servers],
servers=[server.clone() for server in self.servers] if self.servers else None,
issues=[issue.model_copy(deep=True) for issue in self.issues],
labels=[[label.model_copy(deep=True) for label in labels] for labels in self.labels],
error=self.error.clone() if self.error else None,
Expand Down
2 changes: 1 addition & 1 deletion src/mcp_scan/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def print_scan_path_result(
issues = [issue for issue in result.issues if issue.reference == (server_idx, entity_idx)]
server_print.add(format_entity_line(entity, issues, inspect_mode))

if len(result.servers) > 0:
if result.servers is not None and len(result.servers) > 0:
rich.print(path_print_tree)

# print global issues
Expand Down
2 changes: 2 additions & 0 deletions src/mcp_scan/verify_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def setup_tcp_connector() -> aiohttp.TCPConnector:
async def analyze_scan_path(
scan_path: ScanPathResult, base_url: str, additional_headers: dict = {}, opt_out_of_identity: bool = False, verbose: bool = False
) -> ScanPathResult:
if scan_path.servers is None:
return scan_path
url = base_url[:-1] if base_url.endswith("/") else base_url
if "snyk.io" not in base_url:
url = url + "/api/v1/public/mcp-analysis"
Expand Down
3 changes: 2 additions & 1 deletion src/mcp_scan/well_known_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def get_builtin_tools(path_result: ScanPathResult) -> ScanPathResult:
meta={},
)
)

if output.servers is None:
output.servers = []
output.servers.append(
ServerScanResult(name=f"{client_display_name} (built-in)", server=server, signature=signature)
)
Expand Down
11 changes: 10 additions & 1 deletion tests/unit/test_control_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def test_upload_includes_scan_error_in_payload():
path_result_with_error = ScanPathResult(
path="/test/path",
servers=[server],
error=ScanError(message=scan_error_message, exception=Exception(exception_message)),
error=ScanError(message=scan_error_message, exception=Exception(exception_message), is_failure=True),
)

with patch("mcp_scan.upload.get_user_info") as mock_get_user_info:
Expand Down Expand Up @@ -196,6 +196,7 @@ async def test_upload_includes_scan_error_in_payload():
assert "error" in sent_result and sent_result["error"] is not None
assert scan_error_message in sent_result["error"].get("message")
assert exception_message in sent_result["error"].get("exception")
assert sent_result["error"]["is_failure"] is True


@pytest.mark.asyncio
Expand Down Expand Up @@ -225,8 +226,10 @@ async def test_get_servers_from_path_sets_file_not_found_error_and_uploads_paylo

payload = json.loads(mock_post_method.call_args.kwargs["data"])
sent_result = payload["scan_path_results"][0]
assert sent_result["servers"] is None
assert sent_result["path"] == "/nonexistent/path"
assert sent_result["error"]["message"] == "file does not exist"
assert sent_result["error"]["is_failure"] is False
assert "missing" in (sent_result["error"].get("exception") or "")


Expand Down Expand Up @@ -257,8 +260,10 @@ async def test_get_servers_from_path_sets_parse_error_and_uploads_payload():

payload = json.loads(mock_post_method.call_args.kwargs["data"])
sent_result = payload["scan_path_results"][0]
assert sent_result["servers"] is None
assert sent_result["path"] == "/bad/config"
assert sent_result["error"]["message"] == "could not parse file"
assert sent_result["error"]["is_failure"] is True
assert "parse failure" in (sent_result["error"].get("exception") or "")


Expand Down Expand Up @@ -294,8 +299,10 @@ def get_servers(self):
await upload([result], "https://control.mcp.scan", None, False)

payload = json.loads(mock_post_method.call_args.kwargs["data"])
assert payload["scan_path_results"][0]["servers"] is not None
sent_result = payload["scan_path_results"][0]
assert sent_result["servers"][0]["error"]["message"] == "server returned HTTP status code"
assert sent_result["servers"][0]["error"]["is_failure"] is True


@pytest.mark.asyncio
Expand Down Expand Up @@ -329,8 +336,10 @@ def get_servers(self):
await upload([result], "https://control.mcp.scan", None, False)

payload = json.loads(mock_post_method.call_args.kwargs["data"])
assert payload["scan_path_results"][0]["servers"] is not None
sent_result = payload["scan_path_results"][0]
assert sent_result["servers"][0]["error"]["message"] == "could not start server"
assert sent_result["servers"][0]["error"]["is_failure"] is True


@pytest.mark.asyncio
Expand Down