Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp_scan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,6 @@ async def run_scan_inspect(mode="scan", args=None):
result = await scanner.inspect()
else:
raise ValueError(f"Unknown mode: {mode}, expected 'scan' or 'inspect'")

# upload scan result to control servers if specified
if hasattr(args, "control_servers") and args.control_servers:
for server_config in args.control_servers:
Expand Down Expand Up @@ -820,6 +819,7 @@ async def print_scan_inspect(mode="scan", args=None):
args.print_errors,
args.full_toxic_flows if hasattr(args, "full_toxic_flows") else False,
mode == "inspect",
args.verbose,
)


Expand Down
5 changes: 2 additions & 3 deletions src/mcp_scan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,8 @@ def clone(self) -> "ScanError":
class Issue(BaseModel):
code: str
message: str
reference: tuple[int, int] | None = Field(
default=None,
description="The index of the tool the issue references. None if it is global",
reference: None | tuple[int, int | None] = Field(
description="The index of the tool the issue references. (server_index, entity_index) if it is a entity issue, (server_index, None) if it is a server issue, None if it is a global issue",
)
extra_data: dict[str, Any] | None = Field(
default=None,
Expand Down
89 changes: 55 additions & 34 deletions src/mcp_scan/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
MAX_ENTITY_NAME_LENGTH = 25
MAX_ENTITY_NAME_TOXIC_FLOW_LENGTH = 30

ISSUE_COLOR_MAP = {
"successful": "[green]",
"issue": "[red]",
"analysis_error": "[gray62]",
"warning": "[yellow]",
"whitelisted": "[blue]",
"inspect_mode": "[white]",
}


def format_exception(e: Exception | str | None) -> tuple[str, rTraceback | None]:
if e is None:
Expand Down Expand Up @@ -51,10 +60,14 @@ def format_path_line(path: str, status: str | None, operation: str = "Scanning")
return Text.from_markup(text)


def format_servers_line(server: str, status: str | None = None) -> Text:
def format_servers_line(server: str, status: str | None = None, issues: list[Issue] | None = None) -> Text:
text = f"[bold]{server}[/bold]"
gap = 27
text += " " * (max(0, gap - len(text)))
if status:
text += f" [gray62]{status}[/gray62]"
if issues:
text += " " + format_issues(issues)
return Text.from_markup(text)


Expand All @@ -64,6 +77,33 @@ def append_status(status: str, new_status: str) -> str:
return f"{new_status}, {status}"


def format_issues(issues: list[Issue]) -> str:
status_text = " ".join(
[
ISSUE_COLOR_MAP["analysis_error"]
+ rf"\[{issue.code}]: {issue.message}"
+ ISSUE_COLOR_MAP["analysis_error"].replace("[", "[/")
for issue in issues
if issue.code.startswith("X")
]
+ [
ISSUE_COLOR_MAP["issue"]
+ rf"\[{issue.code}]: {issue.message}"
+ ISSUE_COLOR_MAP["issue"].replace("[", "[/")
for issue in issues
if issue.code.startswith("E")
]
+ [
ISSUE_COLOR_MAP["warning"]
+ rf"\[{issue.code}]: {issue.message}"
+ ISSUE_COLOR_MAP["warning"].replace("[", "[/")
for issue in issues
if issue.code.startswith("W")
]
)
return status_text


def format_entity_line(entity: Entity, issues: list[Issue], inspect_mode: bool = False) -> Text:
# is_verified = verified.value
# if is_verified is not None and changed.value is not None:
Expand All @@ -79,15 +119,7 @@ def format_entity_line(entity: Entity, issues: list[Issue], inspect_mode: bool =
else:
status = "successful"

color_map = {
"successful": "[green]",
"issue": "[red]",
"analysis_error": "[gray62]",
"warning": "[yellow]",
"whitelisted": "[blue]",
"inspect_mode": "[white]",
}
color = color_map[status] if not inspect_mode else color_map["inspect_mode"]
color = ISSUE_COLOR_MAP[status] if not inspect_mode else ISSUE_COLOR_MAP["inspect_mode"]
icon_map = {
"successful": ":white_heavy_check_mark:",
"issue": ":cross_mark:",
Expand Down Expand Up @@ -115,25 +147,7 @@ def format_entity_line(entity: Entity, issues: list[Issue], inspect_mode: bool =
# res. temp.
type = type + " " * (len("res. temp.") - len(type))

status_text = " ".join(
[
color_map["analysis_error"]
+ rf"\[{issue.code}]: {issue.message}"
+ color_map["analysis_error"].replace("[", "[/")
for issue in issues
if issue.code.startswith("X")
]
+ [
color_map["issue"] + rf"\[{issue.code}]: {issue.message}" + color_map["issue"].replace("[", "[/")
for issue in issues
if issue.code.startswith("E")
]
+ [
color_map["warning"] + rf"\[{issue.code}]: {issue.message}" + color_map["warning"].replace("[", "[/")
for issue in issues
if issue.code.startswith("W")
]
)
status_text = format_issues(issues)
text = f"{type} {color}[bold]{name}[/bold] {icon} {status_text}"

if include_description:
Expand Down Expand Up @@ -245,16 +259,19 @@ def print_scan_path_result(
path_print_tree = Tree("│")
server_tracebacks = []
for server_idx, server in enumerate(result.servers or []):
server_issues = [issue for issue in result.issues if issue.reference == (server_idx, None)]
if server.error is not None:
err_status, traceback = format_error(server.error)
path_print_tree.add(format_servers_line(server.name or "", err_status))
server_print = path_print_tree.add(
format_servers_line(server.name or "", issues=server_issues, status=err_status)
)
if traceback is not None:
server_tracebacks.append((server, traceback))
else:
server_print = path_print_tree.add(format_servers_line(server.name or ""))
for entity_idx, entity in enumerate(server.entities):
issues = [issue for issue in result.issues if issue.reference == (server_idx, entity_idx)]
server_print.add(format_entity_line(entity, issues, inspect_mode))
server_print = path_print_tree.add(format_servers_line(server.name or "", None, server_issues))
for entity_idx, entity in enumerate(server.entities):
issues = [issue for issue in result.issues if issue.reference == (server_idx, entity_idx)]
server_print.add(format_entity_line(entity, issues, inspect_mode))

if result.servers is not None and len(result.servers) > 0:
rich.print(path_print_tree)
Expand All @@ -278,7 +295,11 @@ def print_scan_result(
print_errors: bool = False,
full_toxic_flows: bool = False,
inspect_mode: bool = False,
internal_issues: bool = False,
) -> None:
if not internal_issues:
for res in result:
res.issues = [issue for issue in res.issues if issue.code not in ["W003", "W004", "W005", "W006"]]
for i, path_result in enumerate(result):
print_scan_path_result(path_result, print_errors, full_toxic_flows, inspect_mode)
if i < len(result) - 1:
Expand Down
7 changes: 7 additions & 0 deletions src/mcp_scan/verify_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ async def analyze_machine(
):
sent_scan_path_result.issues = response_scan_path_result.issues
sent_scan_path_result.labels = response_scan_path_result.labels
for server_given, server_received in zip(
sent_scan_path_result.servers or [],
response_scan_path_result.servers or [],
strict=True,
):
if server_given.signature is None:
server_given.signature = server_received.signature
return scan_paths # Success - exit the function

except aiohttp.ClientResponseError as e:
Expand Down
6 changes: 3 additions & 3 deletions tests/e2e/test_full_scan_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def test_scan(self, path, server_names):

issue_set = {issue["code"] for issue in issues}
if set(server_names) == {"Weather", "Math"}:
allowed_issue_sets = [{"W001", "TF001", "TF002"}, {"W001", "TF002"}, {"W001"}]
allowed_issue_sets = [{"W001", "W003", "TF001", "TF002"}, {"W001", "W003", "TF002"}, {"W001", "W003"}]
elif set(server_names) == {"Weather"}:
allowed_issue_sets = [{"TF001"}, set()]
allowed_issue_sets = [{"W003", "TF001"}, set()]
elif set(server_names) == {"Math"}:
allowed_issue_sets = [{"W001"}, {"W001", "TF002"}]
allowed_issue_sets = [{"W001", "W003"}, {"W001", "W003", "TF002"}]
else:
raise ValueError(f"Invalid server names: {server_names}")
# call list for better error message
Expand Down