Fedir Zadniprovskyi commited on
Commit
35eafc3
·
1 Parent(s): 23a3cae

feat: model unloading

Browse files
pyproject.toml CHANGED
@@ -19,10 +19,10 @@ dependencies = [
19
  client = [
20
  "keyboard>=0.13.5",
21
  ]
22
- # NOTE: when installing `dev` group, all other groups should also be installed
23
  dev = [
24
  "anyio>=4.4.0",
25
  "basedpyright>=1.18.0",
 
26
  "pytest-asyncio>=0.24.0",
27
  "pytest-xdist>=3.6.1",
28
  "pytest>=8.3.3",
 
19
  client = [
20
  "keyboard>=0.13.5",
21
  ]
 
22
  dev = [
23
  "anyio>=4.4.0",
24
  "basedpyright>=1.18.0",
25
+ "pytest-antilru>=2.0.0",
26
  "pytest-asyncio>=0.24.0",
27
  "pytest-xdist>=3.6.1",
28
  "pytest>=8.3.3",
src/faster_whisper_server/config.py CHANGED
@@ -1,7 +1,6 @@
1
  import enum
2
- from typing import Self
3
 
4
- from pydantic import BaseModel, Field, model_validator
5
  from pydantic_settings import BaseSettings, SettingsConfigDict
6
 
7
  SAMPLES_PER_SECOND = 16000
@@ -163,6 +162,12 @@ class WhisperConfig(BaseModel):
163
  compute_type: Quantization = Field(default=Quantization.DEFAULT)
164
  cpu_threads: int = 0
165
  num_workers: int = 1
 
 
 
 
 
 
166
 
167
 
168
  class Config(BaseSettings):
@@ -198,10 +203,6 @@ class Config(BaseSettings):
198
  """
199
  default_response_format: ResponseFormat = ResponseFormat.JSON
200
  whisper: WhisperConfig = WhisperConfig()
201
- max_models: int = 1
202
- """
203
- Maximum number of models that can be loaded at a time.
204
- """
205
  preload_models: list[str] = Field(
206
  default_factory=list,
207
  examples=[
@@ -210,8 +211,8 @@ class Config(BaseSettings):
210
  ],
211
  )
212
  """
213
- List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request.
214
- """ # noqa: E501
215
  max_no_data_seconds: float = 1.0
216
  """
217
  Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
@@ -230,11 +231,3 @@ class Config(BaseSettings):
230
  Controls how many latest seconds of audio are being passed through VAD.
231
  Should be greater than `max_inactivity_seconds`
232
  """
233
-
234
- @model_validator(mode="after")
235
- def ensure_preloaded_models_is_lte_max_models(self) -> Self:
236
- if len(self.preload_models) > self.max_models:
237
- raise ValueError(
238
- f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
239
- )
240
- return self
 
1
  import enum
 
2
 
3
+ from pydantic import BaseModel, Field
4
  from pydantic_settings import BaseSettings, SettingsConfigDict
5
 
6
  SAMPLES_PER_SECOND = 16000
 
162
  compute_type: Quantization = Field(default=Quantization.DEFAULT)
163
  cpu_threads: int = 0
164
  num_workers: int = 1
165
+ ttl: int = Field(default=300, ge=-1)
166
+ """
167
+ Time in seconds until the model is unloaded if it is not being used.
168
+ -1: Never unload the model.
169
+ 0: Unload the model immediately after usage.
170
+ """
171
 
172
 
173
  class Config(BaseSettings):
 
203
  """
204
  default_response_format: ResponseFormat = ResponseFormat.JSON
205
  whisper: WhisperConfig = WhisperConfig()
 
 
 
 
206
  preload_models: list[str] = Field(
207
  default_factory=list,
208
  examples=[
 
211
  ],
212
  )
213
  """
214
+ List of models to preload on startup. By default, the model is first loaded on first request.
215
+ """
216
  max_no_data_seconds: float = 1.0
217
  """
218
  Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
 
231
  Controls how many latest seconds of audio are being passed through VAD.
232
  Should be greater than `max_inactivity_seconds`
233
  """
 
 
 
 
 
 
 
 
src/faster_whisper_server/dependencies.py CHANGED
@@ -18,7 +18,7 @@ ConfigDependency = Annotated[Config, Depends(get_config)]
18
  @lru_cache
19
  def get_model_manager() -> ModelManager:
20
  config = get_config() # HACK
21
- return ModelManager(config)
22
 
23
 
24
  ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
 
18
  @lru_cache
19
  def get_model_manager() -> ModelManager:
20
  config = get_config() # HACK
21
+ return ModelManager(config.whisper)
22
 
23
 
24
  ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
src/faster_whisper_server/model_manager.py CHANGED
@@ -3,48 +3,132 @@ from __future__ import annotations
3
  from collections import OrderedDict
4
  import gc
5
  import logging
 
6
  import time
7
  from typing import TYPE_CHECKING
8
 
9
  from faster_whisper import WhisperModel
10
 
11
  if TYPE_CHECKING:
 
 
12
  from faster_whisper_server.config import (
13
- Config,
14
  )
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class ModelManager:
20
- def __init__(self, config: Config) -> None:
21
- self.config = config
22
- self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
 
23
 
24
- def load_model(self, model_name: str) -> WhisperModel:
25
- if model_name in self.loaded_models:
26
- logger.debug(f"{model_name} model already loaded")
27
- return self.loaded_models[model_name]
28
- if len(self.loaded_models) >= self.config.max_models:
29
- oldest_model_name = next(iter(self.loaded_models))
30
- logger.info(
31
- f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
33
- del self.loaded_models[oldest_model_name]
34
- gc.collect()
35
- logger.debug(f"Loading {model_name}...")
36
- start = time.perf_counter()
37
- # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
38
- whisper = WhisperModel(
39
- model_name,
40
- device=self.config.whisper.inference_device,
41
- device_index=self.config.whisper.device_index,
42
- compute_type=self.config.whisper.compute_type,
43
- cpu_threads=self.config.whisper.cpu_threads,
44
- num_workers=self.config.whisper.num_workers,
45
- )
46
- logger.info(
47
- f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501
48
- )
49
- self.loaded_models[model_name] = whisper
50
- return whisper
 
3
  from collections import OrderedDict
4
  import gc
5
  import logging
6
+ import threading
7
  import time
8
  from typing import TYPE_CHECKING
9
 
10
  from faster_whisper import WhisperModel
11
 
12
  if TYPE_CHECKING:
13
+ from collections.abc import Callable
14
+
15
  from faster_whisper_server.config import (
16
+ WhisperConfig,
17
  )
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
+ # TODO: enable concurrent model downloads
22
+
23
+
24
+ class SelfDisposingWhisperModel:
25
+ def __init__(
26
+ self,
27
+ model_id: str,
28
+ whisper_config: WhisperConfig,
29
+ *,
30
+ on_unload: Callable[[str], None] | None = None,
31
+ ) -> None:
32
+ self.model_id = model_id
33
+ self.whisper_config = whisper_config
34
+ self.on_unload = on_unload
35
+
36
+ self.ref_count: int = 0
37
+ self.rlock = threading.RLock()
38
+ self.expire_timer: threading.Timer | None = None
39
+ self.whisper: WhisperModel | None = None
40
+
41
+ def unload(self) -> None:
42
+ with self.rlock:
43
+ if self.whisper is None:
44
+ raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
45
+ if self.ref_count > 0:
46
+ raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
47
+ if self.expire_timer:
48
+ self.expire_timer.cancel()
49
+ self.whisper = None
50
+ # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
51
+ gc.collect()
52
+ logger.info(f"Model {self.model_id} unloaded")
53
+ if self.on_unload is not None:
54
+ self.on_unload(self.model_id)
55
+
56
+ def _load(self) -> None:
57
+ with self.rlock:
58
+ assert self.whisper is None
59
+ logger.debug(f"Loading model {self.model_id}")
60
+ start = time.perf_counter()
61
+ self.whisper = WhisperModel(
62
+ self.model_id,
63
+ device=self.whisper_config.inference_device,
64
+ device_index=self.whisper_config.device_index,
65
+ compute_type=self.whisper_config.compute_type,
66
+ cpu_threads=self.whisper_config.cpu_threads,
67
+ num_workers=self.whisper_config.num_workers,
68
+ )
69
+ logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
70
+
71
+ def _increment_ref(self) -> None:
72
+ with self.rlock:
73
+ self.ref_count += 1
74
+ if self.expire_timer:
75
+ logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
76
+ self.expire_timer.cancel()
77
+ logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")
78
+
79
+ def _decrement_ref(self) -> None:
80
+ with self.rlock:
81
+ self.ref_count -= 1
82
+ logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
83
+ if self.ref_count <= 0:
84
+ if self.whisper_config.ttl > 0:
85
+ logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
86
+ self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
87
+ self.expire_timer.start()
88
+ elif self.whisper_config.ttl == 0:
89
+ logger.info(f"Model {self.model_id} is idle, unloading immediately")
90
+ self.unload()
91
+ else:
92
+ logger.info(f"Model {self.model_id} is idle, not unloading")
93
+
94
+ def __enter__(self) -> WhisperModel:
95
+ with self.rlock:
96
+ if self.whisper is None:
97
+ self._load()
98
+ self._increment_ref()
99
+ assert self.whisper is not None
100
+ return self.whisper
101
+
102
+ def __exit__(self, *_args) -> None: # noqa: ANN002
103
+ self._decrement_ref()
104
+
105
 
106
  class ModelManager:
107
+ def __init__(self, whisper_config: WhisperConfig) -> None:
108
+ self.whisper_config = whisper_config
109
+ self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
110
+ self._lock = threading.Lock()
111
 
112
+ def _handle_model_unload(self, model_name: str) -> None:
113
+ with self._lock:
114
+ if model_name in self.loaded_models:
115
+ del self.loaded_models[model_name]
116
+
117
+ def unload_model(self, model_name: str) -> None:
118
+ with self._lock:
119
+ model = self.loaded_models.get(model_name)
120
+ if model is None:
121
+ raise KeyError(f"Model {model_name} not found")
122
+ self.loaded_models[model_name].unload()
123
+
124
+ def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
125
+ with self._lock:
126
+ if model_name in self.loaded_models:
127
+ logger.debug(f"{model_name} model already loaded")
128
+ return self.loaded_models[model_name]
129
+ self.loaded_models[model_name] = SelfDisposingWhisperModel(
130
+ model_name,
131
+ self.whisper_config,
132
+ on_unload=self._handle_model_unload,
133
  )
134
+ return self.loaded_models[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/faster_whisper_server/routers/misc.py CHANGED
@@ -1,7 +1,5 @@
1
  from __future__ import annotations
2
 
3
- import gc
4
-
5
  from fastapi import (
6
  APIRouter,
7
  Response,
@@ -42,15 +40,19 @@ def get_running_models(
42
  def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
43
  if model_name in model_manager.loaded_models:
44
  return Response(status_code=409, content="Model already loaded")
45
- model_manager.load_model(model_name)
 
46
  return Response(status_code=201)
47
 
48
 
49
  @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
50
  def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
51
- model = model_manager.loaded_models.get(model_name)
52
- if model is not None:
53
- del model_manager.loaded_models[model_name]
54
- gc.collect()
55
  return Response(status_code=204)
56
- return Response(status_code=404)
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
 
 
3
  from fastapi import (
4
  APIRouter,
5
  Response,
 
40
  def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
41
  if model_name in model_manager.loaded_models:
42
  return Response(status_code=409, content="Model already loaded")
43
+ with model_manager.load_model(model_name):
44
+ pass
45
  return Response(status_code=201)
46
 
47
 
48
  @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
49
  def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
50
+ try:
51
+ model_manager.unload_model(model_name)
 
 
52
  return Response(status_code=204)
53
+ except (KeyError, ValueError) as e:
54
+ match e:
55
+ case KeyError():
56
+ return Response(status_code=404, content="Model not found")
57
+ case ValueError():
58
+ return Response(status_code=409, content=str(e))
src/faster_whisper_server/routers/stt.py CHANGED
@@ -142,20 +142,20 @@ def translate_file(
142
  model = config.whisper.model
143
  if response_format is None:
144
  response_format = config.default_response_format
145
- whisper = model_manager.load_model(model)
146
- segments, transcription_info = whisper.transcribe(
147
- file.file,
148
- task=Task.TRANSLATE,
149
- initial_prompt=prompt,
150
- temperature=temperature,
151
- vad_filter=vad_filter,
152
- )
153
- segments = TranscriptionSegment.from_faster_whisper_segments(segments)
154
-
155
- if stream:
156
- return segments_to_streaming_response(segments, transcription_info, response_format)
157
- else:
158
- return segments_to_response(segments, transcription_info, response_format)
159
 
160
 
161
  # HACK: Since Form() doesn't support `alias`, we need to use a workaround.
@@ -206,23 +206,23 @@ def transcribe_file(
206
  logger.warning(
207
  "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
208
  )
209
- whisper = model_manager.load_model(model)
210
- segments, transcription_info = whisper.transcribe(
211
- file.file,
212
- task=Task.TRANSCRIBE,
213
- language=language,
214
- initial_prompt=prompt,
215
- word_timestamps="word" in timestamp_granularities,
216
- temperature=temperature,
217
- vad_filter=vad_filter,
218
- hotwords=hotwords,
219
- )
220
- segments = TranscriptionSegment.from_faster_whisper_segments(segments)
221
-
222
- if stream:
223
- return segments_to_streaming_response(segments, transcription_info, response_format)
224
- else:
225
- return segments_to_response(segments, transcription_info, response_format)
226
 
227
 
228
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
@@ -280,24 +280,24 @@ async def transcribe_stream(
280
  "vad_filter": vad_filter,
281
  "condition_on_previous_text": False,
282
  }
283
- whisper = model_manager.load_model(model)
284
- asr = FasterWhisperASR(whisper, **transcribe_opts)
285
- audio_stream = AudioStream()
286
- async with asyncio.TaskGroup() as tg:
287
- tg.create_task(audio_receiver(ws, audio_stream))
288
- async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
289
- logger.debug(f"Sending transcription: {transcription.text}")
290
- if ws.client_state == WebSocketState.DISCONNECTED:
291
- break
292
 
293
- if response_format == ResponseFormat.TEXT:
294
- await ws.send_text(transcription.text)
295
- elif response_format == ResponseFormat.JSON:
296
- await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
297
- elif response_format == ResponseFormat.VERBOSE_JSON:
298
- await ws.send_json(
299
- CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
300
- )
301
 
302
  if ws.client_state != WebSocketState.DISCONNECTED:
303
  logger.info("Closing the connection.")
 
142
  model = config.whisper.model
143
  if response_format is None:
144
  response_format = config.default_response_format
145
+ with model_manager.load_model(model) as whisper:
146
+ segments, transcription_info = whisper.transcribe(
147
+ file.file,
148
+ task=Task.TRANSLATE,
149
+ initial_prompt=prompt,
150
+ temperature=temperature,
151
+ vad_filter=vad_filter,
152
+ )
153
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments)
154
+
155
+ if stream:
156
+ return segments_to_streaming_response(segments, transcription_info, response_format)
157
+ else:
158
+ return segments_to_response(segments, transcription_info, response_format)
159
 
160
 
161
  # HACK: Since Form() doesn't support `alias`, we need to use a workaround.
 
206
  logger.warning(
207
  "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
208
  )
209
+ with model_manager.load_model(model) as whisper:
210
+ segments, transcription_info = whisper.transcribe(
211
+ file.file,
212
+ task=Task.TRANSCRIBE,
213
+ language=language,
214
+ initial_prompt=prompt,
215
+ word_timestamps="word" in timestamp_granularities,
216
+ temperature=temperature,
217
+ vad_filter=vad_filter,
218
+ hotwords=hotwords,
219
+ )
220
+ segments = TranscriptionSegment.from_faster_whisper_segments(segments)
221
+
222
+ if stream:
223
+ return segments_to_streaming_response(segments, transcription_info, response_format)
224
+ else:
225
+ return segments_to_response(segments, transcription_info, response_format)
226
 
227
 
228
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
 
280
  "vad_filter": vad_filter,
281
  "condition_on_previous_text": False,
282
  }
283
+ with model_manager.load_model(model) as whisper:
284
+ asr = FasterWhisperASR(whisper, **transcribe_opts)
285
+ audio_stream = AudioStream()
286
+ async with asyncio.TaskGroup() as tg:
287
+ tg.create_task(audio_receiver(ws, audio_stream))
288
+ async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
289
+ logger.debug(f"Sending transcription: {transcription.text}")
290
+ if ws.client_state == WebSocketState.DISCONNECTED:
291
+ break
292
 
293
+ if response_format == ResponseFormat.TEXT:
294
+ await ws.send_text(transcription.text)
295
+ elif response_format == ResponseFormat.JSON:
296
+ await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
297
+ elif response_format == ResponseFormat.VERBOSE_JSON:
298
+ await ws.send_json(
299
+ CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
300
+ )
301
 
302
  if ws.client_state != WebSocketState.DISCONNECTED:
303
  logger.info("Closing the connection.")
tests/model_manager_test.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+
4
+ import anyio
5
+ from httpx import ASGITransport, AsyncClient
6
+ import pytest
7
+
8
+ from faster_whisper_server.main import create_app
9
+
10
+
11
+ @pytest.mark.asyncio
12
+ async def test_model_unloaded_after_ttl() -> None:
13
+ ttl = 5
14
+ model = "Systran/faster-whisper-tiny.en"
15
+ os.environ["WHISPER__TTL"] = str(ttl)
16
+ os.environ["ENABLE_UI"] = "false"
17
+ async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
18
+ res = (await aclient.get("/api/ps")).json()
19
+ assert len(res["models"]) == 0
20
+ await aclient.post(f"/api/ps/{model}")
21
+ res = (await aclient.get("/api/ps")).json()
22
+ assert len(res["models"]) == 1
23
+ await asyncio.sleep(ttl + 1)
24
+ res = (await aclient.get("/api/ps")).json()
25
+ assert len(res["models"]) == 0
26
+
27
+
28
+ @pytest.mark.asyncio
29
+ async def test_ttl_resets_after_usage() -> None:
30
+ ttl = 5
31
+ model = "Systran/faster-whisper-tiny.en"
32
+ os.environ["WHISPER__TTL"] = str(ttl)
33
+ os.environ["ENABLE_UI"] = "false"
34
+ async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
35
+ await aclient.post(f"/api/ps/{model}")
36
+ res = (await aclient.get("/api/ps")).json()
37
+ assert len(res["models"]) == 1
38
+ await asyncio.sleep(ttl - 2)
39
+ res = (await aclient.get("/api/ps")).json()
40
+ assert len(res["models"]) == 1
41
+
42
+ async with await anyio.open_file("audio.wav", "rb") as f:
43
+ data = await f.read()
44
+ res = (
45
+ await aclient.post(
46
+ "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
47
+ )
48
+ ).json()
49
+ res = (await aclient.get("/api/ps")).json()
50
+ assert len(res["models"]) == 1
51
+ await asyncio.sleep(ttl - 2)
52
+ res = (await aclient.get("/api/ps")).json()
53
+ assert len(res["models"]) == 1
54
+
55
+ await asyncio.sleep(3)
56
+ res = (await aclient.get("/api/ps")).json()
57
+ assert len(res["models"]) == 0
58
+
59
+ # test the model can be used again after being unloaded
60
+ # this just ensures the model can be loaded again after being unloaded
61
+ res = (
62
+ await aclient.post(
63
+ "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
64
+ )
65
+ ).json()
66
+
67
+
68
+ @pytest.mark.asyncio
69
+ async def test_model_cant_be_unloaded_when_used() -> None:
70
+ ttl = 0
71
+ model = "Systran/faster-whisper-tiny.en"
72
+ os.environ["WHISPER__TTL"] = str(ttl)
73
+ os.environ["ENABLE_UI"] = "false"
74
+ async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
75
+ async with await anyio.open_file("audio.wav", "rb") as f:
76
+ data = await f.read()
77
+
78
+ task = asyncio.create_task(
79
+ aclient.post(
80
+ "/v1/audio/transcriptions", files={"file": ("audio.wav", data, "audio/wav")}, data={"model": model}
81
+ )
82
+ )
83
+ await asyncio.sleep(0.01)
84
+ res = await aclient.delete(f"/api/ps/{model}")
85
+ assert res.status_code == 409
86
+
87
+ await task
88
+ res = (await aclient.get("/api/ps")).json()
89
+ assert len(res["models"]) == 0
90
+
91
+
92
+ @pytest.mark.asyncio
93
+ async def test_model_cant_be_loaded_twice() -> None:
94
+ ttl = -1
95
+ model = "Systran/faster-whisper-tiny.en"
96
+ os.environ["ENABLE_UI"] = "false"
97
+ os.environ["WHISPER__TTL"] = str(ttl)
98
+ async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
99
+ res = await aclient.post(f"/api/ps/{model}")
100
+ assert res.status_code == 201
101
+ res = await aclient.post(f"/api/ps/{model}")
102
+ assert res.status_code == 409
103
+ res = (await aclient.get("/api/ps")).json()
104
+ assert len(res["models"]) == 1
105
+
106
+
107
+ @pytest.mark.asyncio
108
+ async def test_model_is_unloaded_after_request_when_ttl_is_zero() -> None:
109
+ ttl = 0
110
+ os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
111
+ os.environ["WHISPER__TTL"] = str(ttl)
112
+ os.environ["ENABLE_UI"] = "false"
113
+ async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
114
+ async with await anyio.open_file("audio.wav", "rb") as f:
115
+ data = await f.read()
116
+ res = await aclient.post(
117
+ "/v1/audio/transcriptions",
118
+ files={"file": ("audio.wav", data, "audio/wav")},
119
+ data={"model": "Systran/faster-whisper-tiny.en"},
120
+ )
121
+ res = (await aclient.get("/api/ps")).json()
122
+ assert len(res["models"]) == 0
uv.lock CHANGED
@@ -293,6 +293,7 @@ dev = [
293
  { name = "anyio" },
294
  { name = "basedpyright" },
295
  { name = "pytest" },
 
296
  { name = "pytest-asyncio" },
297
  { name = "pytest-xdist" },
298
  { name = "ruff" },
@@ -322,6 +323,7 @@ requires-dist = [
322
  { name = "pydantic", specifier = ">=2.9.0" },
323
  { name = "pydantic-settings", specifier = ">=2.5.2" },
324
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
 
325
  { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
326
  { name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" },
327
  { name = "python-multipart", specifier = ">=0.0.10" },
@@ -3482,6 +3484,18 @@ wheels = [
3482
  { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
3483
  ]
3484
 
 
 
 
 
 
 
 
 
 
 
 
 
3485
  [[package]]
3486
  name = "pytest-asyncio"
3487
  version = "0.24.0"
 
293
  { name = "anyio" },
294
  { name = "basedpyright" },
295
  { name = "pytest" },
296
+ { name = "pytest-antilru" },
297
  { name = "pytest-asyncio" },
298
  { name = "pytest-xdist" },
299
  { name = "ruff" },
 
323
  { name = "pydantic", specifier = ">=2.9.0" },
324
  { name = "pydantic-settings", specifier = ">=2.5.2" },
325
  { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.3" },
326
+ { name = "pytest-antilru", marker = "extra == 'dev'", specifier = ">=2.0.0" },
327
  { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.24.0" },
328
  { name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.6.1" },
329
  { name = "python-multipart", specifier = ">=0.0.10" },
 
3484
  { url = "https://files.pythonhosted.org/packages/6b/77/7440a06a8ead44c7757a64362dd22df5760f9b12dc5f11b6188cd2fc27a0/pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2", size = 342341 },
3485
  ]
3486
 
3487
+ [[package]]
3488
+ name = "pytest-antilru"
3489
+ version = "2.0.0"
3490
+ source = { registry = "https://pypi.org/simple" }
3491
+ dependencies = [
3492
+ { name = "pytest" },
3493
+ ]
3494
+ sdist = { url = "https://files.pythonhosted.org/packages/c6/01/0b5ef3f143f335b5cb1c1e8e6497769dfb48aed5a791b5dfd119151e2b15/pytest_antilru-2.0.0.tar.gz", hash = "sha256:48cff342648b6a1ce4e5398cf203966905d546b3f2bee7bb55d7cb3ec87a85fb", size = 5569 }
3495
+ wheels = [
3496
+ { url = "https://files.pythonhosted.org/packages/23/f0/fc9f5aaaf2818a7d7f795e99fcf59719dd6ec5f98005e642e1efd63ad2a4/pytest_antilru-2.0.0-py3-none-any.whl", hash = "sha256:cf1d97db0e7b17ef568c1f0bf4c89b8748053fe07546f4eb2558bebf64c1ad33", size = 6301 },
3497
+ ]
3498
+
3499
  [[package]]
3500
  name = "pytest-asyncio"
3501
  version = "0.24.0"