Spaces:
Paused
Paused
feat(openai): VLLM OpenAI compatible server
Browse files- Dockerfile +25 -1
- openai/README.md +26 -0
- openai/__init__.py +0 -0
- openai/api_server.py +643 -0
Dockerfile
CHANGED
@@ -12,4 +12,28 @@ RUN pip install --no-cache-dir -r requirements.txt --extra-index-url https://dow
|
|
12 |
COPY --chown=user . /app
|
13 |
|
14 |
EXPOSE 7860
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
COPY --chown=user . /app
|
13 |
|
14 |
EXPOSE 7860
|
15 |
+
|
16 |
+
#CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
17 |
+
|
18 |
+
CMD [
|
19 |
+
"python3",
|
20 |
+
"-u",
|
21 |
+
"openai/api_server.py",
|
22 |
+
"--model",
|
23 |
+
"meta-llama/Llama-3.2-3B-Instruct",
|
24 |
+
"--revision",
|
25 |
+
"0cb88a4f764b7a12671c53f0838cd831a0843b95",
|
26 |
+
"--host",
|
27 |
+
"0.0.0.0",
|
28 |
+
"--port",
|
29 |
+
"7860",
|
30 |
+
"--max-num-batched-tokens",
|
31 |
+
"32768",
|
32 |
+
"--max-model-len",
|
33 |
+
"32768",
|
34 |
+
"--dtype",
|
35 |
+
"half",
|
36 |
+
"--enforce-eager",
|
37 |
+
"--gpu-memory-utilization",
|
38 |
+
"0.85"
|
39 |
+
]
|
openai/README.md
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VLLM OpenAI Compatible API Server
|
2 |
+
|
3 |
+
> References: https://huggingface.co/spaces/sofianhw/ai/tree/c6527a750644a849b6705bb6fe2fcea4e54a8196
|
4 |
+
|
5 |
+
This `api_server.py` file is exact copy version from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/entrypoints/openai/api_server.py
|
6 |
+
|
7 |
+
* The `HUGGING_FACE_HUB_TOKEN` must exist during runtime.
|
8 |
+
|
9 |
+
## Documentation about config
|
10 |
+
|
11 |
+
* https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/utils.py#L1207-L1221
|
12 |
+
|
13 |
+
```shell
|
14 |
+
"serve,chat,complete",
|
15 |
+
"facebook/opt-12B",
|
16 |
+
'--config', 'config.yaml',
|
17 |
+
'-tp', '2'
|
18 |
+
```
|
19 |
+
|
20 |
+
The yaml is equivalent with argument flag params. Consider passing using flag params that defined here for better documentation:
|
21 |
+
https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/entrypoints/openai/cli_args.py#L77-L237
|
22 |
+
|
23 |
+
Other arguments is the same as LLM class such as `--max-model-len`, `--dtype`, or `--otlp-traces-endpoint`
|
24 |
+
* https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L1061-L1086
|
25 |
+
* https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/engine/arg_utils.py#L221-L913
|
26 |
+
|
openai/__init__.py
ADDED
File without changes
|
openai/api_server.py
ADDED
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import importlib
|
3 |
+
import inspect
|
4 |
+
import multiprocessing
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import signal
|
8 |
+
import socket
|
9 |
+
import tempfile
|
10 |
+
import uuid
|
11 |
+
from argparse import Namespace
|
12 |
+
from contextlib import asynccontextmanager
|
13 |
+
from functools import partial
|
14 |
+
from http import HTTPStatus
|
15 |
+
from typing import AsyncIterator, Optional, Set, Tuple
|
16 |
+
|
17 |
+
import uvloop
|
18 |
+
from fastapi import APIRouter, FastAPI, Request
|
19 |
+
from fastapi.exceptions import RequestValidationError
|
20 |
+
from fastapi.middleware.cors import CORSMiddleware
|
21 |
+
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
22 |
+
from starlette.datastructures import State
|
23 |
+
from starlette.routing import Mount
|
24 |
+
from typing_extensions import assert_never
|
25 |
+
|
26 |
+
import vllm.envs as envs
|
27 |
+
from vllm.config import ModelConfig
|
28 |
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
29 |
+
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
30 |
+
from vllm.engine.multiprocessing.engine import run_mp_engine
|
31 |
+
from vllm.engine.protocol import EngineClient
|
32 |
+
from vllm.entrypoints.launcher import serve_http
|
33 |
+
from vllm.entrypoints.logger import RequestLogger
|
34 |
+
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
35 |
+
validate_parsed_serve_args)
|
36 |
+
# yapf conflicts with isort for this block
|
37 |
+
# yapf: disable
|
38 |
+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
39 |
+
ChatCompletionResponse,
|
40 |
+
CompletionRequest,
|
41 |
+
CompletionResponse,
|
42 |
+
DetokenizeRequest,
|
43 |
+
DetokenizeResponse,
|
44 |
+
EmbeddingRequest,
|
45 |
+
EmbeddingResponse, ErrorResponse,
|
46 |
+
LoadLoraAdapterRequest,
|
47 |
+
TokenizeRequest,
|
48 |
+
TokenizeResponse,
|
49 |
+
UnloadLoraAdapterRequest)
|
50 |
+
# yapf: enable
|
51 |
+
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
52 |
+
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
53 |
+
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
54 |
+
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
55 |
+
from vllm.entrypoints.openai.serving_tokenization import (
|
56 |
+
OpenAIServingTokenization)
|
57 |
+
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
58 |
+
from vllm.logger import init_logger
|
59 |
+
from vllm.usage.usage_lib import UsageContext
|
60 |
+
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
61 |
+
is_valid_ipv6_address)
|
62 |
+
from vllm.version import __version__ as VLLM_VERSION
|
63 |
+
|
64 |
+
if envs.VLLM_USE_V1:
|
65 |
+
from vllm.v1.engine.async_llm import AsyncLLMEngine # type: ignore
|
66 |
+
else:
|
67 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
|
68 |
+
|
69 |
+
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
70 |
+
|
71 |
+
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
72 |
+
|
73 |
+
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
74 |
+
logger = init_logger('vllm.entrypoints.openai.api_server')
|
75 |
+
|
76 |
+
_running_tasks: Set[asyncio.Task] = set()
|
77 |
+
|
78 |
+
|
79 |
+
@asynccontextmanager
|
80 |
+
async def lifespan(app: FastAPI):
|
81 |
+
try:
|
82 |
+
if app.state.log_stats:
|
83 |
+
engine_client: EngineClient = app.state.engine_client
|
84 |
+
|
85 |
+
async def _force_log():
|
86 |
+
while True:
|
87 |
+
await asyncio.sleep(10.)
|
88 |
+
await engine_client.do_log_stats()
|
89 |
+
|
90 |
+
task = asyncio.create_task(_force_log())
|
91 |
+
_running_tasks.add(task)
|
92 |
+
task.add_done_callback(_running_tasks.remove)
|
93 |
+
else:
|
94 |
+
task = None
|
95 |
+
try:
|
96 |
+
yield
|
97 |
+
finally:
|
98 |
+
if task is not None:
|
99 |
+
task.cancel()
|
100 |
+
finally:
|
101 |
+
# Ensure app state including engine ref is gc'd
|
102 |
+
del app.state
|
103 |
+
|
104 |
+
|
105 |
+
@asynccontextmanager
|
106 |
+
async def build_async_engine_client(
|
107 |
+
args: Namespace) -> AsyncIterator[EngineClient]:
|
108 |
+
|
109 |
+
# Context manager to handle engine_client lifecycle
|
110 |
+
# Ensures everything is shutdown and cleaned up on error/exit
|
111 |
+
engine_args = AsyncEngineArgs.from_cli_args(args)
|
112 |
+
|
113 |
+
async with build_async_engine_client_from_engine_args(
|
114 |
+
engine_args, args.disable_frontend_multiprocessing) as engine:
|
115 |
+
yield engine
|
116 |
+
|
117 |
+
|
118 |
+
@asynccontextmanager
|
119 |
+
async def build_async_engine_client_from_engine_args(
|
120 |
+
engine_args: AsyncEngineArgs,
|
121 |
+
disable_frontend_multiprocessing: bool = False,
|
122 |
+
) -> AsyncIterator[EngineClient]:
|
123 |
+
"""
|
124 |
+
Create EngineClient, either:
|
125 |
+
- in-process using the AsyncLLMEngine Directly
|
126 |
+
- multiprocess using AsyncLLMEngine RPC
|
127 |
+
|
128 |
+
Returns the Client or None if the creation failed.
|
129 |
+
"""
|
130 |
+
|
131 |
+
# Fall back
|
132 |
+
# TODO: fill out feature matrix.
|
133 |
+
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
134 |
+
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
|
135 |
+
|
136 |
+
engine_config = engine_args.create_engine_config()
|
137 |
+
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
138 |
+
"uses_ray", False)
|
139 |
+
|
140 |
+
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
141 |
+
engine_args=engine_args,
|
142 |
+
engine_config=engine_config,
|
143 |
+
usage_context=UsageContext.OPENAI_API_SERVER)
|
144 |
+
if uses_ray:
|
145 |
+
# Must run in main thread with ray for its signal handlers to work
|
146 |
+
engine_client = build_engine()
|
147 |
+
else:
|
148 |
+
engine_client = await asyncio.get_running_loop().run_in_executor(
|
149 |
+
None, build_engine)
|
150 |
+
|
151 |
+
yield engine_client
|
152 |
+
if hasattr(engine_client, "shutdown"):
|
153 |
+
engine_client.shutdown()
|
154 |
+
return
|
155 |
+
|
156 |
+
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
157 |
+
else:
|
158 |
+
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
159 |
+
# Make TemporaryDirectory for prometheus multiprocessing
|
160 |
+
# Note: global TemporaryDirectory will be automatically
|
161 |
+
# cleaned up upon exit.
|
162 |
+
global prometheus_multiproc_dir
|
163 |
+
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
164 |
+
os.environ[
|
165 |
+
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
166 |
+
else:
|
167 |
+
logger.warning(
|
168 |
+
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
169 |
+
"This directory must be wiped between vLLM runs or "
|
170 |
+
"you will find inaccurate metrics. Unset the variable "
|
171 |
+
"and vLLM will properly handle cleanup.")
|
172 |
+
|
173 |
+
# Select random path for IPC.
|
174 |
+
ipc_path = get_open_zmq_ipc_path()
|
175 |
+
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
176 |
+
ipc_path)
|
177 |
+
|
178 |
+
# Start RPCServer in separate process (holds the LLMEngine).
|
179 |
+
# the current process might have CUDA context,
|
180 |
+
# so we need to spawn a new process
|
181 |
+
context = multiprocessing.get_context("spawn")
|
182 |
+
|
183 |
+
# The Process can raise an exception during startup, which may
|
184 |
+
# not actually result in an exitcode being reported. As a result
|
185 |
+
# we use a shared variable to communicate the information.
|
186 |
+
engine_alive = multiprocessing.Value('b', True, lock=False)
|
187 |
+
engine_process = context.Process(target=run_mp_engine,
|
188 |
+
args=(engine_args,
|
189 |
+
UsageContext.OPENAI_API_SERVER,
|
190 |
+
ipc_path, engine_alive))
|
191 |
+
engine_process.start()
|
192 |
+
engine_pid = engine_process.pid
|
193 |
+
assert engine_pid is not None, "Engine process failed to start."
|
194 |
+
logger.info("Started engine process with PID %d", engine_pid)
|
195 |
+
|
196 |
+
# Build RPCClient, which conforms to EngineClient Protocol.
|
197 |
+
engine_config = engine_args.create_engine_config()
|
198 |
+
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
|
199 |
+
engine_pid)
|
200 |
+
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
|
201 |
+
None, build_client)
|
202 |
+
try:
|
203 |
+
while True:
|
204 |
+
try:
|
205 |
+
await mq_engine_client.setup()
|
206 |
+
break
|
207 |
+
except TimeoutError:
|
208 |
+
if (not engine_process.is_alive()
|
209 |
+
or not engine_alive.value):
|
210 |
+
raise RuntimeError(
|
211 |
+
"Engine process failed to start. See stack "
|
212 |
+
"trace for the root cause.") from None
|
213 |
+
|
214 |
+
yield mq_engine_client # type: ignore[misc]
|
215 |
+
finally:
|
216 |
+
# Ensure rpc server process was terminated
|
217 |
+
engine_process.terminate()
|
218 |
+
|
219 |
+
# Close all open connections to the backend
|
220 |
+
mq_engine_client.close()
|
221 |
+
|
222 |
+
# Wait for engine process to join
|
223 |
+
engine_process.join(4)
|
224 |
+
if engine_process.exitcode is None:
|
225 |
+
# Kill if taking longer than 5 seconds to stop
|
226 |
+
engine_process.kill()
|
227 |
+
|
228 |
+
# Lazy import for prometheus multiprocessing.
|
229 |
+
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
230 |
+
# before prometheus_client is imported.
|
231 |
+
# See https://prometheus.github.io/client_python/multiprocess/
|
232 |
+
from prometheus_client import multiprocess
|
233 |
+
multiprocess.mark_process_dead(engine_process.pid)
|
234 |
+
|
235 |
+
|
236 |
+
router = APIRouter()
|
237 |
+
|
238 |
+
|
239 |
+
def mount_metrics(app: FastAPI):
|
240 |
+
# Lazy import for prometheus multiprocessing.
|
241 |
+
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
242 |
+
# before prometheus_client is imported.
|
243 |
+
# See https://prometheus.github.io/client_python/multiprocess/
|
244 |
+
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
245 |
+
multiprocess)
|
246 |
+
|
247 |
+
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
248 |
+
if prometheus_multiproc_dir_path is not None:
|
249 |
+
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
250 |
+
prometheus_multiproc_dir_path)
|
251 |
+
registry = CollectorRegistry()
|
252 |
+
multiprocess.MultiProcessCollector(registry)
|
253 |
+
|
254 |
+
# Add prometheus asgi middleware to route /metrics requests
|
255 |
+
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
256 |
+
else:
|
257 |
+
# Add prometheus asgi middleware to route /metrics requests
|
258 |
+
metrics_route = Mount("/metrics", make_asgi_app())
|
259 |
+
|
260 |
+
# Workaround for 307 Redirect for /metrics
|
261 |
+
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
262 |
+
app.routes.append(metrics_route)
|
263 |
+
|
264 |
+
|
265 |
+
def base(request: Request) -> OpenAIServing:
|
266 |
+
# Reuse the existing instance
|
267 |
+
return tokenization(request)
|
268 |
+
|
269 |
+
|
270 |
+
def chat(request: Request) -> Optional[OpenAIServingChat]:
|
271 |
+
return request.app.state.openai_serving_chat
|
272 |
+
|
273 |
+
|
274 |
+
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
|
275 |
+
return request.app.state.openai_serving_completion
|
276 |
+
|
277 |
+
|
278 |
+
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
279 |
+
return request.app.state.openai_serving_embedding
|
280 |
+
|
281 |
+
|
282 |
+
def tokenization(request: Request) -> OpenAIServingTokenization:
|
283 |
+
return request.app.state.openai_serving_tokenization
|
284 |
+
|
285 |
+
|
286 |
+
def engine_client(request: Request) -> EngineClient:
|
287 |
+
return request.app.state.engine_client
|
288 |
+
|
289 |
+
|
290 |
+
@router.get("/health")
|
291 |
+
async def health(raw_request: Request) -> Response:
|
292 |
+
"""Health check."""
|
293 |
+
await engine_client(raw_request).check_health()
|
294 |
+
return Response(status_code=200)
|
295 |
+
|
296 |
+
|
297 |
+
@router.post("/tokenize")
|
298 |
+
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
299 |
+
handler = tokenization(raw_request)
|
300 |
+
|
301 |
+
generator = await handler.create_tokenize(request)
|
302 |
+
if isinstance(generator, ErrorResponse):
|
303 |
+
return JSONResponse(content=generator.model_dump(),
|
304 |
+
status_code=generator.code)
|
305 |
+
elif isinstance(generator, TokenizeResponse):
|
306 |
+
return JSONResponse(content=generator.model_dump())
|
307 |
+
|
308 |
+
assert_never(generator)
|
309 |
+
|
310 |
+
|
311 |
+
@router.post("/detokenize")
|
312 |
+
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
313 |
+
handler = tokenization(raw_request)
|
314 |
+
|
315 |
+
generator = await handler.create_detokenize(request)
|
316 |
+
if isinstance(generator, ErrorResponse):
|
317 |
+
return JSONResponse(content=generator.model_dump(),
|
318 |
+
status_code=generator.code)
|
319 |
+
elif isinstance(generator, DetokenizeResponse):
|
320 |
+
return JSONResponse(content=generator.model_dump())
|
321 |
+
|
322 |
+
assert_never(generator)
|
323 |
+
|
324 |
+
|
325 |
+
@router.get("/v1/models")
|
326 |
+
async def show_available_models(raw_request: Request):
|
327 |
+
handler = base(raw_request)
|
328 |
+
|
329 |
+
models = await handler.show_available_models()
|
330 |
+
return JSONResponse(content=models.model_dump())
|
331 |
+
|
332 |
+
|
333 |
+
@router.get("/version")
|
334 |
+
async def show_version():
|
335 |
+
ver = {"version": VLLM_VERSION}
|
336 |
+
return JSONResponse(content=ver)
|
337 |
+
|
338 |
+
|
339 |
+
@router.post("/v1/chat/completions")
|
340 |
+
async def create_chat_completion(request: ChatCompletionRequest,
|
341 |
+
raw_request: Request):
|
342 |
+
handler = chat(raw_request)
|
343 |
+
if handler is None:
|
344 |
+
return base(raw_request).create_error_response(
|
345 |
+
message="The model does not support Chat Completions API")
|
346 |
+
|
347 |
+
generator = await handler.create_chat_completion(request, raw_request)
|
348 |
+
|
349 |
+
if isinstance(generator, ErrorResponse):
|
350 |
+
return JSONResponse(content=generator.model_dump(),
|
351 |
+
status_code=generator.code)
|
352 |
+
|
353 |
+
elif isinstance(generator, ChatCompletionResponse):
|
354 |
+
return JSONResponse(content=generator.model_dump())
|
355 |
+
|
356 |
+
return StreamingResponse(content=generator, media_type="text/event-stream")
|
357 |
+
|
358 |
+
|
359 |
+
@router.post("/v1/completions")
|
360 |
+
async def create_completion(request: CompletionRequest, raw_request: Request):
|
361 |
+
handler = completion(raw_request)
|
362 |
+
if handler is None:
|
363 |
+
return base(raw_request).create_error_response(
|
364 |
+
message="The model does not support Completions API")
|
365 |
+
|
366 |
+
generator = await handler.create_completion(request, raw_request)
|
367 |
+
if isinstance(generator, ErrorResponse):
|
368 |
+
return JSONResponse(content=generator.model_dump(),
|
369 |
+
status_code=generator.code)
|
370 |
+
elif isinstance(generator, CompletionResponse):
|
371 |
+
return JSONResponse(content=generator.model_dump())
|
372 |
+
|
373 |
+
return StreamingResponse(content=generator, media_type="text/event-stream")
|
374 |
+
|
375 |
+
|
376 |
+
@router.post("/v1/embeddings")
|
377 |
+
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
378 |
+
handler = embedding(raw_request)
|
379 |
+
if handler is None:
|
380 |
+
return base(raw_request).create_error_response(
|
381 |
+
message="The model does not support Embeddings API")
|
382 |
+
|
383 |
+
generator = await handler.create_embedding(request, raw_request)
|
384 |
+
if isinstance(generator, ErrorResponse):
|
385 |
+
return JSONResponse(content=generator.model_dump(),
|
386 |
+
status_code=generator.code)
|
387 |
+
elif isinstance(generator, EmbeddingResponse):
|
388 |
+
return JSONResponse(content=generator.model_dump())
|
389 |
+
|
390 |
+
assert_never(generator)
|
391 |
+
|
392 |
+
|
393 |
+
if envs.VLLM_TORCH_PROFILER_DIR:
|
394 |
+
logger.warning(
|
395 |
+
"Torch Profiler is enabled in the API server. This should ONLY be "
|
396 |
+
"used for local development!")
|
397 |
+
|
398 |
+
@router.post("/start_profile")
|
399 |
+
async def start_profile(raw_request: Request):
|
400 |
+
logger.info("Starting profiler...")
|
401 |
+
await engine_client(raw_request).start_profile()
|
402 |
+
logger.info("Profiler started.")
|
403 |
+
return Response(status_code=200)
|
404 |
+
|
405 |
+
@router.post("/stop_profile")
|
406 |
+
async def stop_profile(raw_request: Request):
|
407 |
+
logger.info("Stopping profiler...")
|
408 |
+
await engine_client(raw_request).stop_profile()
|
409 |
+
logger.info("Profiler stopped.")
|
410 |
+
return Response(status_code=200)
|
411 |
+
|
412 |
+
|
413 |
+
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
414 |
+
logger.warning(
|
415 |
+
"Lora dynamic loading & unloading is enabled in the API server. "
|
416 |
+
"This should ONLY be used for local development!")
|
417 |
+
|
418 |
+
@router.post("/v1/load_lora_adapter")
|
419 |
+
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
420 |
+
raw_request: Request):
|
421 |
+
for route in [chat, completion, embedding]:
|
422 |
+
handler = route(raw_request)
|
423 |
+
if handler is not None:
|
424 |
+
response = await handler.load_lora_adapter(request)
|
425 |
+
if isinstance(response, ErrorResponse):
|
426 |
+
return JSONResponse(content=response.model_dump(),
|
427 |
+
status_code=response.code)
|
428 |
+
|
429 |
+
return Response(status_code=200, content=response)
|
430 |
+
|
431 |
+
@router.post("/v1/unload_lora_adapter")
|
432 |
+
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
433 |
+
raw_request: Request):
|
434 |
+
for route in [chat, completion, embedding]:
|
435 |
+
handler = route(raw_request)
|
436 |
+
if handler is not None:
|
437 |
+
response = await handler.unload_lora_adapter(request)
|
438 |
+
if isinstance(response, ErrorResponse):
|
439 |
+
return JSONResponse(content=response.model_dump(),
|
440 |
+
status_code=response.code)
|
441 |
+
|
442 |
+
return Response(status_code=200, content=response)
|
443 |
+
|
444 |
+
|
445 |
+
def build_app(args: Namespace) -> FastAPI:
|
446 |
+
if args.disable_fastapi_docs:
|
447 |
+
app = FastAPI(openapi_url=None,
|
448 |
+
docs_url=None,
|
449 |
+
redoc_url=None,
|
450 |
+
lifespan=lifespan)
|
451 |
+
else:
|
452 |
+
app = FastAPI(lifespan=lifespan)
|
453 |
+
app.include_router(router)
|
454 |
+
app.root_path = args.root_path
|
455 |
+
|
456 |
+
mount_metrics(app)
|
457 |
+
|
458 |
+
app.add_middleware(
|
459 |
+
CORSMiddleware,
|
460 |
+
allow_origins=args.allowed_origins,
|
461 |
+
allow_credentials=args.allow_credentials,
|
462 |
+
allow_methods=args.allowed_methods,
|
463 |
+
allow_headers=args.allowed_headers,
|
464 |
+
)
|
465 |
+
|
466 |
+
@app.exception_handler(RequestValidationError)
|
467 |
+
async def validation_exception_handler(_, exc):
|
468 |
+
chat = app.state.openai_serving_chat
|
469 |
+
err = chat.create_error_response(message=str(exc))
|
470 |
+
return JSONResponse(err.model_dump(),
|
471 |
+
status_code=HTTPStatus.BAD_REQUEST)
|
472 |
+
|
473 |
+
if token := envs.VLLM_API_KEY or args.api_key:
|
474 |
+
|
475 |
+
@app.middleware("http")
|
476 |
+
async def authentication(request: Request, call_next):
|
477 |
+
root_path = "" if args.root_path is None else args.root_path
|
478 |
+
if request.method == "OPTIONS":
|
479 |
+
return await call_next(request)
|
480 |
+
if not request.url.path.startswith(f"{root_path}/v1"):
|
481 |
+
return await call_next(request)
|
482 |
+
if request.headers.get("Authorization") != "Bearer " + token:
|
483 |
+
return JSONResponse(content={"error": "Unauthorized"},
|
484 |
+
status_code=401)
|
485 |
+
return await call_next(request)
|
486 |
+
|
487 |
+
@app.middleware("http")
|
488 |
+
async def add_request_id(request: Request, call_next):
|
489 |
+
request_id = request.headers.get("X-Request-Id") or uuid.uuid4().hex
|
490 |
+
response = await call_next(request)
|
491 |
+
response.headers["X-Request-Id"] = request_id
|
492 |
+
return response
|
493 |
+
|
494 |
+
for middleware in args.middleware:
|
495 |
+
module_path, object_name = middleware.rsplit(".", 1)
|
496 |
+
imported = getattr(importlib.import_module(module_path), object_name)
|
497 |
+
if inspect.isclass(imported):
|
498 |
+
app.add_middleware(imported)
|
499 |
+
elif inspect.iscoroutinefunction(imported):
|
500 |
+
app.middleware("http")(imported)
|
501 |
+
else:
|
502 |
+
raise ValueError(f"Invalid middleware {middleware}. "
|
503 |
+
f"Must be a function or a class.")
|
504 |
+
|
505 |
+
return app
|
506 |
+
|
507 |
+
|
508 |
+
def init_app_state(
|
509 |
+
engine_client: EngineClient,
|
510 |
+
model_config: ModelConfig,
|
511 |
+
state: State,
|
512 |
+
args: Namespace,
|
513 |
+
) -> None:
|
514 |
+
if args.served_model_name is not None:
|
515 |
+
served_model_names = args.served_model_name
|
516 |
+
else:
|
517 |
+
served_model_names = [args.model]
|
518 |
+
|
519 |
+
if args.disable_log_requests:
|
520 |
+
request_logger = None
|
521 |
+
else:
|
522 |
+
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
523 |
+
|
524 |
+
base_model_paths = [
|
525 |
+
BaseModelPath(name=name, model_path=args.model)
|
526 |
+
for name in served_model_names
|
527 |
+
]
|
528 |
+
|
529 |
+
state.engine_client = engine_client
|
530 |
+
state.log_stats = not args.disable_log_stats
|
531 |
+
|
532 |
+
state.openai_serving_chat = OpenAIServingChat(
|
533 |
+
engine_client,
|
534 |
+
model_config,
|
535 |
+
base_model_paths,
|
536 |
+
args.response_role,
|
537 |
+
lora_modules=args.lora_modules,
|
538 |
+
prompt_adapters=args.prompt_adapters,
|
539 |
+
request_logger=request_logger,
|
540 |
+
chat_template=args.chat_template,
|
541 |
+
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
542 |
+
enable_auto_tools=args.enable_auto_tool_choice,
|
543 |
+
tool_parser=args.tool_call_parser,
|
544 |
+
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
545 |
+
) if model_config.task == "generate" else None
|
546 |
+
state.openai_serving_completion = OpenAIServingCompletion(
|
547 |
+
engine_client,
|
548 |
+
model_config,
|
549 |
+
base_model_paths,
|
550 |
+
lora_modules=args.lora_modules,
|
551 |
+
prompt_adapters=args.prompt_adapters,
|
552 |
+
request_logger=request_logger,
|
553 |
+
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
554 |
+
) if model_config.task == "generate" else None
|
555 |
+
state.openai_serving_embedding = OpenAIServingEmbedding(
|
556 |
+
engine_client,
|
557 |
+
model_config,
|
558 |
+
base_model_paths,
|
559 |
+
request_logger=request_logger,
|
560 |
+
chat_template=args.chat_template,
|
561 |
+
) if model_config.task == "embedding" else None
|
562 |
+
state.openai_serving_tokenization = OpenAIServingTokenization(
|
563 |
+
engine_client,
|
564 |
+
model_config,
|
565 |
+
base_model_paths,
|
566 |
+
lora_modules=args.lora_modules,
|
567 |
+
request_logger=request_logger,
|
568 |
+
chat_template=args.chat_template,
|
569 |
+
)
|
570 |
+
|
571 |
+
|
572 |
+
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
573 |
+
family = socket.AF_INET
|
574 |
+
if is_valid_ipv6_address(addr[0]):
|
575 |
+
family = socket.AF_INET6
|
576 |
+
|
577 |
+
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
578 |
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
579 |
+
sock.bind(addr)
|
580 |
+
|
581 |
+
return sock
|
582 |
+
|
583 |
+
|
584 |
+
async def run_server(args, **uvicorn_kwargs) -> None:
|
585 |
+
logger.info("vLLM API server version %s", VLLM_VERSION)
|
586 |
+
logger.info("args: %s", args)
|
587 |
+
|
588 |
+
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
589 |
+
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
590 |
+
|
591 |
+
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
592 |
+
if args.enable_auto_tool_choice \
|
593 |
+
and args.tool_call_parser not in valide_tool_parses:
|
594 |
+
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
595 |
+
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
596 |
+
|
597 |
+
# workaround to make sure that we bind the port before the engine is set up.
|
598 |
+
# This avoids race conditions with ray.
|
599 |
+
# see https://github.com/vllm-project/vllm/issues/8204
|
600 |
+
sock_addr = (args.host or "", args.port)
|
601 |
+
sock = create_server_socket(sock_addr)
|
602 |
+
|
603 |
+
def signal_handler(*_) -> None:
|
604 |
+
# Interrupt server on sigterm while initializing
|
605 |
+
raise KeyboardInterrupt("terminated")
|
606 |
+
|
607 |
+
signal.signal(signal.SIGTERM, signal_handler)
|
608 |
+
|
609 |
+
async with build_async_engine_client(args) as engine_client:
|
610 |
+
app = build_app(args)
|
611 |
+
|
612 |
+
model_config = await engine_client.get_model_config()
|
613 |
+
init_app_state(engine_client, model_config, app.state, args)
|
614 |
+
|
615 |
+
shutdown_task = await serve_http(
|
616 |
+
app,
|
617 |
+
host=args.host,
|
618 |
+
port=args.port,
|
619 |
+
log_level=args.uvicorn_log_level,
|
620 |
+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
621 |
+
ssl_keyfile=args.ssl_keyfile,
|
622 |
+
ssl_certfile=args.ssl_certfile,
|
623 |
+
ssl_ca_certs=args.ssl_ca_certs,
|
624 |
+
ssl_cert_reqs=args.ssl_cert_reqs,
|
625 |
+
**uvicorn_kwargs,
|
626 |
+
)
|
627 |
+
|
628 |
+
# NB: Await server shutdown only after the backend context is exited
|
629 |
+
await shutdown_task
|
630 |
+
|
631 |
+
sock.close()
|
632 |
+
|
633 |
+
|
634 |
+
if __name__ == "__main__":
|
635 |
+
# NOTE(simon):
|
636 |
+
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
637 |
+
parser = FlexibleArgumentParser(
|
638 |
+
description="vLLM OpenAI-Compatible RESTful API server.")
|
639 |
+
parser = make_arg_parser(parser)
|
640 |
+
args = parser.parse_args()
|
641 |
+
validate_parsed_serve_args(args)
|
642 |
+
|
643 |
+
uvloop.run(run_server(args))
|