|
18 | 18 | # |
19 | 19 |
|
20 | 20 | import functools |
21 | | -import json |
22 | 21 | import os |
23 | 22 | import signal |
24 | | -import subprocess |
25 | | -import sys |
26 | | -import time |
27 | 23 | from collections.abc import Sequence |
28 | | -from typing import Any, Callable, Optional |
| 24 | +from typing import Callable |
29 | 25 |
|
30 | | -import httpx |
31 | | -import openai |
32 | | -import requests |
33 | 26 | import torch |
34 | 27 | import torch.nn.functional as F |
35 | 28 | from typing_extensions import ParamSpec |
36 | | -from vllm.engine.arg_utils import AsyncEngineArgs |
37 | | -from vllm.entrypoints.cli.serve import ServeSubcommand |
38 | | -from vllm.model_executor.model_loader import get_model_loader |
39 | | -from vllm.utils import FlexibleArgumentParser, get_open_port |
40 | 29 |
|
41 | 30 | _P = ParamSpec("_P") |
42 | 31 |
|
@@ -115,152 +104,3 @@ def check_embeddings_close( |
115 | 104 | f"\n{name_1}:\t{embeddings_1[:16]!r}") |
116 | 105 |
|
117 | 106 | assert sim >= 1 - tol, fail_msg |
118 | | - |
119 | | - |
120 | | -class RemoteOpenAIServer: |
121 | | - DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key |
122 | | - |
123 | | - def _start_server(self, model: str, vllm_serve_args: list[str], |
124 | | - env_dict: Optional[dict[str, str]]) -> None: |
125 | | - """Subclasses override this method to customize server process launch |
126 | | - """ |
127 | | - env = os.environ.copy() |
128 | | - # the current process might initialize npu, |
129 | | - # to be safe, we should use spawn method |
130 | | - env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' |
131 | | - if env_dict is not None: |
132 | | - env.update(env_dict) |
133 | | - self.proc: subprocess.Popen = subprocess.Popen( |
134 | | - ["vllm", "serve", model, *vllm_serve_args], |
135 | | - env=env, |
136 | | - stdout=sys.stdout, |
137 | | - stderr=sys.stderr, |
138 | | - ) |
139 | | - |
140 | | - def __init__(self, |
141 | | - model: str, |
142 | | - vllm_serve_args: list[str], |
143 | | - *, |
144 | | - env_dict: Optional[dict[str, str]] = None, |
145 | | - seed: Optional[int] = 0, |
146 | | - auto_port: bool = True, |
147 | | - max_wait_seconds: Optional[float] = None, |
148 | | - override_hf_configs: Optional[dict[str, Any]] = None) -> None: |
149 | | - if auto_port: |
150 | | - if "-p" in vllm_serve_args or "--port" in vllm_serve_args: |
151 | | - raise ValueError("You have manually specified the port " |
152 | | - "when `auto_port=True`.") |
153 | | - |
154 | | - # No need for a port if using unix sockets |
155 | | - if "--uds" not in vllm_serve_args: |
156 | | - # Don't mutate the input args |
157 | | - vllm_serve_args = vllm_serve_args + [ |
158 | | - "--port", str(get_open_port()) |
159 | | - ] |
160 | | - if seed is not None: |
161 | | - if "--seed" in vllm_serve_args: |
162 | | - raise ValueError("You have manually specified the seed " |
163 | | - f"when `seed={seed}`.") |
164 | | - |
165 | | - vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] |
166 | | - |
167 | | - if override_hf_configs is not None: |
168 | | - vllm_serve_args = vllm_serve_args + [ |
169 | | - "--hf-overrides", |
170 | | - json.dumps(override_hf_configs) |
171 | | - ] |
172 | | - |
173 | | - parser = FlexibleArgumentParser( |
174 | | - description="vLLM's remote OpenAI server.") |
175 | | - subparsers = parser.add_subparsers(required=False, dest="subparser") |
176 | | - parser = ServeSubcommand().subparser_init(subparsers) |
177 | | - args = parser.parse_args(["--model", model, *vllm_serve_args]) |
178 | | - self.uds = args.uds |
179 | | - if args.uds: |
180 | | - self.host = None |
181 | | - self.port = None |
182 | | - else: |
183 | | - self.host = str(args.host or 'localhost') |
184 | | - self.port = int(args.port) |
185 | | - |
186 | | - self.show_hidden_metrics = \ |
187 | | - args.show_hidden_metrics_for_version is not None |
188 | | - |
189 | | - # download the model before starting the server to avoid timeout |
190 | | - is_local = os.path.isdir(model) |
191 | | - if not is_local: |
192 | | - engine_args = AsyncEngineArgs.from_cli_args(args) |
193 | | - model_config = engine_args.create_model_config() |
194 | | - load_config = engine_args.create_load_config() |
195 | | - |
196 | | - model_loader = get_model_loader(load_config) |
197 | | - model_loader.download_model(model_config) |
198 | | - |
199 | | - self._start_server(model, vllm_serve_args, env_dict) |
200 | | - max_wait_seconds = max_wait_seconds or 240 |
201 | | - self._wait_for_server(url=self.url_for("health"), |
202 | | - timeout=max_wait_seconds) |
203 | | - |
204 | | - def __enter__(self): |
205 | | - return self |
206 | | - |
207 | | - def __exit__(self, exc_type, exc_value, traceback): |
208 | | - self.proc.terminate() |
209 | | - try: |
210 | | - self.proc.wait(8) |
211 | | - except subprocess.TimeoutExpired: |
212 | | - # force kill if needed |
213 | | - self.proc.kill() |
214 | | - |
215 | | - def _poll(self) -> Optional[int]: |
216 | | - """Subclasses override this method to customize process polling""" |
217 | | - return self.proc.poll() |
218 | | - |
219 | | - def _wait_for_server(self, *, url: str, timeout: float): |
220 | | - # run health check |
221 | | - start = time.time() |
222 | | - client = (httpx.Client(transport=httpx.HTTPTransport( |
223 | | - uds=self.uds)) if self.uds else requests) |
224 | | - while True: |
225 | | - try: |
226 | | - if client.get(url).status_code == 200: |
227 | | - break |
228 | | - except Exception: |
229 | | - # this exception can only be raised by requests.get, |
230 | | - # which means the server is not ready yet. |
231 | | - # the stack trace is not useful, so we suppress it |
232 | | - # by using `raise from None`. |
233 | | - result = self._poll() |
234 | | - if result is not None and result != 0: |
235 | | - raise RuntimeError("Server exited unexpectedly.") from None |
236 | | - |
237 | | - time.sleep(0.5) |
238 | | - if time.time() - start > timeout: |
239 | | - raise RuntimeError( |
240 | | - "Server failed to start in time.") from None |
241 | | - |
242 | | - @property |
243 | | - def url_root(self) -> str: |
244 | | - return (f"http://{self.uds.split('/')[-1]}" |
245 | | - if self.uds else f"http://{self.host}:{self.port}") |
246 | | - |
247 | | - def url_for(self, *parts: str) -> str: |
248 | | - return self.url_root + "/" + "/".join(parts) |
249 | | - |
250 | | - def get_client(self, **kwargs): |
251 | | - if "timeout" not in kwargs: |
252 | | - kwargs["timeout"] = 600 |
253 | | - return openai.OpenAI( |
254 | | - base_url=self.url_for("v1"), |
255 | | - api_key=self.DUMMY_API_KEY, |
256 | | - max_retries=0, |
257 | | - **kwargs, |
258 | | - ) |
259 | | - |
260 | | - def get_async_client(self, **kwargs): |
261 | | - if "timeout" not in kwargs: |
262 | | - kwargs["timeout"] = 600 |
263 | | - return openai.AsyncOpenAI(base_url=self.url_for("v1"), |
264 | | - api_key=self.DUMMY_API_KEY, |
265 | | - max_retries=0, |
266 | | - **kwargs) |
0 commit comments