Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def initialize_all(app: FastAPI, args):
prefill_model_labels=args.prefill_model_labels,
decode_model_labels=args.decode_model_labels,
kv_aware_threshold=args.kv_aware_threshold,
max_instance_failover_reroute_attempts=args.max_instance_failover_reroute_attempts,
)

# Initialize feature gates
Expand Down
7 changes: 7 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,13 @@ def parse_args():
help="The threshold for kv-aware routing.",
)

parser.add_argument(
"--max-instance-failover-reroute-attempts",
type=int,
default=0,
help="Number of reroute attempts per failed request",
)

args = parser.parse_args()
args = load_initial_config_from_config_file_if_required(parser, args)

Expand Down
20 changes: 15 additions & 5 deletions src/vllm_router/routers/routing_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def _qps_routing(
ret = url
return ret

def set_request_migration(self, max_instance_failover_reroute_attempts):
self.max_instance_failover_reroute_attempts = (
max_instance_failover_reroute_attempts
)

def _update_hash_ring(self, endpoints: List["EndpointInfo"]):
"""
Update the hash ring with the current list of endpoints.
Expand Down Expand Up @@ -466,10 +471,10 @@ def initialize_routing_logic(
) -> RoutingInterface:
if routing_logic == RoutingLogic.ROUND_ROBIN:
logger.info("Initializing round-robin routing logic")
return RoundRobinRouter()
router = RoundRobinRouter()
elif routing_logic == RoutingLogic.SESSION_BASED:
logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}")
return SessionRouter(kwargs.get("session_key"))
router = SessionRouter(kwargs.get("session_key"))
elif routing_logic == RoutingLogic.KVAWARE:
logger.info("Initializing kvaware routing logic")
router = KvawareRouter(
Expand All @@ -478,17 +483,22 @@ def initialize_routing_logic(
kwargs.get("kv_aware_threshold"),
)
router.start_kv_manager()
return router
elif routing_logic == RoutingLogic.PREFIXAWARE:
logger.info("Initializing prefix-aware routing logic")
return PrefixAwareRouter()
router = PrefixAwareRouter()
elif routing_logic == RoutingLogic.DISAGGREGATED_PREFILL:
logger.info("Initializing disaggregated prefill routing logic")
return DisaggregatedPrefillRouter(
router = DisaggregatedPrefillRouter(
kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels")
)
else:
raise ValueError(f"Invalid routing logic {routing_logic}")
router.set_request_migration(
max_instance_failover_reroute_attempts=kwargs.get(
"max_instance_failover_reroute_attempts"
)
)
return router


def reconfigure_routing_logic(
Expand Down
194 changes: 115 additions & 79 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# --- Request Processing & Routing ---
import asyncio
import json
import os
import time
Expand Down Expand Up @@ -136,8 +137,57 @@ async def process_request(
)


def perform_service_discovery(
request, request_json, request_endpoint, requested_model, error_urls
):
service_discovery = get_service_discovery()
endpoints = service_discovery.get_endpoint_info()

aliases = getattr(service_discovery, "aliases", None)
if aliases and requested_model in aliases.keys():
requested_model = aliases[requested_model]
request_body = replace_model_in_request_body(request_json, requested_model)
update_content_length(request, request_body)

if not request_endpoint:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and not x.sleep
and x.url not in error_urls,
endpoints,
)
)
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
request_stats = request.app.state.request_stats_monitor.get_request_stats(
time.time()
)
else:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and x.Id == request_endpoint
and not x.sleep
and x.url not in error_urls,
endpoints,
)
)
engine_stats, request_stats = None, None

if not endpoints:
return JSONResponse(
status_code=400,
content={
"error": f"Model {requested_model} not found or vLLM engine is sleeping."
},
)
return endpoints, engine_stats, request_stats


async def route_general_request(
request: Request, endpoint: str, background_tasks: BackgroundTasks
request: Request,
endpoint: str,
background_tasks: BackgroundTasks,
):
"""
Route the incoming request to the backend server and stream the response back to the client.
Expand Down Expand Up @@ -203,96 +253,82 @@ async def route_general_request(
status_code=400, detail="Request body is not JSON parsable."
)

service_discovery = get_service_discovery()
endpoints = service_discovery.get_endpoint_info()
# Perform service discovery to request path a number of times equal to reroutes + 1
error_urls = set()
for _ in range(request.app.state.router.max_instance_failover_reroute_attempts + 1):
endpoints, engine_stats, request_stats = await asyncio.to_thread(
perform_service_discovery,
request,
request_json,
request_endpoint,
requested_model,
error_urls,
)

aliases = getattr(service_discovery, "aliases", None)
if aliases and requested_model in aliases.keys():
requested_model = aliases[requested_model]
request_body = replace_model_in_request_body(request_json, requested_model)
update_content_length(request, request_body)
logger.debug(f"Routing request {request_id} for model: {requested_model}")
if request_endpoint:
server_url = endpoints[0].url
logger.debug(
f"Routing request {request_id} to engine with Id: {endpoints[0].Id}"
)

if not request_endpoint:
endpoints = list(
filter(
lambda x: requested_model in x.model_names and not x.sleep,
endpoints,
elif isinstance(request.app.state.router, KvawareRouter) or isinstance(
request.app.state.router, PrefixAwareRouter
):
server_url = await request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request, request_json
)
)
engine_stats = request.app.state.engine_stats_scraper.get_engine_stats()
request_stats = request.app.state.request_stats_monitor.get_request_stats(
time.time()
)
else:
endpoints = list(
filter(
lambda x: requested_model in x.model_names
and x.Id == request_endpoint
and not x.sleep,
endpoints,
else:
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
)
)

if not endpoints:
return JSONResponse(
status_code=400,
content={
"error": f"Model {requested_model} not found or vLLM engine is sleeping."
},
curr_time = time.time()
# Extract actual session ID from request headers for logging
session_key = (
getattr(request.app.state.router, "session_key", None)
if hasattr(request.app.state.router, "session_key")
else None
)

logger.debug(f"Routing request {request_id} for model: {requested_model}")
if request_endpoint:
server_url = endpoints[0].url
logger.debug(
f"Routing request {request_id} to engine with Id: {endpoints[0].Id}"
session_id = (
request.headers.get(session_key, None) if session_key is not None else None
)
session_id_display = session_id if session_id is not None else "None"

elif isinstance(request.app.state.router, KvawareRouter) or isinstance(
request.app.state.router, PrefixAwareRouter
):
server_url = await request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request, request_json
# Debug logging to help troubleshoot session ID extraction
logger.debug(
f"Debug session extraction - Router type: {type(request.app.state.router).__name__}"
)
else:
server_url = request.app.state.router.route_request(
endpoints, engine_stats, request_stats, request
logger.debug(f"Debug session extraction - Session key config: {session_key}")
logger.debug(
f"Debug session extraction - Request headers: {dict(request.headers)}"
)
logger.debug(f"Debug session extraction - Extracted session ID: {session_id}")

curr_time = time.time()
# Extract actual session ID from request headers for logging
session_key = (
getattr(request.app.state.router, "session_key", None)
if hasattr(request.app.state.router, "session_key")
else None
)
session_id = (
request.headers.get(session_key, None) if session_key is not None else None
)
session_id_display = session_id if session_id is not None else "None"

# Debug logging to help troubleshoot session ID extraction
logger.debug(
f"Debug session extraction - Router type: {type(request.app.state.router).__name__}"
)
logger.debug(f"Debug session extraction - Session key config: {session_key}")
logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}")
logger.debug(f"Debug session extraction - Extracted session ID: {session_id}")
logger.info(
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
)
error = None
try:
stream_generator = process_request(
request,
request_body,
server_url,
request_id,
endpoint,
background_tasks,
)
headers, status = await anext(stream_generator)
headers_dict = {key: value for key, value in headers.items()}
headers_dict["X-Request-Id"] = request_id
# Break out of the loop when the request's stream is fully generated
break
except Exception as e:
error_urls.add(server_url)
error = e

logger.info(
f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}"
)
stream_generator = process_request(
request,
request_body,
server_url,
request_id,
endpoint,
background_tasks,
)
headers, status = await anext(stream_generator)
headers_dict = {key: value for key, value in headers.items()}
headers_dict["X-Request-Id"] = request_id
if error:
raise error
return StreamingResponse(
stream_generator,
status_code=status,
Expand Down