diff --git a/configs/config.yaml b/configs/config.yaml index b105ac75..613f537e 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -25,6 +25,7 @@ vllm: max_retries: 3 # Number of retries for API calls retry_delay: 1.0 # Initial delay between retries (seconds) sleep_time: 0.1 # Small delay in seconds between batches to avoid rate limits + http_request_timeout: 180 # Http Request timeout in seconds (3 minutes) # API endpoint configuration api-endpoint: diff --git a/synthetic_data_kit/models/llm_client.py b/synthetic_data_kit/models/llm_client.py index 4f964100..dd09dcdd 100644 --- a/synthetic_data_kit/models/llm_client.py +++ b/synthetic_data_kit/models/llm_client.py @@ -11,6 +11,7 @@ import os import logging import asyncio +import aiohttp from pathlib import Path from synthetic_data_kit.utils.config import load_config, get_vllm_config, get_openai_config, get_llm_provider @@ -36,7 +37,8 @@ def __init__(self, api_key: Optional[str] = None, model_name: Optional[str] = None, max_retries: Optional[int] = None, - retry_delay: Optional[float] = None): + retry_delay: Optional[float] = None, + http_request_timeout: Optional[int] = None): """Initialize an LLM client that supports multiple providers Args: @@ -92,6 +94,7 @@ def __init__(self, self.max_retries = max_retries or vllm_config.get('max_retries') self.retry_delay = retry_delay or vllm_config.get('retry_delay') self.sleep_time = vllm_config.get('sleep_time',0.1) + self.http_request_timeout = vllm_config.get('http_request_timeout', 180) # No client to initialize for vLLM as we use requests directly # Verify server is running @@ -304,7 +307,7 @@ def _vllm_chat_completion(self, f"{self.api_base}/chat/completions", headers={"Content-Type": "application/json"}, data=json.dumps(data), - timeout=180 # Increased timeout to 180 seconds + timeout=self.http_request_timeout # made the http timeout dynamic ) if verbose: @@ -500,12 +503,6 @@ def _openai_batch_completion(self, if verbose: logger.info(f"Processing batch {i//batch_size + 1}/{(len(message_batches) + batch_size - 1) // batch_size} with {len(batch_chunk)} requests") - # Import asyncio here to avoid issues if not available - try: - import asyncio - except ImportError: - raise ImportError("The 'asyncio' package is required for batch processing. Please ensure you're using Python 3.7+.") - # Define async batch processing function async def process_batch(): tasks = [] @@ -534,13 +531,13 @@ async def process_batch(): return results def _vllm_batch_completion(self, - message_batches: List[List[Dict[str, str]]], - temperature: float, - max_tokens: int, - top_p: float, - batch_size: int, - verbose: bool) -> List[str]: - """Process multiple message sets in batches using vLLM's API""" + message_batches: List[List[Dict[str, str]]], + temperature: float, + max_tokens: int, + top_p: float, + batch_size: int, + verbose: bool) -> List[str]: + """Process multiple message sets in true batches using vLLM's API with concurrent requests""" results = [] # Process message batches in chunks to avoid overloading the server @@ -549,49 +546,124 @@ def _vllm_batch_completion(self, if verbose: logger.info(f"Processing batch {i//batch_size + 1}/{(len(message_batches) + batch_size - 1) // batch_size} with {len(batch_chunk)} requests") - # Create batch request payload for VLLM - batch_requests = [] - for messages in batch_chunk: - batch_requests.append({ - "model": self.model, - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens, - "top_p": top_p - }) - - try: - # For now, we run these in parallel with multiple requests - batch_results = [] - for request_data in batch_requests: - # Only print if verbose mode is enabled - if verbose: - logger.info(f"Sending batch request to vLLM model {self.model}...") - - response = requests.post( - f"{self.api_base}/chat/completions", - headers={"Content-Type": "application/json"}, - data=json.dumps(request_data), - timeout=180 # Increased timeout for batch processing - ) - - if verbose: - logger.info(f"Received response with status code: {response.status_code}") - - response.raise_for_status() - content = response.json()["choices"][0]["message"]["content"] - batch_results.append(content) - - results.extend(batch_results) - - except (requests.exceptions.RequestException, KeyError, IndexError) as e: - raise Exception(f"Failed to process vLLM batch: {str(e)}") + # Run the async batch processing + batch_results = asyncio.run(self._process_vllm_batch_async( + batch_chunk, temperature, max_tokens, top_p, verbose, batch_size + )) + results.extend(batch_results) - # Small delay between batches + # Small delay between batches to avoid rate limits if i + batch_size < len(message_batches): time.sleep(self.sleep_time) return results + + async def _process_vllm_batch_async(self, + batch_chunk: List[List[Dict[str, str]]], + temperature: float, + max_tokens: int, + top_p: float, + verbose: bool, + batch_size: int) -> List[str]: + """Process a batch of requests asynchronously using aiohttp""" + + async def process_single_request(session: aiohttp.ClientSession, + messages: List[Dict[str, str]], + semaphore: asyncio.Semaphore, + http_request_timeout: int,) -> str: + """Process a single request with retry logic""" + data = { + "model": self.model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p + } + + async with semaphore: # Limit concurrent requests + for attempt in range(self.max_retries): + try: + if verbose and attempt == 0: # Only log on first attempt + logger.info(f"Sending async request to vLLM model {self.model}...") + + async with session.post( + f"{self.api_base}/chat/completions", + headers={"Content-Type": "application/json"}, + data=json.dumps(data), + timeout=aiohttp.ClientTimeout(total=http_request_timeout) # 300 minutes timeout + ) as response: + + if verbose and attempt == 0: + logger.info(f"Received response with status code: {response.status}") + + response.raise_for_status() + response_json = await response.json() + + try: + return response_json["choices"][0]["message"]["content"] + except (KeyError, IndexError) as e: + raise ValueError(f"Invalid response format: {e}") + + except asyncio.TimeoutError: + error_msg = f"Request timeout on attempt {attempt + 1}/{self.max_retries}" + if verbose: + logger.warning(error_msg) + if attempt == self.max_retries - 1: + return f"ERROR: {error_msg}" + + except aiohttp.ClientError as e: + error_msg = f"HTTP error on attempt {attempt + 1}/{self.max_retries}: {str(e)}" + if verbose: + logger.warning(error_msg) + if attempt == self.max_retries - 1: + return f"ERROR: {error_msg}" + + except Exception as e: + error_msg = f"Unexpected error on attempt {attempt + 1}/{self.max_retries}: {str(e)}" + if verbose: + logger.warning(error_msg) + if attempt == self.max_retries - 1: + return f"ERROR: {error_msg}" + + # Exponential backoff between retries + if attempt < self.max_retries - 1: + await asyncio.sleep(self.retry_delay * (attempt + 1)) + + # Create semaphore to limit concurrent connections (prevent overwhelming the server) + max_concurrent = min(batch_size, 1024) # Cap at 1024 concurrent requests + semaphore = asyncio.Semaphore(max_concurrent) + + # Create aiohttp session with connection pooling + connector = aiohttp.TCPConnector( + limit=max_concurrent * 2, # Total connection pool size + limit_per_host=max_concurrent, # Connections per host + ttl_dns_cache=300, # DNS cache TTL + use_dns_cache=True, + ) + + timeout = aiohttp.ClientTimeout(total=300) # 5 minutes total timeout + + async with aiohttp.ClientSession( + connector=connector, + timeout=timeout, + headers={"Content-Type": "application/json"} + ) as session: + # Create tasks for all requests in the batch + tasks = [] + for messages in batch_chunk: + task = process_single_request(session, messages, semaphore, self.http_request_timeout) + tasks.append(task) + + if verbose: + logger.info(f"Starting {len(tasks)} concurrent requests...") + + # Execute all requests concurrently + results = await asyncio.gather(*tasks, return_exceptions=False) + + if verbose: + logger.info(f"Completed {len(results)} requests") + + return results @classmethod def from_config(cls, config_path: Path) -> 'LLMClient':