Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
35eafc3
1
Parent(s):
23a3cae
feat: model unloading
Browse files- pyproject.toml +1 -1
- src/faster_whisper_server/config.py +9 -16
- src/faster_whisper_server/dependencies.py +1 -1
- src/faster_whisper_server/model_manager.py +114 -30
- src/faster_whisper_server/routers/misc.py +10 -8
- src/faster_whisper_server/routers/stt.py +48 -48
- tests/model_manager_test.py +122 -0
- uv.lock +14 -0
pyproject.toml
CHANGED
@@ -19,10 +19,10 @@ dependencies = [
|
|
19 |
client = [
|
20 |
"keyboard>=0.13.5",
|
21 |
]
|
22 |
-
# NOTE: when installing `dev` group, all other groups should also be installed
|
23 |
dev = [
|
24 |
"anyio>=4.4.0",
|
25 |
"basedpyright>=1.18.0",
|
|
|
26 |
"pytest-asyncio>=0.24.0",
|
27 |
"pytest-xdist>=3.6.1",
|
28 |
"pytest>=8.3.3",
|
|
|
19 |
client = [
|
20 |
"keyboard>=0.13.5",
|
21 |
]
|
|
|
22 |
dev = [
|
23 |
"anyio>=4.4.0",
|
24 |
"basedpyright>=1.18.0",
|
25 |
+
"pytest-antilru>=2.0.0",
|
26 |
"pytest-asyncio>=0.24.0",
|
27 |
"pytest-xdist>=3.6.1",
|
28 |
"pytest>=8.3.3",
|
src/faster_whisper_server/config.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import enum
|
2 |
-
from typing import Self
|
3 |
|
4 |
-
from pydantic import BaseModel, Field
|
5 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
6 |
|
7 |
SAMPLES_PER_SECOND = 16000
|
@@ -163,6 +162,12 @@ class WhisperConfig(BaseModel):
|
|
163 |
compute_type: Quantization = Field(default=Quantization.DEFAULT)
|
164 |
cpu_threads: int = 0
|
165 |
num_workers: int = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
|
168 |
class Config(BaseSettings):
|
@@ -198,10 +203,6 @@ class Config(BaseSettings):
|
|
198 |
"""
|
199 |
default_response_format: ResponseFormat = ResponseFormat.JSON
|
200 |
whisper: WhisperConfig = WhisperConfig()
|
201 |
-
max_models: int = 1
|
202 |
-
"""
|
203 |
-
Maximum number of models that can be loaded at a time.
|
204 |
-
"""
|
205 |
preload_models: list[str] = Field(
|
206 |
default_factory=list,
|
207 |
examples=[
|
@@ -210,8 +211,8 @@ class Config(BaseSettings):
|
|
210 |
],
|
211 |
)
|
212 |
"""
|
213 |
-
List of models to preload on startup.
|
214 |
-
"""
|
215 |
max_no_data_seconds: float = 1.0
|
216 |
"""
|
217 |
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
|
@@ -230,11 +231,3 @@ class Config(BaseSettings):
|
|
230 |
Controls how many latest seconds of audio are being passed through VAD.
|
231 |
Should be greater than `max_inactivity_seconds`
|
232 |
"""
|
233 |
-
|
234 |
-
@model_validator(mode="after")
|
235 |
-
def ensure_preloaded_models_is_lte_max_models(self) -> Self:
|
236 |
-
if len(self.preload_models) > self.max_models:
|
237 |
-
raise ValueError(
|
238 |
-
f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
|
239 |
-
)
|
240 |
-
return self
|
|
|
1 |
import enum
|
|
|
2 |
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
5 |
|
6 |
SAMPLES_PER_SECOND = 16000
|
|
|
162 |
compute_type: Quantization = Field(default=Quantization.DEFAULT)
|
163 |
cpu_threads: int = 0
|
164 |
num_workers: int = 1
|
165 |
+
ttl: int = Field(default=300, ge=-1)
|
166 |
+
"""
|
167 |
+
Time in seconds until the model is unloaded if it is not being used.
|
168 |
+
-1: Never unload the model.
|
169 |
+
0: Unload the model immediately after usage.
|
170 |
+
"""
|
171 |
|
172 |
|
173 |
class Config(BaseSettings):
|
|
|
203 |
"""
|
204 |
default_response_format: ResponseFormat = ResponseFormat.JSON
|
205 |
whisper: WhisperConfig = WhisperConfig()
|
|
|
|
|
|
|
|
|
206 |
preload_models: list[str] = Field(
|
207 |
default_factory=list,
|
208 |
examples=[
|
|
|
211 |
],
|
212 |
)
|
213 |
"""
|
214 |
+
List of models to preload on startup. By default, the model is first loaded on first request.
|
215 |
+
"""
|
216 |
max_no_data_seconds: float = 1.0
|
217 |
"""
|
218 |
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
|
|
|
231 |
Controls how many latest seconds of audio are being passed through VAD.
|
232 |
Should be greater than `max_inactivity_seconds`
|
233 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/faster_whisper_server/dependencies.py
CHANGED
@@ -18,7 +18,7 @@ ConfigDependency = Annotated[Config, Depends(get_config)]
|
|
18 |
@lru_cache
|
19 |
def get_model_manager() -> ModelManager:
|
20 |
config = get_config() # HACK
|
21 |
-
return ModelManager(config)
|
22 |
|
23 |
|
24 |
ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
|
|
|
18 |
@lru_cache
|
19 |
def get_model_manager() -> ModelManager:
|
20 |
config = get_config() # HACK
|
21 |
+
return ModelManager(config.whisper)
|
22 |
|
23 |
|
24 |
ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
|
src/faster_whisper_server/model_manager.py
CHANGED
@@ -3,48 +3,132 @@ from __future__ import annotations
|
|
3 |
from collections import OrderedDict
|
4 |
import gc
|
5 |
import logging
|
|
|
6 |
import time
|
7 |
from typing import TYPE_CHECKING
|
8 |
|
9 |
from faster_whisper import WhisperModel
|
10 |
|
11 |
if TYPE_CHECKING:
|
|
|
|
|
12 |
from faster_whisper_server.config import (
|
13 |
-
|
14 |
)
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
class ModelManager:
|
20 |
-
def __init__(self,
|
21 |
-
self.
|
22 |
-
self.loaded_models: OrderedDict[str,
|
|
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
)
|
33 |
-
|
34 |
-
gc.collect()
|
35 |
-
logger.debug(f"Loading {model_name}...")
|
36 |
-
start = time.perf_counter()
|
37 |
-
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
|
38 |
-
whisper = WhisperModel(
|
39 |
-
model_name,
|
40 |
-
device=self.config.whisper.inference_device,
|
41 |
-
device_index=self.config.whisper.device_index,
|
42 |
-
compute_type=self.config.whisper.compute_type,
|
43 |
-
cpu_threads=self.config.whisper.cpu_threads,
|
44 |
-
num_workers=self.config.whisper.num_workers,
|
45 |
-
)
|
46 |
-
logger.info(
|
47 |
-
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501
|
48 |
-
)
|
49 |
-
self.loaded_models[model_name] = whisper
|
50 |
-
return whisper
|
|
|
3 |
from collections import OrderedDict
|
4 |
import gc
|
5 |
import logging
|
6 |
+
import threading
|
7 |
import time
|
8 |
from typing import TYPE_CHECKING
|
9 |
|
10 |
from faster_whisper import WhisperModel
|
11 |
|
12 |
if TYPE_CHECKING:
|
13 |
+
from collections.abc import Callable
|
14 |
+
|
15 |
from faster_whisper_server.config import (
|
16 |
+
WhisperConfig,
|
17 |
)
|
18 |
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
+
# TODO: enable concurrent model downloads
|
22 |
+
|
23 |
+
|
24 |
+
class SelfDisposingWhisperModel:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
model_id: str,
|
28 |
+
whisper_config: WhisperConfig,
|
29 |
+
*,
|
30 |
+
on_unload: Callable[[str], None] | None = None,
|
31 |
+
) -> None:
|
32 |
+
self.model_id = model_id
|
33 |
+
self.whisper_config = whisper_config
|
34 |
+
self.on_unload = on_unload
|
35 |
+
|
36 |
+
self.ref_count: int = 0
|
37 |
+
self.rlock = threading.RLock()
|
38 |
+
self.expire_timer: threading.Timer | None = None
|
39 |
+
self.whisper: WhisperModel | None = None
|
40 |
+
|
41 |
+
def unload(self) -> None:
|
42 |
+
with self.rlock:
|
43 |
+
if self.whisper is None:
|
44 |
+
raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
|
45 |
+
if self.ref_count > 0:
|
46 |
+
raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
|
47 |
+
if self.expire_timer:
|
48 |
+
self.expire_timer.cancel()
|
49 |
+
self.whisper = None
|
50 |
+
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
|
51 |
+
gc.collect()
|
52 |
+
logger.info(f"Model {self.model_id} unloaded")
|
53 |
+
if self.on_unload is not None:
|
54 |
+
self.on_unload(self.model_id)
|
55 |
+
|
56 |
+
def _load(self) -> None:
|
57 |
+
with self.rlock:
|
58 |
+
assert self.whisper is None
|
59 |
+
logger.debug(f"Loading model {self.model_id}")
|
60 |
+
start = time.perf_counter()
|
61 |
+
self.whisper = WhisperModel(
|
62 |
+
self.model_id,
|
63 |
+
device=self.whisper_config.inference_device,
|
64 |
+
device_index=self.whisper_config.device_index,
|
65 |
+
compute_type=self.whisper_config.compute_type,
|
66 |
+
cpu_threads=self.whisper_config.cpu_threads,
|
67 |
+
num_workers=self.whisper_config.num_workers,
|
68 |
+
)
|
69 |
+
logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
|
70 |
+
|
71 |
+
def _increment_ref(self) -> None:
|
72 |
+
with self.rlock:
|
73 |
+
self.ref_count += 1
|
74 |
+
if self.expire_timer:
|
75 |
+
logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
|
76 |
+
self.expire_timer.cancel()
|
77 |
+
logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")
|
78 |
+
|
79 |
+
def _decrement_ref(self) -> None:
|
80 |
+
with self.rlock:
|
81 |
+
self.ref_count -= 1
|
82 |
+
logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
|
83 |
+
if self.ref_count <= 0:
|
84 |
+
if self.whisper_config.ttl > 0:
|
85 |
+
logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
|
86 |
+
self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
|
87 |
+
self.expire_timer.start()
|
88 |
+
elif self.whisper_config.ttl == 0:
|
89 |
+
logger.info(f"Model {self.model_id} is idle, unloading immediately")
|
90 |
+
self.unload()
|
91 |
+
else:
|
92 |
+
logger.info(f"Model {self.model_id} is idle, not unloading")
|
93 |
+
|
94 |
+
def __enter__(self) -> WhisperModel:
|
95 |
+
with self.rlock:
|
96 |
+
if self.whisper is None:
|
97 |
+
self._load()
|
98 |
+
self._increment_ref()
|
99 |
+
assert self.whisper is not None
|
100 |
+
return self.whisper
|
101 |
+
|
102 |
+
def __exit__(self, *_args) -> None: # noqa: ANN002
|
103 |
+
self._decrement_ref()
|
104 |
+
|
105 |
|
106 |
class ModelManager:
|
107 |
+
def __init__(self, whisper_config: WhisperConfig) -> None:
|
108 |
+
self.whisper_config = whisper_config
|
109 |
+
self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
|
110 |
+
self._lock = threading.Lock()
|
111 |
|
112 |
+
def _handle_model_unload(self, model_name: str) -> None:
|
113 |
+
with self._lock:
|
114 |
+
if model_name in self.loaded_models:
|
115 |
+
del self.loaded_models[model_name]
|
116 |
+
|
117 |
+
def unload_model(self, model_name: str) -> None:
|
118 |
+
with self._lock:
|
119 |
+
model = self.loaded_models.get(model_name)
|
120 |
+
if model is None:
|
121 |
+
raise KeyError(f"Model {model_name} not found")
|
122 |
+
self.loaded_models[model_name].unload()
|
123 |
+
|
124 |
+
def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
|
125 |
+
with self._lock:
|
126 |
+
if model_name in self.loaded_models:
|
127 |
+
logger.debug(f"{model_name} model already loaded")
|
128 |
+
return self.loaded_models[model_name]
|
129 |
+
self.loaded_models[model_name] = SelfDisposingWhisperModel(
|
130 |
+
model_name,
|
131 |
+
self.whisper_config,
|
132 |
+
on_unload=self._handle_model_unload,
|
133 |
)
|
134 |
+
return self.loaded_models[model_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/faster_whisper_server/routers/misc.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
import gc
|
4 |
-
|
5 |
from fastapi import (
|
6 |
APIRouter,
|
7 |
Response,
|
@@ -42,15 +40,19 @@ def get_running_models(
|
|
42 |
def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
|
43 |
if model_name in model_manager.loaded_models:
|
44 |
return Response(status_code=409, content="Model already loaded")
|
45 |
-
model_manager.load_model(model_name)
|
|
|
46 |
return Response(status_code=201)
|
47 |
|
48 |
|
49 |
@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
|
50 |
def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
|
51 |
-
|
52 |
-
|
53 |
-
del model_manager.loaded_models[model_name]
|
54 |
-
gc.collect()
|
55 |
return Response(status_code=204)
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
|
|
3 |
from fastapi import (
|
4 |
APIRouter,
|
5 |
Response,
|
|
|
40 |
def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
|
41 |
if model_name in model_manager.loaded_models:
|
42 |
return Response(status_code=409, content="Model already loaded")
|
43 |
+
with model_manager.load_model(model_name):
|
44 |
+
pass
|
45 |
return Response(status_code=201)
|
46 |
|
47 |
|
48 |
@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
|
49 |
def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
|
50 |
+
try:
|
51 |
+
model_manager.unload_model(model_name)
|
|
|
|
|
52 |
return Response(status_code=204)
|
53 |
+
except (KeyError, ValueError) as e:
|
54 |
+
match e:
|
55 |
+
case KeyError():
|
56 |
+
return Response(status_code=404, content="Model not found")
|
57 |
+
case ValueError():
|
58 |
+
return Response(status_code=409, content=str(e))
|
src/faster_whisper_server/routers/stt.py
CHANGED
@@ -142,20 +142,20 @@ def translate_file(
|
|
142 |
model = config.whisper.model
|
143 |
if response_format is None:
|
144 |
response_format = config.default_response_format
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
|
160 |
|
161 |
# HACK: Since Form() doesn't support `alias`, we need to use a workaround.
|
@@ -206,23 +206,23 @@ def transcribe_file(
|
|
206 |
logger.warning(
|
207 |
"It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
|
208 |
)
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
|
227 |
|
228 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
@@ -280,24 +280,24 @@ async def transcribe_stream(
|
|
280 |
"vad_filter": vad_filter,
|
281 |
"condition_on_previous_text": False,
|
282 |
}
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
|
302 |
if ws.client_state != WebSocketState.DISCONNECTED:
|
303 |
logger.info("Closing the connection.")
|
|
|
142 |
model = config.whisper.model
|
143 |
if response_format is None:
|
144 |
response_format = config.default_response_format
|
145 |
+
with model_manager.load_model(model) as whisper:
|
146 |
+
segments, transcription_info = whisper.transcribe(
|
147 |
+
file.file,
|
148 |
+
task=Task.TRANSLATE,
|
149 |
+
initial_prompt=prompt,
|
150 |
+
temperature=temperature,
|
151 |
+
vad_filter=vad_filter,
|
152 |
+
)
|
153 |
+
segments = TranscriptionSegment.from_faster_whisper_segments(segments)
|
154 |
+
|
155 |
+
if stream:
|
156 |
+
return segments_to_streaming_response(segments, transcription_info, response_format)
|
157 |
+
else:
|
158 |
+
return segments_to_response(segments, transcription_info, response_format)
|
159 |
|
160 |
|
161 |
# HACK: Since Form() doesn't support `alias`, we need to use a workaround.
|
|
|
206 |
logger.warning(
|
207 |
"It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
|
208 |
)
|
209 |
+
with model_manager.load_model(model) as whisper:
|
210 |
+
segments, transcription_info = whisper.transcribe(
|
211 |
+
file.file,
|
212 |
+
task=Task.TRANSCRIBE,
|
213 |
+
language=language,
|
214 |
+
initial_prompt=prompt,
|
215 |
+
word_timestamps="word" in timestamp_granularities,
|
216 |
+
temperature=temperature,
|
217 |
+
vad_filter=vad_filter,
|
218 |
+
hotwords=hotwords,
|
219 |
+
)
|
220 |
+
segments = TranscriptionSegment.from_faster_whisper_segments(segments)
|
221 |
+
|
222 |
+
if stream:
|
223 |
+
return segments_to_streaming_response(segments, transcription_info, response_format)
|
224 |
+
else:
|
225 |
+
return segments_to_response(segments, transcription_info, response_format)
|
226 |
|
227 |
|
228 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
|
|
280 |
"vad_filter": vad_filter,
|
281 |
"condition_on_previous_text": False,
|
282 |
}
|
283 |
+
with model_manager.load_model(model) as whisper:
|
284 |
+
asr = FasterWhisperASR(whisper, **transcribe_opts)
|
285 |
+
audio_stream = AudioStream()
|
286 |
+
async with asyncio.TaskGroup() as tg:
|
287 |
+
tg.create_task(audio_receiver(ws, audio_stream))
|
288 |
+
async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
|
289 |
+
logger.debug(f"Sending transcription: {transcription.text}")
|
290 |
+
if ws.client_state == WebSocketState.DISCONNECTED:
|
291 |
+
break
|
292 |
|
293 |
+
if response_format == ResponseFormat.TEXT:
|
294 |
+
await ws.send_text(transcription.text)
|
295 |
+
elif response_format == ResponseFormat.JSON:
|
296 |
+
await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
|
297 |
+
elif response_format == ResponseFormat.VERBOSE_JSON:
|
298 |
+
await ws.send_json(
|
299 |
+
CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
|
300 |
+
)
|
301 |
|
302 |
if ws.client_state != WebSocketState.DISCONNECTED:
|
303 |
logger.info("Closing the connection.")
|
tests/model_manager_test.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
|
4 |
+
import anyio
|
5 |
+
from httpx import ASGITransport, AsyncClient
|
6 |
+
import pytest
|
7 |
+
|
8 |
+
from faster_whisper_server.main import create_app
|
9 |
+
|
10 |
+
|
11 |
+
@pytest.mark.asyncio
|
12 |
+
async def test_model_unloaded_after_ttl() -> None:
|
13 |
+
ttl = 5
|
14 |
+
model = "Systran/faster-whisper-tiny.en"
|
15 |
+
os.environ["WHISPER__TTL"] = str(ttl)
|
16 |
+
os.environ["ENABLE_UI"] = "false"
|
17 |
+
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
18 |
+
res = (await aclient.get("/api/ps")).json()
|
19 |
+
assert len(res["models"]) == 0
|
20 |
+
await aclient.post(f"/api/ps/{model}")
|
21 |
+
res = (await aclient.get("/api/ps")).json()
|
22 |
+
assert len(res["models"]) == 1
|
23 |
+
await asyncio.sleep(ttl + 1)
|
24 |
+
res = (await aclient.get("/api/ps")).json()
|
25 |
+
assert len(res["models"]) == 0
|
26 |
+
|
27 |
+
|
28 |
+
@pytest.mark.asyncio
|
29 |
+
async def test_ttl_resets_after_usage() -> None:
|
30 |
+
ttl = 5
|
31 |
+
model = "Systran/faster-whisper-tiny.en"
|
32 |
+
os.environ["WHISPER__TTL"] = str(ttl)
|
33 |
+
os.environ["ENABLE_UI"] = "false"
|
34 |
+
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
35 |
+
await aclient.post(f"/api/ps/{model}")
|
36 |
+
res = (await aclient.get("/api/ps")).json()
|
37 |
+
assert len(res["models"]) == 1
|
38 |
+
await asyncio.sleep(ttl - 2)
|
39 |
+
res = (await aclient.get("/api/ps")).json()
|
40 |
+
assert len(res["models"]) == 1
|
41 |
+
|
42 |
+
async with await anyio.open_file("audio.wav", "rb") as f:
|
43 |
+
data = await f.read()
|
44 |
+
res = (
|
45 |
+
await aclient.post(
|
46 |
+
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
|
47 |
+
)
|
48 |
+
).json()
|
49 |
+
res = (await aclient.get("/api/ps")).json()
|
50 |
+
assert len(res["models"]) == 1
|
51 |
+
await asyncio.sleep(ttl - 2)
|
52 |
+
res = (await aclient.get("/api/ps")).json()
|
53 |
+
assert len(res["models"]) == 1
|
54 |
+
|
55 |
+
await asyncio.sleep(3)
|
56 |
+
res = (await aclient.get("/api/ps")).json()
|
57 |
+
assert len(res["models"]) == 0
|
58 |
+
|
59 |
+
# test the model can be used again after being unloaded
|
60 |
+
# this just ensures the model can be loaded again after being unloaded
|
61 |
+
res = (
|
62 |
+
await aclient.post(
|
63 |
+
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
|
64 |
+
)
|
65 |
+
).json()
|
66 |
+
|
67 |
+
|
68 |
+
@pytest.mark.asyncio
|
69 |
+
async def test_model_cant_be_unloaded_when_used() -> None:
|
70 |
+
ttl = 0
|
71 |
+
model = "Systran/faster-whisper-tiny.en"
|
72 |
+
os.environ["WHISPER__TTL"] = str(ttl)
|
73 |
+
os.environ["ENABLE_UI"] = "false"
|
74 |
+
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
75 |
+
async with await anyio.open_file("audio.wav", "rb") as f:
|
76 |
+
data = await f.read()
|
77 |
+
|
78 |
+
task = asyncio.create_task(
|
79 |
+
aclient.post(
|
80 |
+
"/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
|
81 |
+
)
|
82 |
+
)
|
83 |
+
await asyncio.sleep(0.01)
|
84 |
+
res = await aclient.delete(f"/api/ps/{model}")
|
85 |
+
assert res.status_code == 409
|
86 |
+
|
87 |
+
await task
|
88 |
+
res = (await aclient.get("/api/ps")).json()
|
89 |
+
assert len(res["models"]) == 0
|
90 |
+
|
91 |
+
|
92 |
+
@pytest.mark.asyncio
|
93 |
+
async def test_model_cant_be_loaded_twice() -> None:
|
94 |
+
ttl = -1
|
95 |
+
model = "Systran/faster-whisper-tiny.en"
|
96 |
+
os.environ["ENABLE_UI"] = "false"
|
97 |
+
os.environ["WHISPER__TTL"] = str(ttl)
|
98 |
+
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
99 |
+
res = await aclient.post(f"/api/ps/{model}")
|
100 |
+
assert res.status_code == 201
|
101 |
+
res = await aclient.post(f"/api/ps/{model}")
|
102 |
+
assert res.status_code == 409
|
103 |
+
res = (await aclient.get("/api/ps")).json()
|
104 |
+
assert len(res["models"]) == 1
|
105 |
+
|
106 |
+
|
107 |
+
@pytest.mark.asyncio
|
108 |
+
async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
|
109 |
+
ttl = 0
|
110 |
+
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
|
111 |
+
os.environ["WHISPER__TTL"] = str(ttl)
|
112 |
+
os.environ["ENABLE_UI"] = "false"
|
113 |
+
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
114 |
+
async with await anyio.open_file("audio.wav", "rb") as f:
|
115 |
+
data = await f.read()
|
116 |
+
res = await aclient.post(
|
117 |
+
"/v1/audio/transcriptions",
|
118 |
+
files={"file": ("audio.wav", data, "audio/wav")},
|
119 |
+
data={"model": "Systran/faster-whisper-tiny.en"},
|
120 |
+
)
|
121 |
+
res = (await aclient.get("/api/ps")).json()
|
122 |
+
assert len(res["models"]) == 0
|
uv.lock
CHANGED
@@ -293,6 +293,7 @@ dev = [
|
|
293 |
{ name = "anyio" },
|
294 |
{ name = "basedpyright" },
|
295 |
{ name = "pytest" },
|
|
|
296 |
{ name = "pytest-asyncio" },
|
297 |
{ name = "pytest-xdist" },
|
298 |
{ name = "ruff" },
|
@@ -322,6 +323,7 @@ requires-dist = [
|
|
322 |
{ name = "pydantic", specifier = ">=2.9.0" },
|
323 |
{ name = "pydantic-settings", specifier = ">=2.5.2" },
|
324 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
|
|
|
325 |
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
|
326 |
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" },
|
327 |
{ name = "python-multipart", specifier = ">=0.0.10" },
|
@@ -3482,6 +3484,18 @@ wheels = [
|
|
3482 |
{ url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
|
3483 |
]
|
3484 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3485 |
[[package]]
|
3486 |
name = "pytest-asyncio"
|
3487 |
version = "0.24.0"
|
|
|
293 |
{ name = "anyio" },
|
294 |
{ name = "basedpyright" },
|
295 |
{ name = "pytest" },
|
296 |
+
{ name = "pytest-antilru" },
|
297 |
{ name = "pytest-asyncio" },
|
298 |
{ name = "pytest-xdist" },
|
299 |
{ name = "ruff" },
|
|
|
323 |
{ name = "pydantic", specifier = ">=2.9.0" },
|
324 |
{ name = "pydantic-settings", specifier = ">=2.5.2" },
|
325 |
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
|
326 |
+
{ name = "pytest-antilru", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
327 |
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
|
328 |
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" },
|
329 |
{ name = "python-multipart", specifier = ">=0.0.10" },
|
|
|
3484 |
{ url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
|
3485 |
]
|
3486 |
|
3487 |
+
[[package]]
|
3488 |
+
name = "pytest-antilru"
|
3489 |
+
version = "2.0.0"
|
3490 |
+
source = { registry = "https://pypi.org/simple" }
|
3491 |
+
dependencies = [
|
3492 |
+
{ name = "pytest" },
|
3493 |
+
]
|
3494 |
+
sdist = { url = "https://files.pythonhosted.org/packages/c6/01/0b5ef3f143f335b5cb1c1e8e6497769dfb48aed5a791b5dfd119151e2b15/pytest_antilru-2.0.0.tar.gz", hash = "sha256:48cff342648b6a1ce4e5398cf203966905d546b3f2bee7bb55d7cb3ec87a85fb", size = 5569 }
|
3495 |
+
wheels = [
|
3496 |
+
{ url = "https://files.pythonhosted.org/packages/23/f0/fc9f5aaaf2818a7d7f795e99fcf59719dd6ec5f98005e642e1efd63ad2a4/pytest_antilru-2.0.0-py3-none-any.whl", hash = "sha256:cf1d97db0e7b17ef568c1f0bf4c89b8748053fe07546f4eb2558bebf64c1ad33", size = 6301 },
|
3497 |
+
]
|
3498 |
+
|
3499 |
[[package]]
|
3500 |
name = "pytest-asyncio"
|
3501 |
version = "0.24.0"
|