Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions src/mcp_scan/Storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from mcp_scan_server.models import DEFAULT_GUARDRAIL_CONFIG, GuardrailConfigFile

from .models import Entity, ScannedEntities, ScannedEntity, entity_type_to_str, hash_entity
from .utils import upload_whitelist_entry

# Set up logger for this module
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -167,19 +166,11 @@ def print_whitelist(self) -> None:
rich.print(entity_type, name, self.whitelist[key])
rich.print(f"[bold]{len(whitelist_keys)} entries in whitelist[/bold]")

def add_to_whitelist(self, entity_type: str, name: str, hash: str, base_url: str | None = None) -> None:
def add_to_whitelist(self, entity_type: str, name: str, hash: str) -> None:
key = f"{entity_type}.{name}"
logger.info("Adding to whitelist: %s with hash: %s", key, hash)
self.whitelist[key] = hash
self.save()
if base_url is not None:
logger.debug("Uploading whitelist entry to base URL: %s", base_url)
with contextlib.suppress(Exception):
try:
asyncio.run(upload_whitelist_entry(name, hash, base_url))
logger.info("Successfully uploaded whitelist entry to remote server")
except Exception as e:
logger.warning("Failed to upload whitelist entry: %s", e)

def is_whitelisted(self, entity: Entity) -> bool:
hash = hash_entity(entity)
Expand Down
16 changes: 9 additions & 7 deletions src/mcp_scan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,7 @@ def server(on_exit=None):
sf.print_whitelist()
sys.exit(0)
elif all(x is not None for x in [args.type, args.name, args.hash]):
sf.add_to_whitelist(
args.type,
args.name,
args.hash,
base_url=args.base_url if not args.local_only else None,
)
sf.add_to_whitelist(args.type, args.name, args.hash)
sf.print_whitelist()
sys.exit(0)
else:
Expand Down Expand Up @@ -580,7 +575,14 @@ async def run_scan_inspect(mode="scan", args=None):
and args.control_server
and hasattr(args, "opt_out")
):
await upload(result, args.control_server, args.control_identifier, args.opt_out, additional_headers=parse_headers(args.control_server_H))
await upload(
result,
args.control_server,
args.control_identifier,
args.opt_out,
verbose=args.verbose,
additional_headers=parse_headers(args.control_server_H),
)
return result

async def print_scan_inspect(mode="scan", args=None):
Expand Down
14 changes: 12 additions & 2 deletions src/mcp_scan/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from mcp_scan.identity import IdentityManager
from mcp_scan.models import ScanPathResult, ScanUserInfo, ScanPathResultsCreate
from mcp_scan.well_known_clients import get_client_from_path
from mcp_scan.verify_api import setup_aiohttp_debug_logging, setup_tcp_connector

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,7 +52,12 @@ def get_user_info(identifier: str | None = None, opt_out: bool = False) -> ScanU


async def upload(
results: list[ScanPathResult], control_server: str, identifier: str | None = None, opt_out: bool = False, additional_headers: dict = {}
results: list[ScanPathResult],
control_server: str,
identifier: str | None = None,
opt_out: bool = False,
verbose: bool = False,
additional_headers: dict | None = None,
) -> None:
"""
Upload the scan results to the control server.
Expand Down Expand Up @@ -80,8 +86,12 @@ async def upload(
scan_user_info=user_info
)

trace_configs = setup_aiohttp_debug_logging(verbose=verbose)
tcp_connector = setup_tcp_connector()
additional_headers = additional_headers or {}

try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trace_configs=trace_configs, connector=tcp_connector) as session:
headers = {"Content-Type": "application/json"}
headers.update(additional_headers)

Expand Down
15 changes: 0 additions & 15 deletions src/mcp_scan/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
import os
import tempfile

import aiohttp
from lark import Lark
from rapidfuzz.distance import Levenshtein

Expand Down Expand Up @@ -49,19 +47,6 @@ def rebalance_command_args(command, args):
return command, args


async def upload_whitelist_entry(name: str, hash: str, base_url: str):
url = base_url + "/api/v1/public/mcp-whitelist"
headers = {"Content-Type": "application/json"}
data = {
"name": name,
"hash": hash,
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, data=json.dumps(data)) as response:
if response.status != 200:
raise Exception(f"Failed to upload whitelist entry: {response.status} - {response.text}")


class TempFile:
"""A windows compatible version of tempfile.NamedTemporaryFile."""

Expand Down
73 changes: 37 additions & 36 deletions src/mcp_scan/verify_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,58 @@
identity_manager = IdentityManager()


def setup_aiohttp_debug_logging():
def setup_aiohttp_debug_logging(verbose: bool) -> list[aiohttp.TraceConfig]:
"""Setup detailed aiohttp logging and tracing for debugging purposes."""
# Enable aiohttp internal logging
aiohttp_logger = logging.getLogger('aiohttp')
aiohttp_logger.setLevel(logging.DEBUG)
aiohttp_client_logger = logging.getLogger('aiohttp.client')
aiohttp_client_logger.setLevel(logging.DEBUG)

# Create trace config for detailed aiohttp logging
trace_config = aiohttp.TraceConfig()


if verbose:
return []

async def on_request_start(session, trace_config_ctx, params):
logger.debug("aiohttp: Starting request %s %s", params.method, params.url)

async def on_request_end(session, trace_config_ctx, params):
logger.debug("aiohttp: Request completed %s %s -> %s",
logger.debug("aiohttp: Request completed %s %s -> %s",
params.method, params.url, params.response.status)

async def on_connection_create_start(session, trace_config_ctx, params):
logger.debug("aiohttp: Creating connection")

async def on_connection_create_end(session, trace_config_ctx, params):
logger.debug("aiohttp: Connection created")

async def on_dns_resolvehost_start(session, trace_config_ctx, params):
logger.debug("aiohttp: Starting DNS resolution for %s", params.host)

async def on_dns_resolvehost_end(session, trace_config_ctx, params):
logger.debug("aiohttp: DNS resolution completed for %s", params.host)

async def on_connection_queued_start(session, trace_config_ctx, params):
logger.debug("aiohttp: Connection queued")

async def on_connection_queued_end(session, trace_config_ctx, params):
logger.debug("aiohttp: Connection dequeued")

async def on_request_exception(session, trace_config_ctx, params):
logger.error("aiohttp: Request exception for %s %s: %s",
logger.error("aiohttp: Request exception for %s %s: %s",
params.method, params.url, params.exception)
# Check if it's an SSL-related exception
if hasattr(params.exception, '__class__'):
exc_name = params.exception.__class__.__name__
if 'ssl' in exc_name.lower() or 'certificate' in str(params.exception).lower():
logger.error("aiohttp: SSL/Certificate error detected: %s", params.exception)

async def on_request_redirect(session, trace_config_ctx, params):
logger.debug("aiohttp: Request redirected from %s %s to %s",
params.method, params.url, params.response.headers.get('Location', 'unknown'))

trace_config.on_request_start.append(on_request_start)
trace_config.on_request_end.append(on_request_end)
trace_config.on_connection_create_start.append(on_connection_create_start)
Expand All @@ -76,8 +79,21 @@ async def on_request_redirect(session, trace_config_ctx, params):
trace_config.on_connection_queued_end.append(on_connection_queued_end)
trace_config.on_request_exception.append(on_request_exception)
trace_config.on_request_redirect.append(on_request_redirect)

return trace_config

return [trace_config]


def setup_tcp_connector() -> aiohttp.TCPConnector:
"""
Setup a TCP connector with a default SSL context and cleanup enabled.
"""
ssl_context = ssl.create_default_context(cafile=certifi.where())
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
connector = aiohttp.TCPConnector(
ssl=ssl_context,
enable_cleanup_closed=True
)
return connector


async def analyze_scan_path(
Expand Down Expand Up @@ -106,28 +122,13 @@ async def analyze_scan_path(

# Server signatures do not contain any information about the user setup. Only about the server itself.
try:
# Setup debugging if verbose mode is enabled
trace_configs = []
if verbose:
trace_config = setup_aiohttp_debug_logging()
trace_configs.append(trace_config)
trace_configs = setup_aiohttp_debug_logging(verbose=verbose)
tcp_connector = setup_tcp_connector()

# explicitly creating the ssl context sidesepts SSL issues
ssl_context = ssl.create_default_context(cafile=certifi.where())

if verbose:
logger.debug("aiohttp: SSL context created - verify_mode=%s, check_hostname=%s",
ssl_context.verify_mode, ssl_context.check_hostname)

connector = aiohttp.TCPConnector(
ssl=ssl_context,
enable_cleanup_closed=True
)

if verbose:
logger.debug("aiohttp: TCPConnector created")
async with aiohttp.ClientSession(connector=connector, trace_configs=trace_configs) as session:

async with aiohttp.ClientSession(connector=tcp_connector, trace_configs=trace_configs) as session:
async with session.post(url, headers=headers, data=payload.model_dump_json()) as response:
if response.status == 200:
results = AnalysisServerResponse.model_validate_json(await response.read())
Expand Down