Skip to content

Commit c238924

Browse files
committed
to pass lint and add log support for multiple tools
1 parent 4dd383b commit c238924

16 files changed

+99
-44
lines changed

.github/workflows/check-leaks.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ permissions:
1515

1616
jobs:
1717
scan:
18-
if: github.repository_owner == 'MiroMindAI'
1918
name: gitleaks
2019
runs-on: ubuntu-latest
2120
steps:

.github/workflows/check-pr-title.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ name: check-pr-title
22

33
on:
44
pull_request:
5-
types: [opened, synchronize]
5+
types: [opened, synchronize, edited]
66

77
jobs:
88
check-pr-title:
9-
if: github.repository_owner == 'MiroMindAI'
109
name: Check PR Title
1110
runs-on: ubuntu-latest
1211
steps:

.github/workflows/run-ruff.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ on:
66

77
jobs:
88
lint:
9-
if: github.repository_owner == 'MiroMindAI'
109
name: lint pull request
1110
runs-on: ubuntu-latest
1211
steps:

common_benchmark.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,20 @@
1919
from omegaconf import DictConfig, OmegaConf
2020

2121
from utils.eval_utils import verify_answer_for_datasets
22-
from src.logging.logger import bootstrap_logger, task_logging_context, init_logging_for_benchmark_evaluation
22+
from src.logging.logger import (
23+
bootstrap_logger,
24+
task_logging_context,
25+
init_logging_for_benchmark_evaluation,
26+
)
2327
from config import config_name, config_path
2428
from src.core.pipeline import (
2529
create_pipeline_components,
2630
execute_task_pipeline,
2731
)
32+
2833
init_logging_for_benchmark_evaluation(print_task_logs=False)
2934

35+
3036
class TaskStatus(StrEnum):
3137
PENDING = "pending"
3238
RUN_FAILED = "run_failed"
@@ -373,7 +379,7 @@ async def run_parallel_inference(
373379
async def run_with_semaphore(task):
374380
async with semaphore:
375381
with task_logging_context(task.task_id, self.get_log_dir()):
376-
result = await self.run_single_task(task)
382+
result = await self.run_single_task(task)
377383
return result
378384

379385
# Shuffle tasks to avoid order bias and improve balancing

src/logging/logger.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88
import logging
99
from functools import lru_cache
1010
from pathlib import Path
11-
from typing import Literal, Dict
11+
from typing import Literal
1212
from contextvars import ContextVar
1313
import hydra
1414
from rich.console import Console
1515
from rich.logging import RichHandler
1616
import asyncio
1717
import threading
1818
from contextlib import contextmanager
19+
1920
TASK_CONTEXT_VAR: ContextVar[str | None] = ContextVar("CURRENT_TASK_ID", default=None)
2021

22+
2123
class ZMQLogHandler(logging.Handler):
2224
def __init__(self, addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"):
2325
super().__init__()
@@ -34,6 +36,7 @@ def emit(self, record):
3436
except Exception:
3537
self.handleError(record)
3638

39+
3740
async def zmq_log_listener(bind_addr="tcp://127.0.0.1:6000"):
3841
ctx = zmq.asyncio.Context()
3942
sock = ctx.socket(zmq.PULL)
@@ -47,23 +50,30 @@ async def zmq_log_listener(bind_addr="tcp://127.0.0.1:6000"):
4750
task_id, tool_name, msg = raw.split("||", 2)
4851

4952
record = root_logger.makeRecord(
50-
name=f'[TOOL] {tool_name}',
53+
name=f"[TOOL] {tool_name}",
5154
level=logging.INFO,
52-
fn="", lno=0, msg=msg, args=(),
53-
exc_info=None
55+
fn="",
56+
lno=0,
57+
msg=msg,
58+
args=(),
59+
exc_info=None,
5460
)
5561
record.task_id = task_id
5662

5763
root_logger.handle(record)
5864
else:
5965
root_logger.info(raw)
6066

67+
6168
def start_zmq_listener():
6269
loop = asyncio.new_event_loop()
6370
asyncio.set_event_loop(loop)
6471
loop.run_until_complete(zmq_log_listener())
6572

66-
def setup_mcp_logging(level="INFO", addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"):
73+
74+
def setup_mcp_logging(
75+
level="INFO", addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"
76+
):
6777
root = logging.getLogger()
6878
root.setLevel(level)
6979

@@ -78,29 +88,36 @@ def setup_mcp_logging(level="INFO", addr="tcp://127.0.0.1:6000", tool_name="unkn
7888
for h in logger.handlers[:]:
7989
logger.removeHandler(h)
8090
h.close()
81-
logger.propagate = True # 确保冒泡到 root
91+
logger.propagate = True # Ensure bubbling to root
8292

8393
# Re-add the ZMQ handler
8494
handler = ZMQLogHandler(addr=addr, tool_name=tool_name)
85-
handler.setFormatter(logging.Formatter("[TOOL] %(asctime)s %(levelname)s: %(message)s"))
95+
handler.setFormatter(
96+
logging.Formatter("[TOOL] %(asctime)s %(levelname)s: %(message)s")
97+
)
8698
root.addHandler(handler)
8799

100+
88101
def setup_log_record_factory():
89102
old_factory = logging.getLogRecordFactory()
103+
90104
def record_factory(*args, **kwargs):
91105
record = old_factory(*args, **kwargs)
92106
record.task_id = TASK_CONTEXT_VAR.get()
93107
return record
108+
94109
logging.setLogRecordFactory(record_factory)
95110

111+
96112
class TaskFilter(logging.Filter):
97113
def __init__(self, task_id: str):
98114
super().__init__()
99115
self.task_id = task_id
100-
116+
101117
def filter(self, record: logging.LogRecord) -> bool:
102118
return getattr(record, "task_id", None) == self.task_id
103119

120+
104121
def make_task_logger(task_id: str, log_dir: Path) -> logging.Handler:
105122
log_dir.mkdir(parents=True, exist_ok=True)
106123
file_path = log_dir / f"task_{task_id}.log"
@@ -111,9 +128,10 @@ def make_task_logger(task_id: str, log_dir: Path) -> logging.Handler:
111128
logging.getLogger().addHandler(fh)
112129
return fh
113130

131+
114132
def remove_all_console_handlers():
115133
"""
116-
移除当前进程中所有 logger 上的 console handler (StreamHandler/RichHandler)
134+
Remove all console handlers (StreamHandler/RichHandler) from all loggers in the current process.
117135
"""
118136
for name, logger in logging.Logger.manager.loggerDict.items():
119137
if isinstance(logger, logging.Logger):
@@ -134,6 +152,7 @@ def remove_all_console_handlers():
134152
root_logger.removeHandler(h)
135153
h.close()
136154

155+
137156
@contextmanager
138157
def task_logging_context(task_id: str, log_dir: Path):
139158
token = TASK_CONTEXT_VAR.set(task_id)
@@ -145,21 +164,25 @@ def task_logging_context(task_id: str, log_dir: Path):
145164
logging.getLogger().removeHandler(handler)
146165
handler.close()
147166

167+
148168
def init_logging_for_benchmark_evaluation(print_task_logs=False):
149-
threading.Thread(target=start_zmq_listener, daemon=True).start() #monitoring tool logs
150-
logging.basicConfig(handlers=[])
169+
threading.Thread(
170+
target=start_zmq_listener, daemon=True
171+
).start() # monitoring tool logs
172+
logging.basicConfig(handlers=[])
151173
setup_log_record_factory()
152174
if not print_task_logs:
153-
remove_all_console_handlers()
175+
remove_all_console_handlers()
176+
154177

155178
@lru_cache
156179
def bootstrap_logger(
157180
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | int = "INFO",
158181
logger_name: str = "miroflow",
159182
logger: logging.Logger | None = None,
160-
log_dir: str | Path | None = None, # 日志存储目录
161-
log_filename: str = "miroflow.log", # 默认日志文件名
162-
to_console: bool = True, # 是否显示到 console
183+
log_dir: str | Path | None = None, # Log storage directory
184+
log_filename: str = "miroflow.log", # Default log filename
185+
to_console: bool = True, # Whether to display to console
163186
) -> logging.Logger:
164187
"""Configure only this logger, not the root logger"""
165188
if logger is None:
@@ -173,7 +196,7 @@ def bootstrap_logger(
173196
console=Console(
174197
stderr=True,
175198
width=200,
176-
color_system=None,
199+
color_system=None,
177200
force_terminal=False,
178201
legacy_windows=False,
179202
),
@@ -191,9 +214,9 @@ def bootstrap_logger(
191214
log_dir.mkdir(parents=True, exist_ok=True)
192215
file_path = log_dir / log_filename
193216
file_handler = logging.FileHandler(file_path, encoding="utf-8")
194-
file_handler.setFormatter(logging.Formatter(
195-
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
196-
))
217+
file_handler.setFormatter(
218+
logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
219+
)
197220
logger.addHandler(file_handler)
198221

199222
logger.setLevel(level)

src/tool/manager.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,20 @@
2020

2121
R = TypeVar("R")
2222

23-
def update_server_params_with_context_var(server_params: StdioServerParameters) -> StdioServerParameters:
23+
24+
def update_server_params_with_context_var(
25+
server_params: StdioServerParameters,
26+
) -> StdioServerParameters:
2427
"""
2528
Update the server params with the context var.
2629
"""
2730
from src.logging.logger import TASK_CONTEXT_VAR
31+
2832
if TASK_CONTEXT_VAR.get() is not None:
2933
server_params.env["TASK_ID"] = TASK_CONTEXT_VAR.get()
3034
return server_params
3135

36+
3237
def with_timeout(timeout_s: float = 300.0):
3338
"""
3439
Decorator: wraps any *async* function in asyncio.wait_for().
@@ -116,7 +121,9 @@ async def _find_servers_with_tool(self, tool_name):
116121

117122
try:
118123
if isinstance(server_params, StdioServerParameters):
119-
async with stdio_client(update_server_params_with_context_var(server_params)) as (read, write):
124+
async with stdio_client(
125+
update_server_params_with_context_var(server_params)
126+
) as (read, write):
120127
async with ClientSession(
121128
read, write, sampling_callback=None
122129
) as session:
@@ -176,7 +183,9 @@ async def get_all_tool_definitions(self):
176183

177184
try:
178185
if isinstance(server_params, StdioServerParameters):
179-
async with stdio_client(update_server_params_with_context_var(server_params)) as (read, write):
186+
async with stdio_client(
187+
update_server_params_with_context_var(server_params)
188+
) as (read, write):
180189
async with ClientSession(
181190
read, write, sampling_callback=None
182191
) as session:
@@ -350,7 +359,9 @@ async def execute_tool_call(self, server_name, tool_name, arguments) -> Any:
350359
try:
351360
result_content = None
352361
if isinstance(server_params, StdioServerParameters):
353-
async with stdio_client(update_server_params_with_context_var(server_params)) as (read, write):
362+
async with stdio_client(
363+
update_server_params_with_context_var(server_params)
364+
) as (read, write):
354365
async with ClientSession(
355366
read, write, sampling_callback=None
356367
) as session:

src/tool/mcp_servers/audio_mcp_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import contextlib
1515
from mutagen import File as MutagenFile
1616
import asyncio
17+
from src.logging.logger import setup_mcp_logging
18+
1719

1820
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
1921
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
@@ -25,10 +27,10 @@
2527
)
2628

2729
# Initialize FastMCP server
28-
from src.logging.logger import setup_mcp_logging
2930
setup_mcp_logging(tool_name=os.path.basename(__file__))
3031
mcp = FastMCP("audio-mcp-server")
3132

33+
3234
def _get_audio_extension(url: str, content_type: str = None) -> str:
3335
"""
3436
Determine the appropriate audio file extension from URL or content type.
@@ -290,4 +292,4 @@ async def audio_question_answering(audio_path_or_url: str, question: str) -> str
290292

291293

292294
if __name__ == "__main__":
293-
mcp.run(transport="stdio",show_banner=False)
295+
mcp.run(transport="stdio", show_banner=False)

src/tool/mcp_servers/audio_mcp_server_os.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
from fastmcp import FastMCP
2626
from mutagen import File as MutagenFile
2727
from openai import OpenAI
28+
from src.logging.logger import setup_mcp_logging
2829

2930
WHISPER_API_KEY = os.environ.get("WHISPER_API_KEY")
3031
WHISPER_BASE_URL = os.environ.get("WHISPER_BASE_URL")
3132
WHISPER_MODEL_NAME = os.environ.get("WHISPER_MODEL_NAME")
3233

3334
# Initialize FastMCP server
35+
setup_mcp_logging(tool_name=os.path.basename(__file__))
3436
mcp = FastMCP("audio-mcp-server-os")
3537

3638

@@ -210,4 +212,4 @@ async def audio_transcription(audio_path_or_url: str) -> str:
210212

211213

212214
if __name__ == "__main__":
213-
mcp.run(transport="stdio")
215+
mcp.run(transport="stdio", show_banner=False)

src/tool/mcp_servers/browser_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,4 @@ async def test_persistent_session():
9797

9898

9999
if __name__ == "__main__":
100-
asyncio.run(test_persistent_session(),show_banner=False)
100+
asyncio.run(test_persistent_session(), show_banner=False)

src/tool/mcp_servers/python_server.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# Initialize FastMCP server
1212
from src.logging.logger import setup_mcp_logging
13+
1314
setup_mcp_logging(tool_name=os.path.basename(__file__))
1415
mcp = FastMCP("e2b-python-interpreter")
1516

@@ -413,4 +414,4 @@ async def download_file_from_sandbox_to_local(
413414

414415

415416
if __name__ == "__main__":
416-
mcp.run(transport="stdio",show_banner=False)
417+
mcp.run(transport="stdio", show_banner=False)

0 commit comments

Comments
 (0)