Skip to content

Commit 0b2e557

Browse files
committed
to pass lint
1 parent 6838280 commit 0b2e557

File tree

1 file changed

+77
-9
lines changed

1 file changed

+77
-9
lines changed

src/logging/logger.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,101 @@
1616
import asyncio
1717
import threading
1818
from contextlib import contextmanager
19+
import socket
1920

2021
TASK_CONTEXT_VAR: ContextVar[str | None] = ContextVar("CURRENT_TASK_ID", default=None)
2122

23+
# Global variable to store the actual ZMQ address being used
24+
_ZMQ_ADDRESS: str = "tcp://127.0.0.1:6000"
25+
26+
27+
def find_available_port(start_port: int = 6000, max_attempts: int = 10) -> int:
28+
"""Find an available port starting from start_port."""
29+
for port in range(start_port, start_port + max_attempts):
30+
try:
31+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
32+
s.bind(("127.0.0.1", port))
33+
return port
34+
except OSError:
35+
continue
36+
raise RuntimeError(
37+
f"Could not find an available port in range {start_port}-{start_port + max_attempts - 1}"
38+
)
39+
40+
41+
def get_zmq_address() -> str:
42+
"""Get the current ZMQ address."""
43+
return _ZMQ_ADDRESS
44+
45+
46+
def set_zmq_address(address: str) -> None:
47+
"""Set the ZMQ address."""
48+
global _ZMQ_ADDRESS
49+
_ZMQ_ADDRESS = address
50+
51+
52+
def _extract_port_from_address(addr: str) -> int:
53+
"""Extract port number from ZMQ address."""
54+
try:
55+
return int(addr.split(":")[-1])
56+
except (ValueError, IndexError):
57+
return 6000
58+
59+
60+
def _bind_zmq_socket(sock, bind_addr: str) -> str:
61+
"""Bind ZMQ socket to an available port and return the actual address."""
62+
port = _extract_port_from_address(bind_addr)
63+
64+
try:
65+
available_port = find_available_port(port)
66+
actual_addr = f"tcp://127.0.0.1:{available_port}"
67+
sock.bind(actual_addr)
68+
return actual_addr
69+
except RuntimeError:
70+
# Fallback to random port
71+
return sock.bind_to_random_port("tcp://127.0.0.1")
72+
2273

2374
class ZMQLogHandler(logging.Handler):
24-
def __init__(self, addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"):
75+
def __init__(self, addr=None, tool_name="unknown_tool"):
2576
super().__init__()
2677
ctx = zmq.Context()
2778
self.sock = ctx.socket(zmq.PUSH)
28-
self.sock.connect(addr)
79+
80+
# Use the global ZMQ address if no specific address is provided
81+
if addr is None:
82+
addr = get_zmq_address()
83+
84+
# Try to connect to the address
85+
try:
86+
self.sock.connect(addr)
87+
print(f"ZMQ handler connected to: {addr}")
88+
except zmq.error.ZMQError as e:
89+
# If connection fails, disable the handler
90+
print(f"Warning: Could not connect to ZMQ listener at {addr}: {e}")
91+
print("Disabling ZMQ logging for this handler")
92+
self.sock = None
93+
2994
self.task_id = os.environ.get("TASK_ID", "0")
3095
self.tool_name = tool_name
3196

3297
def emit(self, record):
3398
try:
34-
msg = f"{record.getMessage()}"
35-
self.sock.send_string(f"{self.task_id}||{self.tool_name}||{msg}")
99+
if self.sock is not None:
100+
msg = f"{record.getMessage()}"
101+
self.sock.send_string(f"{self.task_id}||{self.tool_name}||{msg}")
36102
except Exception:
37103
self.handleError(record)
38104

39105

40106
async def zmq_log_listener(bind_addr="tcp://127.0.0.1:6000"):
41107
ctx = zmq.asyncio.Context()
42108
sock = ctx.socket(zmq.PULL)
43-
sock.bind(bind_addr)
109+
110+
# Bind to available port
111+
actual_addr = _bind_zmq_socket(sock, bind_addr)
112+
set_zmq_address(actual_addr)
113+
print(f"ZMQ listener bound to: {actual_addr}")
44114

45115
root_logger = logging.getLogger()
46116

@@ -71,9 +141,7 @@ def start_zmq_listener():
71141
loop.run_until_complete(zmq_log_listener())
72142

73143

74-
def setup_mcp_logging(
75-
level="INFO", addr="tcp://127.0.0.1:6000", tool_name="unknown_tool"
76-
):
144+
def setup_mcp_logging(level="INFO", addr=None, tool_name="unknown_tool"):
77145
root = logging.getLogger()
78146
root.setLevel(level)
79147

@@ -90,7 +158,7 @@ def setup_mcp_logging(
90158
h.close()
91159
logger.propagate = True # Ensure bubbling to root
92160

93-
# Re-add the ZMQ handler
161+
# Re-add the ZMQ handler (will use global address if addr is None)
94162
handler = ZMQLogHandler(addr=addr, tool_name=tool_name)
95163
handler.setFormatter(
96164
logging.Formatter("[TOOL] %(asctime)s %(levelname)s: %(message)s")

0 commit comments

Comments
 (0)