Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion tests/v1/distributed/test_external_lb_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@
DP_SIZE = int(os.getenv("DP_SIZE", "2"))
# Default tensor parallel size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
# Make sure CCL worker count is set for data parallelism
os.environ["CCL_WORKER_COUNT"] = str(DP_SIZE)

import socket

Check failure on line 26 in tests/v1/distributed/test_external_lb_dp.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E402)

tests/v1/distributed/test_external_lb_dp.py:26:1: E402 Module level import not at top of file


def is_port_available(port: int, host: str = "127.0.0.1") -> bool:
# Try to bind to the port to check if it's available. This is more reliable
# than trying to connect.
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind((host, port))
return True
except OSError:
return False


def get_unique_port(start_port=8000):
"""Find an available port"""
port = start_port
while not is_port_available(port):
port += 1 # Increment until an available port is found
if port > start_port + 100: # Limit the search range
raise RuntimeError("No available ports")
return port


class ExternalLBServerManager:
Expand All @@ -44,6 +69,14 @@

def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for external LB mode."""

allocated_ports = []
last_port = 7999
for _ in range(self.dp_size):
port = get_unique_port(start_port=last_port + 1)
allocated_ports.append(port)
last_port = port

for rank in range(self.dp_size):
# Create server args for this specific rank
server_args = self.base_server_args.copy()
Expand All @@ -60,7 +93,7 @@
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + rank), # Different port for each rank
str(allocated_ports[rank]), # Different port for each rank
"--api-server-count",
str(self.api_server_count),
]
Expand Down
Loading