yusufs commited on
Commit
147b3a2
·
1 Parent(s): b44271e

feat(openai): VLLM OpenAI compatible server

Browse files
Files changed (4) hide show
  1. Dockerfile +25 -1
  2. openai/README.md +26 -0
  3. openai/__init__.py +0 -0
  4. openai/api_server.py +643 -0
Dockerfile CHANGED
@@ -12,4 +12,28 @@ RUN pip install --no-cache-dir -r requirements.txt --extra-index-url https://dow
12
  COPY --chown=user . /app
13
 
14
  EXPOSE 7860
15
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  COPY --chown=user . /app
13
 
14
  EXPOSE 7860
15
+
16
+ #CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
17
+
18
+ CMD [
19
+ "python3",
20
+ "-u",
21
+ "openai/api_server.py",
22
+ "--model",
23
+ "meta-llama/Llama-3.2-3B-Instruct",
24
+ "--revision",
25
+ "0cb88a4f764b7a12671c53f0838cd831a0843b95",
26
+ "--host",
27
+ "0.0.0.0",
28
+ "--port",
29
+ "7860",
30
+ "--max-num-batched-tokens",
31
+ "32768",
32
+ "--max-model-len",
33
+ "32768",
34
+ "--dtype",
35
+ "half",
36
+ "--enforce-eager",
37
+ "--gpu-memory-utilization",
38
+ "0.85"
39
+ ]
openai/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VLLM OpenAI Compatible API Server
2
+
3
+ > References: https://huggingface.co/spaces/sofianhw/ai/tree/c6527a750644a849b6705bb6fe2fcea4e54a8196
4
+
5
+ This `api_server.py` file is exact copy version from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/entrypoints/openai/api_server.py
6
+
7
+ * The `HUGGING_FACE_HUB_TOKEN` must exist during runtime.
8
+
9
+ ## Documentation about config
10
+
11
+ * https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/utils.py#L1207-L1221
12
+
13
+ ```shell
14
+ "serve,chat,complete",
15
+ "facebook/opt-12B",
16
+ '--config', 'config.yaml',
17
+ '-tp', '2'
18
+ ```
19
+
20
+ The yaml is equivalent with argument flag params. Consider passing using flag params that defined here for better documentation:
21
+ https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/entrypoints/openai/cli_args.py#L77-L237
22
+
23
+ Other arguments is the same as LLM class such as `--max-model-len`, `--dtype`, or `--otlp-traces-endpoint`
24
+ * https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L1061-L1086
25
+ * https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/engine/arg_utils.py#L221-L913
26
+
openai/__init__.py ADDED
File without changes
openai/api_server.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import importlib
3
+ import inspect
4
+ import multiprocessing
5
+ import os
6
+ import re
7
+ import signal
8
+ import socket
9
+ import tempfile
10
+ import uuid
11
+ from argparse import Namespace
12
+ from contextlib import asynccontextmanager
13
+ from functools import partial
14
+ from http import HTTPStatus
15
+ from typing import AsyncIterator, Optional, Set, Tuple
16
+
17
+ import uvloop
18
+ from fastapi import APIRouter, FastAPI, Request
19
+ from fastapi.exceptions import RequestValidationError
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from fastapi.responses import JSONResponse, Response, StreamingResponse
22
+ from starlette.datastructures import State
23
+ from starlette.routing import Mount
24
+ from typing_extensions import assert_never
25
+
26
+ import vllm.envs as envs
27
+ from vllm.config import ModelConfig
28
+ from vllm.engine.arg_utils import AsyncEngineArgs
29
+ from vllm.engine.multiprocessing.client import MQLLMEngineClient
30
+ from vllm.engine.multiprocessing.engine import run_mp_engine
31
+ from vllm.engine.protocol import EngineClient
32
+ from vllm.entrypoints.launcher import serve_http
33
+ from vllm.entrypoints.logger import RequestLogger
34
+ from vllm.entrypoints.openai.cli_args import (make_arg_parser,
35
+ validate_parsed_serve_args)
36
+ # yapf conflicts with isort for this block
37
+ # yapf: disable
38
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
39
+ ChatCompletionResponse,
40
+ CompletionRequest,
41
+ CompletionResponse,
42
+ DetokenizeRequest,
43
+ DetokenizeResponse,
44
+ EmbeddingRequest,
45
+ EmbeddingResponse, ErrorResponse,
46
+ LoadLoraAdapterRequest,
47
+ TokenizeRequest,
48
+ TokenizeResponse,
49
+ UnloadLoraAdapterRequest)
50
+ # yapf: enable
51
+ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
52
+ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
53
+ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
54
+ from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
55
+ from vllm.entrypoints.openai.serving_tokenization import (
56
+ OpenAIServingTokenization)
57
+ from vllm.entrypoints.openai.tool_parsers import ToolParserManager
58
+ from vllm.logger import init_logger
59
+ from vllm.usage.usage_lib import UsageContext
60
+ from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
61
+ is_valid_ipv6_address)
62
+ from vllm.version import __version__ as VLLM_VERSION
63
+
64
+ if envs.VLLM_USE_V1:
65
+ from vllm.v1.engine.async_llm import AsyncLLMEngine # type: ignore
66
+ else:
67
+ from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
68
+
69
+ TIMEOUT_KEEP_ALIVE = 5 # seconds
70
+
71
+ prometheus_multiproc_dir: tempfile.TemporaryDirectory
72
+
73
+ # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
74
+ logger = init_logger('vllm.entrypoints.openai.api_server')
75
+
76
+ _running_tasks: Set[asyncio.Task] = set()
77
+
78
+
79
+ @asynccontextmanager
80
+ async def lifespan(app: FastAPI):
81
+ try:
82
+ if app.state.log_stats:
83
+ engine_client: EngineClient = app.state.engine_client
84
+
85
+ async def _force_log():
86
+ while True:
87
+ await asyncio.sleep(10.)
88
+ await engine_client.do_log_stats()
89
+
90
+ task = asyncio.create_task(_force_log())
91
+ _running_tasks.add(task)
92
+ task.add_done_callback(_running_tasks.remove)
93
+ else:
94
+ task = None
95
+ try:
96
+ yield
97
+ finally:
98
+ if task is not None:
99
+ task.cancel()
100
+ finally:
101
+ # Ensure app state including engine ref is gc'd
102
+ del app.state
103
+
104
+
105
+ @asynccontextmanager
106
+ async def build_async_engine_client(
107
+ args: Namespace) -> AsyncIterator[EngineClient]:
108
+
109
+ # Context manager to handle engine_client lifecycle
110
+ # Ensures everything is shutdown and cleaned up on error/exit
111
+ engine_args = AsyncEngineArgs.from_cli_args(args)
112
+
113
+ async with build_async_engine_client_from_engine_args(
114
+ engine_args, args.disable_frontend_multiprocessing) as engine:
115
+ yield engine
116
+
117
+
118
+ @asynccontextmanager
119
+ async def build_async_engine_client_from_engine_args(
120
+ engine_args: AsyncEngineArgs,
121
+ disable_frontend_multiprocessing: bool = False,
122
+ ) -> AsyncIterator[EngineClient]:
123
+ """
124
+ Create EngineClient, either:
125
+ - in-process using the AsyncLLMEngine Directly
126
+ - multiprocess using AsyncLLMEngine RPC
127
+
128
+ Returns the Client or None if the creation failed.
129
+ """
130
+
131
+ # Fall back
132
+ # TODO: fill out feature matrix.
133
+ if (MQLLMEngineClient.is_unsupported_config(engine_args)
134
+ or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
135
+
136
+ engine_config = engine_args.create_engine_config()
137
+ uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
138
+ "uses_ray", False)
139
+
140
+ build_engine = partial(AsyncLLMEngine.from_engine_args,
141
+ engine_args=engine_args,
142
+ engine_config=engine_config,
143
+ usage_context=UsageContext.OPENAI_API_SERVER)
144
+ if uses_ray:
145
+ # Must run in main thread with ray for its signal handlers to work
146
+ engine_client = build_engine()
147
+ else:
148
+ engine_client = await asyncio.get_running_loop().run_in_executor(
149
+ None, build_engine)
150
+
151
+ yield engine_client
152
+ if hasattr(engine_client, "shutdown"):
153
+ engine_client.shutdown()
154
+ return
155
+
156
+ # Otherwise, use the multiprocessing AsyncLLMEngine.
157
+ else:
158
+ if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
159
+ # Make TemporaryDirectory for prometheus multiprocessing
160
+ # Note: global TemporaryDirectory will be automatically
161
+ # cleaned up upon exit.
162
+ global prometheus_multiproc_dir
163
+ prometheus_multiproc_dir = tempfile.TemporaryDirectory()
164
+ os.environ[
165
+ "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
166
+ else:
167
+ logger.warning(
168
+ "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
169
+ "This directory must be wiped between vLLM runs or "
170
+ "you will find inaccurate metrics. Unset the variable "
171
+ "and vLLM will properly handle cleanup.")
172
+
173
+ # Select random path for IPC.
174
+ ipc_path = get_open_zmq_ipc_path()
175
+ logger.info("Multiprocessing frontend to use %s for IPC Path.",
176
+ ipc_path)
177
+
178
+ # Start RPCServer in separate process (holds the LLMEngine).
179
+ # the current process might have CUDA context,
180
+ # so we need to spawn a new process
181
+ context = multiprocessing.get_context("spawn")
182
+
183
+ # The Process can raise an exception during startup, which may
184
+ # not actually result in an exitcode being reported. As a result
185
+ # we use a shared variable to communicate the information.
186
+ engine_alive = multiprocessing.Value('b', True, lock=False)
187
+ engine_process = context.Process(target=run_mp_engine,
188
+ args=(engine_args,
189
+ UsageContext.OPENAI_API_SERVER,
190
+ ipc_path, engine_alive))
191
+ engine_process.start()
192
+ engine_pid = engine_process.pid
193
+ assert engine_pid is not None, "Engine process failed to start."
194
+ logger.info("Started engine process with PID %d", engine_pid)
195
+
196
+ # Build RPCClient, which conforms to EngineClient Protocol.
197
+ engine_config = engine_args.create_engine_config()
198
+ build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
199
+ engine_pid)
200
+ mq_engine_client = await asyncio.get_running_loop().run_in_executor(
201
+ None, build_client)
202
+ try:
203
+ while True:
204
+ try:
205
+ await mq_engine_client.setup()
206
+ break
207
+ except TimeoutError:
208
+ if (not engine_process.is_alive()
209
+ or not engine_alive.value):
210
+ raise RuntimeError(
211
+ "Engine process failed to start. See stack "
212
+ "trace for the root cause.") from None
213
+
214
+ yield mq_engine_client # type: ignore[misc]
215
+ finally:
216
+ # Ensure rpc server process was terminated
217
+ engine_process.terminate()
218
+
219
+ # Close all open connections to the backend
220
+ mq_engine_client.close()
221
+
222
+ # Wait for engine process to join
223
+ engine_process.join(4)
224
+ if engine_process.exitcode is None:
225
+ # Kill if taking longer than 5 seconds to stop
226
+ engine_process.kill()
227
+
228
+ # Lazy import for prometheus multiprocessing.
229
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
230
+ # before prometheus_client is imported.
231
+ # See https://prometheus.github.io/client_python/multiprocess/
232
+ from prometheus_client import multiprocess
233
+ multiprocess.mark_process_dead(engine_process.pid)
234
+
235
+
236
+ router = APIRouter()
237
+
238
+
239
+ def mount_metrics(app: FastAPI):
240
+ # Lazy import for prometheus multiprocessing.
241
+ # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
242
+ # before prometheus_client is imported.
243
+ # See https://prometheus.github.io/client_python/multiprocess/
244
+ from prometheus_client import (CollectorRegistry, make_asgi_app,
245
+ multiprocess)
246
+
247
+ prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
248
+ if prometheus_multiproc_dir_path is not None:
249
+ logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
250
+ prometheus_multiproc_dir_path)
251
+ registry = CollectorRegistry()
252
+ multiprocess.MultiProcessCollector(registry)
253
+
254
+ # Add prometheus asgi middleware to route /metrics requests
255
+ metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
256
+ else:
257
+ # Add prometheus asgi middleware to route /metrics requests
258
+ metrics_route = Mount("/metrics", make_asgi_app())
259
+
260
+ # Workaround for 307 Redirect for /metrics
261
+ metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
262
+ app.routes.append(metrics_route)
263
+
264
+
265
+ def base(request: Request) -> OpenAIServing:
266
+ # Reuse the existing instance
267
+ return tokenization(request)
268
+
269
+
270
+ def chat(request: Request) -> Optional[OpenAIServingChat]:
271
+ return request.app.state.openai_serving_chat
272
+
273
+
274
+ def completion(request: Request) -> Optional[OpenAIServingCompletion]:
275
+ return request.app.state.openai_serving_completion
276
+
277
+
278
+ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
279
+ return request.app.state.openai_serving_embedding
280
+
281
+
282
+ def tokenization(request: Request) -> OpenAIServingTokenization:
283
+ return request.app.state.openai_serving_tokenization
284
+
285
+
286
+ def engine_client(request: Request) -> EngineClient:
287
+ return request.app.state.engine_client
288
+
289
+
290
+ @router.get("/health")
291
+ async def health(raw_request: Request) -> Response:
292
+ """Health check."""
293
+ await engine_client(raw_request).check_health()
294
+ return Response(status_code=200)
295
+
296
+
297
+ @router.post("/tokenize")
298
+ async def tokenize(request: TokenizeRequest, raw_request: Request):
299
+ handler = tokenization(raw_request)
300
+
301
+ generator = await handler.create_tokenize(request)
302
+ if isinstance(generator, ErrorResponse):
303
+ return JSONResponse(content=generator.model_dump(),
304
+ status_code=generator.code)
305
+ elif isinstance(generator, TokenizeResponse):
306
+ return JSONResponse(content=generator.model_dump())
307
+
308
+ assert_never(generator)
309
+
310
+
311
+ @router.post("/detokenize")
312
+ async def detokenize(request: DetokenizeRequest, raw_request: Request):
313
+ handler = tokenization(raw_request)
314
+
315
+ generator = await handler.create_detokenize(request)
316
+ if isinstance(generator, ErrorResponse):
317
+ return JSONResponse(content=generator.model_dump(),
318
+ status_code=generator.code)
319
+ elif isinstance(generator, DetokenizeResponse):
320
+ return JSONResponse(content=generator.model_dump())
321
+
322
+ assert_never(generator)
323
+
324
+
325
+ @router.get("/v1/models")
326
+ async def show_available_models(raw_request: Request):
327
+ handler = base(raw_request)
328
+
329
+ models = await handler.show_available_models()
330
+ return JSONResponse(content=models.model_dump())
331
+
332
+
333
+ @router.get("/version")
334
+ async def show_version():
335
+ ver = {"version": VLLM_VERSION}
336
+ return JSONResponse(content=ver)
337
+
338
+
339
+ @router.post("/v1/chat/completions")
340
+ async def create_chat_completion(request: ChatCompletionRequest,
341
+ raw_request: Request):
342
+ handler = chat(raw_request)
343
+ if handler is None:
344
+ return base(raw_request).create_error_response(
345
+ message="The model does not support Chat Completions API")
346
+
347
+ generator = await handler.create_chat_completion(request, raw_request)
348
+
349
+ if isinstance(generator, ErrorResponse):
350
+ return JSONResponse(content=generator.model_dump(),
351
+ status_code=generator.code)
352
+
353
+ elif isinstance(generator, ChatCompletionResponse):
354
+ return JSONResponse(content=generator.model_dump())
355
+
356
+ return StreamingResponse(content=generator, media_type="text/event-stream")
357
+
358
+
359
+ @router.post("/v1/completions")
360
+ async def create_completion(request: CompletionRequest, raw_request: Request):
361
+ handler = completion(raw_request)
362
+ if handler is None:
363
+ return base(raw_request).create_error_response(
364
+ message="The model does not support Completions API")
365
+
366
+ generator = await handler.create_completion(request, raw_request)
367
+ if isinstance(generator, ErrorResponse):
368
+ return JSONResponse(content=generator.model_dump(),
369
+ status_code=generator.code)
370
+ elif isinstance(generator, CompletionResponse):
371
+ return JSONResponse(content=generator.model_dump())
372
+
373
+ return StreamingResponse(content=generator, media_type="text/event-stream")
374
+
375
+
376
+ @router.post("/v1/embeddings")
377
+ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
378
+ handler = embedding(raw_request)
379
+ if handler is None:
380
+ return base(raw_request).create_error_response(
381
+ message="The model does not support Embeddings API")
382
+
383
+ generator = await handler.create_embedding(request, raw_request)
384
+ if isinstance(generator, ErrorResponse):
385
+ return JSONResponse(content=generator.model_dump(),
386
+ status_code=generator.code)
387
+ elif isinstance(generator, EmbeddingResponse):
388
+ return JSONResponse(content=generator.model_dump())
389
+
390
+ assert_never(generator)
391
+
392
+
393
+ if envs.VLLM_TORCH_PROFILER_DIR:
394
+ logger.warning(
395
+ "Torch Profiler is enabled in the API server. This should ONLY be "
396
+ "used for local development!")
397
+
398
+ @router.post("/start_profile")
399
+ async def start_profile(raw_request: Request):
400
+ logger.info("Starting profiler...")
401
+ await engine_client(raw_request).start_profile()
402
+ logger.info("Profiler started.")
403
+ return Response(status_code=200)
404
+
405
+ @router.post("/stop_profile")
406
+ async def stop_profile(raw_request: Request):
407
+ logger.info("Stopping profiler...")
408
+ await engine_client(raw_request).stop_profile()
409
+ logger.info("Profiler stopped.")
410
+ return Response(status_code=200)
411
+
412
+
413
+ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
414
+ logger.warning(
415
+ "Lora dynamic loading & unloading is enabled in the API server. "
416
+ "This should ONLY be used for local development!")
417
+
418
+ @router.post("/v1/load_lora_adapter")
419
+ async def load_lora_adapter(request: LoadLoraAdapterRequest,
420
+ raw_request: Request):
421
+ for route in [chat, completion, embedding]:
422
+ handler = route(raw_request)
423
+ if handler is not None:
424
+ response = await handler.load_lora_adapter(request)
425
+ if isinstance(response, ErrorResponse):
426
+ return JSONResponse(content=response.model_dump(),
427
+ status_code=response.code)
428
+
429
+ return Response(status_code=200, content=response)
430
+
431
+ @router.post("/v1/unload_lora_adapter")
432
+ async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
433
+ raw_request: Request):
434
+ for route in [chat, completion, embedding]:
435
+ handler = route(raw_request)
436
+ if handler is not None:
437
+ response = await handler.unload_lora_adapter(request)
438
+ if isinstance(response, ErrorResponse):
439
+ return JSONResponse(content=response.model_dump(),
440
+ status_code=response.code)
441
+
442
+ return Response(status_code=200, content=response)
443
+
444
+
445
+ def build_app(args: Namespace) -> FastAPI:
446
+ if args.disable_fastapi_docs:
447
+ app = FastAPI(openapi_url=None,
448
+ docs_url=None,
449
+ redoc_url=None,
450
+ lifespan=lifespan)
451
+ else:
452
+ app = FastAPI(lifespan=lifespan)
453
+ app.include_router(router)
454
+ app.root_path = args.root_path
455
+
456
+ mount_metrics(app)
457
+
458
+ app.add_middleware(
459
+ CORSMiddleware,
460
+ allow_origins=args.allowed_origins,
461
+ allow_credentials=args.allow_credentials,
462
+ allow_methods=args.allowed_methods,
463
+ allow_headers=args.allowed_headers,
464
+ )
465
+
466
+ @app.exception_handler(RequestValidationError)
467
+ async def validation_exception_handler(_, exc):
468
+ chat = app.state.openai_serving_chat
469
+ err = chat.create_error_response(message=str(exc))
470
+ return JSONResponse(err.model_dump(),
471
+ status_code=HTTPStatus.BAD_REQUEST)
472
+
473
+ if token := envs.VLLM_API_KEY or args.api_key:
474
+
475
+ @app.middleware("http")
476
+ async def authentication(request: Request, call_next):
477
+ root_path = "" if args.root_path is None else args.root_path
478
+ if request.method == "OPTIONS":
479
+ return await call_next(request)
480
+ if not request.url.path.startswith(f"{root_path}/v1"):
481
+ return await call_next(request)
482
+ if request.headers.get("Authorization") != "Bearer " + token:
483
+ return JSONResponse(content={"error": "Unauthorized"},
484
+ status_code=401)
485
+ return await call_next(request)
486
+
487
+ @app.middleware("http")
488
+ async def add_request_id(request: Request, call_next):
489
+ request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex
490
+ response = await call_next(request)
491
+ response.headers["X-Request-Id"] = request_id
492
+ return response
493
+
494
+ for middleware in args.middleware:
495
+ module_path, object_name = middleware.rsplit(".", 1)
496
+ imported = getattr(importlib.import_module(module_path), object_name)
497
+ if inspect.isclass(imported):
498
+ app.add_middleware(imported)
499
+ elif inspect.iscoroutinefunction(imported):
500
+ app.middleware("http")(imported)
501
+ else:
502
+ raise ValueError(f"Invalid middleware {middleware}. "
503
+ f"Must be a function or a class.")
504
+
505
+ return app
506
+
507
+
508
+ def init_app_state(
509
+ engine_client: EngineClient,
510
+ model_config: ModelConfig,
511
+ state: State,
512
+ args: Namespace,
513
+ ) -> None:
514
+ if args.served_model_name is not None:
515
+ served_model_names = args.served_model_name
516
+ else:
517
+ served_model_names = [args.model]
518
+
519
+ if args.disable_log_requests:
520
+ request_logger = None
521
+ else:
522
+ request_logger = RequestLogger(max_log_len=args.max_log_len)
523
+
524
+ base_model_paths = [
525
+ BaseModelPath(name=name, model_path=args.model)
526
+ for name in served_model_names
527
+ ]
528
+
529
+ state.engine_client = engine_client
530
+ state.log_stats = not args.disable_log_stats
531
+
532
+ state.openai_serving_chat = OpenAIServingChat(
533
+ engine_client,
534
+ model_config,
535
+ base_model_paths,
536
+ args.response_role,
537
+ lora_modules=args.lora_modules,
538
+ prompt_adapters=args.prompt_adapters,
539
+ request_logger=request_logger,
540
+ chat_template=args.chat_template,
541
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
542
+ enable_auto_tools=args.enable_auto_tool_choice,
543
+ tool_parser=args.tool_call_parser,
544
+ enable_prompt_tokens_details=args.enable_prompt_tokens_details,
545
+ ) if model_config.task == "generate" else None
546
+ state.openai_serving_completion = OpenAIServingCompletion(
547
+ engine_client,
548
+ model_config,
549
+ base_model_paths,
550
+ lora_modules=args.lora_modules,
551
+ prompt_adapters=args.prompt_adapters,
552
+ request_logger=request_logger,
553
+ return_tokens_as_token_ids=args.return_tokens_as_token_ids,
554
+ ) if model_config.task == "generate" else None
555
+ state.openai_serving_embedding = OpenAIServingEmbedding(
556
+ engine_client,
557
+ model_config,
558
+ base_model_paths,
559
+ request_logger=request_logger,
560
+ chat_template=args.chat_template,
561
+ ) if model_config.task == "embedding" else None
562
+ state.openai_serving_tokenization = OpenAIServingTokenization(
563
+ engine_client,
564
+ model_config,
565
+ base_model_paths,
566
+ lora_modules=args.lora_modules,
567
+ request_logger=request_logger,
568
+ chat_template=args.chat_template,
569
+ )
570
+
571
+
572
+ def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
573
+ family = socket.AF_INET
574
+ if is_valid_ipv6_address(addr[0]):
575
+ family = socket.AF_INET6
576
+
577
+ sock = socket.socket(family=family, type=socket.SOCK_STREAM)
578
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
579
+ sock.bind(addr)
580
+
581
+ return sock
582
+
583
+
584
+ async def run_server(args, **uvicorn_kwargs) -> None:
585
+ logger.info("vLLM API server version %s", VLLM_VERSION)
586
+ logger.info("args: %s", args)
587
+
588
+ if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
589
+ ToolParserManager.import_tool_parser(args.tool_parser_plugin)
590
+
591
+ valide_tool_parses = ToolParserManager.tool_parsers.keys()
592
+ if args.enable_auto_tool_choice \
593
+ and args.tool_call_parser not in valide_tool_parses:
594
+ raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
595
+ f"(chose from {{ {','.join(valide_tool_parses)} }})")
596
+
597
+ # workaround to make sure that we bind the port before the engine is set up.
598
+ # This avoids race conditions with ray.
599
+ # see https://github.com/vllm-project/vllm/issues/8204
600
+ sock_addr = (args.host or "", args.port)
601
+ sock = create_server_socket(sock_addr)
602
+
603
+ def signal_handler(*_) -> None:
604
+ # Interrupt server on sigterm while initializing
605
+ raise KeyboardInterrupt("terminated")
606
+
607
+ signal.signal(signal.SIGTERM, signal_handler)
608
+
609
+ async with build_async_engine_client(args) as engine_client:
610
+ app = build_app(args)
611
+
612
+ model_config = await engine_client.get_model_config()
613
+ init_app_state(engine_client, model_config, app.state, args)
614
+
615
+ shutdown_task = await serve_http(
616
+ app,
617
+ host=args.host,
618
+ port=args.port,
619
+ log_level=args.uvicorn_log_level,
620
+ timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
621
+ ssl_keyfile=args.ssl_keyfile,
622
+ ssl_certfile=args.ssl_certfile,
623
+ ssl_ca_certs=args.ssl_ca_certs,
624
+ ssl_cert_reqs=args.ssl_cert_reqs,
625
+ **uvicorn_kwargs,
626
+ )
627
+
628
+ # NB: Await server shutdown only after the backend context is exited
629
+ await shutdown_task
630
+
631
+ sock.close()
632
+
633
+
634
+ if __name__ == "__main__":
635
+ # NOTE(simon):
636
+ # This section should be in sync with vllm/scripts.py for CLI entrypoints.
637
+ parser = FlexibleArgumentParser(
638
+ description="vLLM OpenAI-Compatible RESTful API server.")
639
+ parser = make_arg_parser(parser)
640
+ args = parser.parse_args()
641
+ validate_parsed_serve_args(args)
642
+
643
+ uvloop.run(run_server(args))