Fedir Zadniprovskyi commited on
Commit
bf48682
1 Parent(s): 8f3dcc9

feat: dependency injection

Browse files

The main purpose of this change is to allow modifying the configuration
for testing. This change does lead to some ugly code where `get_config`
function gets called in random places.

Taskfile.yaml CHANGED
@@ -1,6 +1,6 @@
1
  version: "3"
2
  tasks:
3
- server: uvicorn --host 0.0.0.0 faster_whisper_server.main:app {{.CLI_ARGS}}
4
  test:
5
  cmds:
6
  - pytest -o log_cli=true -o log_cli_level=DEBUG {{.CLI_ARGS}}
 
1
  version: "3"
2
  tasks:
3
+ server: uvicorn --factory --host 0.0.0.0 faster_whisper_server.main:create_app {{.CLI_ARGS}}
4
  test:
5
  cmds:
6
  - pytest -o log_cli=true -o log_cli_level=DEBUG {{.CLI_ARGS}}
pyproject.toml CHANGED
@@ -75,6 +75,7 @@ ignore = [
75
  "ISC001", # recommended to disable for formatting
76
  "INP001",
77
  "PT018",
 
78
  ]
79
 
80
  [tool.ruff.lint.isort]
 
75
  "ISC001", # recommended to disable for formatting
76
  "INP001",
77
  "PT018",
78
+ "G004", # logging f string
79
  ]
80
 
81
  [tool.ruff.lint.isort]
src/faster_whisper_server/asr.py CHANGED
@@ -1,11 +1,13 @@
1
  import asyncio
 
2
  import time
3
 
4
  from faster_whisper import transcribe
5
 
6
  from faster_whisper_server.audio import Audio
7
  from faster_whisper_server.core import Segment, Transcription, Word
8
- from faster_whisper_server.logger import logger
 
9
 
10
 
11
  class FasterWhisperASR:
 
1
  import asyncio
2
+ import logging
3
  import time
4
 
5
  from faster_whisper import transcribe
6
 
7
  from faster_whisper_server.audio import Audio
8
  from faster_whisper_server.core import Segment, Transcription, Word
9
+
10
+ logger = logging.getLogger(__name__)
11
 
12
 
13
  class FasterWhisperASR:
src/faster_whisper_server/audio.py CHANGED
@@ -1,13 +1,13 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  from typing import TYPE_CHECKING, BinaryIO
5
 
6
  import numpy as np
7
  import soundfile as sf
8
 
9
  from faster_whisper_server.config import SAMPLES_PER_SECOND
10
- from faster_whisper_server.logger import logger
11
 
12
  if TYPE_CHECKING:
13
  from collections.abc import AsyncGenerator
@@ -15,6 +15,9 @@ if TYPE_CHECKING:
15
  from numpy.typing import NDArray
16
 
17
 
 
 
 
18
  def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
19
  audio_and_sample_rate = sf.read(
20
  file,
 
1
  from __future__ import annotations
2
 
3
  import asyncio
4
+ import logging
5
  from typing import TYPE_CHECKING, BinaryIO
6
 
7
  import numpy as np
8
  import soundfile as sf
9
 
10
  from faster_whisper_server.config import SAMPLES_PER_SECOND
 
11
 
12
  if TYPE_CHECKING:
13
  from collections.abc import AsyncGenerator
 
15
  from numpy.typing import NDArray
16
 
17
 
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
  def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
22
  audio_and_sample_rate = sf.read(
23
  file,
src/faster_whisper_server/config.py CHANGED
@@ -238,6 +238,3 @@ class Config(BaseSettings):
238
  f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
239
  )
240
  return self
241
-
242
-
243
- config = Config()
 
238
  f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
239
  )
240
  return self
 
 
 
src/faster_whisper_server/core.py CHANGED
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
5
 
6
  from pydantic import BaseModel
7
 
8
- from faster_whisper_server.config import config
9
 
10
  if TYPE_CHECKING:
11
  from collections.abc import Iterable
@@ -113,6 +113,7 @@ class Transcription:
113
  self.words.extend(words)
114
 
115
  def _ensure_no_word_overlap(self, words: list[Word]) -> None:
 
116
  if len(self.words) > 0 and len(words) > 0:
117
  if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
118
  raise ValueError(
 
5
 
6
  from pydantic import BaseModel
7
 
8
+ from faster_whisper_server.dependencies import get_config
9
 
10
  if TYPE_CHECKING:
11
  from collections.abc import Iterable
 
113
  self.words.extend(words)
114
 
115
  def _ensure_no_word_overlap(self, words: list[Word]) -> None:
116
+ config = get_config() # HACK
117
  if len(self.words) > 0 and len(words) > 0:
118
  if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
119
  raise ValueError(
src/faster_whisper_server/dependencies.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Annotated
3
+
4
+ from fastapi import Depends
5
+
6
+ from faster_whisper_server.config import Config
7
+ from faster_whisper_server.model_manager import ModelManager
8
+
9
+
10
+ @lru_cache
11
+ def get_config() -> Config:
12
+ return Config()
13
+
14
+
15
+ ConfigDependency = Annotated[Config, Depends(get_config)]
16
+
17
+
18
+ @lru_cache
19
+ def get_model_manager() -> ModelManager:
20
+ config = get_config() # HACK
21
+ return ModelManager(config)
22
+
23
+
24
+ ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
src/faster_whisper_server/hf_utils.py CHANGED
@@ -1,10 +1,11 @@
1
  from collections.abc import Generator
 
2
  from pathlib import Path
3
  import typing
4
 
5
  import huggingface_hub
6
 
7
- from faster_whisper_server.logger import logger
8
 
9
  LIBRARY_NAME = "ctranslate2"
10
  TASK_NAME = "automatic-speech-recognition"
 
1
  from collections.abc import Generator
2
+ import logging
3
  from pathlib import Path
4
  import typing
5
 
6
  import huggingface_hub
7
 
8
+ logger = logging.getLogger(__name__)
9
 
10
  LIBRARY_NAME = "ctranslate2"
11
  TASK_NAME = "automatic-speech-recognition"
src/faster_whisper_server/logger.py CHANGED
@@ -1,8 +1,11 @@
1
  import logging
2
 
3
- from faster_whisper_server.config import config
4
 
5
- logging.getLogger().setLevel(logging.INFO)
6
- logger = logging.getLogger(__name__)
7
- logger.setLevel(config.log_level.upper())
8
- logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s")
 
 
 
 
1
  import logging
2
 
3
+ from faster_whisper_server.dependencies import get_config
4
 
5
+
6
+ def setup_logger() -> None:
7
+ config = get_config() # HACK
8
+ logging.getLogger().setLevel(logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(config.log_level.upper())
11
+ logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s")
src/faster_whisper_server/main.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  from contextlib import asynccontextmanager
 
4
  from typing import TYPE_CHECKING
5
 
6
  from fastapi import (
@@ -8,11 +9,8 @@ from fastapi import (
8
  )
9
  from fastapi.middleware.cors import CORSMiddleware
10
 
11
- from faster_whisper_server.config import (
12
- config,
13
- )
14
- from faster_whisper_server.logger import logger
15
- from faster_whisper_server.model_manager import model_manager
16
  from faster_whisper_server.routers.list_models import (
17
  router as list_models_router,
18
  )
@@ -27,34 +25,42 @@ if TYPE_CHECKING:
27
  from collections.abc import AsyncGenerator
28
 
29
 
30
- logger.debug(f"Config: {config}")
 
 
 
 
 
 
31
 
 
32
 
33
- @asynccontextmanager
34
- async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
35
- for model_name in config.preload_models:
36
- model_manager.load_model(model_name)
37
- yield
38
 
 
39
 
40
- app = FastAPI(lifespan=lifespan)
 
 
41
 
42
- app.include_router(stt_router)
43
- app.include_router(list_models_router)
44
- app.include_router(misc_router)
 
 
 
 
 
45
 
46
- if config.allow_origins is not None:
47
- app.add_middleware(
48
- CORSMiddleware,
49
- allow_origins=config.allow_origins,
50
- allow_credentials=True,
51
- allow_methods=["*"],
52
- allow_headers=["*"],
53
- )
54
 
55
- if config.enable_ui:
56
- import gradio as gr
57
 
58
- from faster_whisper_server.gradio_app import create_gradio_demo
59
 
60
- app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")
 
1
  from __future__ import annotations
2
 
3
  from contextlib import asynccontextmanager
4
+ import logging
5
  from typing import TYPE_CHECKING
6
 
7
  from fastapi import (
 
9
  )
10
  from fastapi.middleware.cors import CORSMiddleware
11
 
12
+ from faster_whisper_server.dependencies import get_config, get_model_manager
13
+ from faster_whisper_server.logger import setup_logger
 
 
 
14
  from faster_whisper_server.routers.list_models import (
15
  router as list_models_router,
16
  )
 
25
  from collections.abc import AsyncGenerator
26
 
27
 
28
+ def create_app() -> FastAPI:
29
+ setup_logger()
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+ config = get_config() # HACK
34
+ logger.debug(f"Config: {config}")
35
 
36
+ model_manager = get_model_manager() # HACK
37
 
38
+ @asynccontextmanager
39
+ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
40
+ for model_name in config.preload_models:
41
+ model_manager.load_model(model_name)
42
+ yield
43
 
44
+ app = FastAPI(lifespan=lifespan)
45
 
46
+ app.include_router(stt_router)
47
+ app.include_router(list_models_router)
48
+ app.include_router(misc_router)
49
 
50
+ if config.allow_origins is not None:
51
+ app.add_middleware(
52
+ CORSMiddleware,
53
+ allow_origins=config.allow_origins,
54
+ allow_credentials=True,
55
+ allow_methods=["*"],
56
+ allow_headers=["*"],
57
+ )
58
 
59
+ if config.enable_ui:
60
+ import gradio as gr
 
 
 
 
 
 
61
 
62
+ from faster_whisper_server.gradio_app import create_gradio_demo
 
63
 
64
+ app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")
65
 
66
+ return app
src/faster_whisper_server/model_manager.py CHANGED
@@ -2,27 +2,34 @@ from __future__ import annotations
2
 
3
  from collections import OrderedDict
4
  import gc
 
5
  import time
 
6
 
7
  from faster_whisper import WhisperModel
8
 
9
- from faster_whisper_server.config import (
10
- config,
11
- )
12
- from faster_whisper_server.logger import logger
 
 
13
 
14
 
15
  class ModelManager:
16
- def __init__(self) -> None:
 
17
  self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
18
 
19
  def load_model(self, model_name: str) -> WhisperModel:
20
  if model_name in self.loaded_models:
21
  logger.debug(f"{model_name} model already loaded")
22
  return self.loaded_models[model_name]
23
- if len(self.loaded_models) >= config.max_models:
24
  oldest_model_name = next(iter(self.loaded_models))
25
- logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
 
 
26
  del self.loaded_models[oldest_model_name]
27
  gc.collect()
28
  logger.debug(f"Loading {model_name}...")
@@ -30,17 +37,14 @@ class ModelManager:
30
  # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
31
  whisper = WhisperModel(
32
  model_name,
33
- device=config.whisper.inference_device,
34
- device_index=config.whisper.device_index,
35
- compute_type=config.whisper.compute_type,
36
- cpu_threads=config.whisper.cpu_threads,
37
- num_workers=config.whisper.num_workers,
38
  )
39
  logger.info(
40
- f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
41
  )
42
  self.loaded_models[model_name] = whisper
43
  return whisper
44
-
45
-
46
- model_manager = ModelManager()
 
2
 
3
  from collections import OrderedDict
4
  import gc
5
+ import logging
6
  import time
7
+ from typing import TYPE_CHECKING
8
 
9
  from faster_whisper import WhisperModel
10
 
11
+ if TYPE_CHECKING:
12
+ from faster_whisper_server.config import (
13
+ Config,
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
 
18
 
19
  class ModelManager:
20
+ def __init__(self, config: Config) -> None:
21
+ self.config = config
22
  self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
23
 
24
  def load_model(self, model_name: str) -> WhisperModel:
25
  if model_name in self.loaded_models:
26
  logger.debug(f"{model_name} model already loaded")
27
  return self.loaded_models[model_name]
28
+ if len(self.loaded_models) >= self.config.max_models:
29
  oldest_model_name = next(iter(self.loaded_models))
30
+ logger.info(
31
+ f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
32
+ )
33
  del self.loaded_models[oldest_model_name]
34
  gc.collect()
35
  logger.debug(f"Loading {model_name}...")
 
37
  # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
38
  whisper = WhisperModel(
39
  model_name,
40
+ device=self.config.whisper.inference_device,
41
+ device_index=self.config.whisper.device_index,
42
+ compute_type=self.config.whisper.compute_type,
43
+ cpu_threads=self.config.whisper.cpu_threads,
44
+ num_workers=self.config.whisper.num_workers,
45
  )
46
  logger.info(
47
+ f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501
48
  )
49
  self.loaded_models[model_name] = whisper
50
  return whisper
 
 
 
src/faster_whisper_server/routers/misc.py CHANGED
@@ -6,11 +6,12 @@ from fastapi import (
6
  APIRouter,
7
  Response,
8
  )
9
- from faster_whisper_server import hf_utils
10
- from faster_whisper_server.model_manager import model_manager
11
  import huggingface_hub
12
  from huggingface_hub.hf_api import RepositoryNotFoundError
13
 
 
 
 
14
  router = APIRouter()
15
 
16
 
@@ -31,12 +32,14 @@ def pull_model(model_name: str) -> Response:
31
 
32
 
33
  @router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
34
- def get_running_models() -> dict[str, list[str]]:
 
 
35
  return {"models": list(model_manager.loaded_models.keys())}
36
 
37
 
38
  @router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
39
- def load_model_route(model_name: str) -> Response:
40
  if model_name in model_manager.loaded_models:
41
  return Response(status_code=409, content="Model already loaded")
42
  model_manager.load_model(model_name)
@@ -44,7 +47,7 @@ def load_model_route(model_name: str) -> Response:
44
 
45
 
46
  @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
47
- def stop_running_model(model_name: str) -> Response:
48
  model = model_manager.loaded_models.get(model_name)
49
  if model is not None:
50
  del model_manager.loaded_models[model_name]
 
6
  APIRouter,
7
  Response,
8
  )
 
 
9
  import huggingface_hub
10
  from huggingface_hub.hf_api import RepositoryNotFoundError
11
 
12
+ from faster_whisper_server import hf_utils
13
+ from faster_whisper_server.dependencies import ModelManagerDependency # noqa: TCH001
14
+
15
  router = APIRouter()
16
 
17
 
 
32
 
33
 
34
  @router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
35
+ def get_running_models(
36
+ model_manager: ModelManagerDependency,
37
+ ) -> dict[str, list[str]]:
38
  return {"models": list(model_manager.loaded_models.keys())}
39
 
40
 
41
  @router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
42
+ def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
43
  if model_name in model_manager.loaded_models:
44
  return Response(status_code=409, content="Model already loaded")
45
  model_manager.load_model(model_name)
 
47
 
48
 
49
  @router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
50
+ def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
51
  model = model_manager.loaded_models.get(model_name)
52
  if model is not None:
53
  del model_manager.loaded_models[model_name]
src/faster_whisper_server/routers/stt.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import asyncio
4
  from io import BytesIO
 
5
  from typing import TYPE_CHECKING, Annotated, Literal
6
 
7
  from fastapi import (
@@ -16,6 +17,8 @@ from fastapi import (
16
  from fastapi.responses import StreamingResponse
17
  from fastapi.websockets import WebSocketState
18
  from faster_whisper.vad import VadOptions, get_speech_timestamps
 
 
19
  from faster_whisper_server.asr import FasterWhisperASR
20
  from faster_whisper_server.audio import AudioStream, audio_samples_from_file
21
  from faster_whisper_server.config import (
@@ -23,17 +26,14 @@ from faster_whisper_server.config import (
23
  Language,
24
  ResponseFormat,
25
  Task,
26
- config,
27
  )
28
  from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
29
- from faster_whisper_server.logger import logger
30
- from faster_whisper_server.model_manager import model_manager
31
  from faster_whisper_server.server_models import (
32
  TranscriptionJsonResponse,
33
  TranscriptionVerboseJsonResponse,
34
  )
35
  from faster_whisper_server.transcriber import audio_transcriber
36
- from pydantic import AfterValidator
37
 
38
  if TYPE_CHECKING:
39
  from collections.abc import Generator, Iterable
@@ -41,6 +41,8 @@ if TYPE_CHECKING:
41
  from faster_whisper.transcribe import TranscriptionInfo
42
 
43
 
 
 
44
  router = APIRouter()
45
 
46
 
@@ -103,6 +105,7 @@ def handle_default_openai_model(model_name: str) -> str:
103
 
104
  For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
105
  """
 
106
  if model_name == "whisper-1":
107
  logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
108
  return config.whisper.model
@@ -117,13 +120,19 @@ ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
117
  response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
118
  )
119
  def translate_file(
 
 
120
  file: Annotated[UploadFile, Form()],
121
- model: Annotated[ModelName, Form()] = config.whisper.model,
122
  prompt: Annotated[str | None, Form()] = None,
123
- response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
124
  temperature: Annotated[float, Form()] = 0.0,
125
  stream: Annotated[bool, Form()] = False,
126
  ) -> Response | StreamingResponse:
 
 
 
 
127
  whisper = model_manager.load_model(model)
128
  segments, transcription_info = whisper.transcribe(
129
  file.file,
@@ -147,11 +156,13 @@ def translate_file(
147
  response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
148
  )
149
  def transcribe_file(
 
 
150
  file: Annotated[UploadFile, Form()],
151
- model: Annotated[ModelName, Form()] = config.whisper.model,
152
- language: Annotated[Language | None, Form()] = config.default_language,
153
  prompt: Annotated[str | None, Form()] = None,
154
- response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
155
  temperature: Annotated[float, Form()] = 0.0,
156
  timestamp_granularities: Annotated[
157
  list[Literal["segment", "word"]],
@@ -160,6 +171,12 @@ def transcribe_file(
160
  stream: Annotated[bool, Form()] = False,
161
  hotwords: Annotated[str | None, Form()] = None,
162
  ) -> Response | StreamingResponse:
 
 
 
 
 
 
163
  whisper = model_manager.load_model(model)
164
  segments, transcription_info = whisper.transcribe(
165
  file.file,
@@ -180,6 +197,7 @@ def transcribe_file(
180
 
181
 
182
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
 
183
  try:
184
  while True:
185
  bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
@@ -211,12 +229,20 @@ async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
211
 
212
  @router.websocket("/v1/audio/transcriptions")
213
  async def transcribe_stream(
 
 
214
  ws: WebSocket,
215
- model: Annotated[ModelName, Query()] = config.whisper.model,
216
- language: Annotated[Language | None, Query()] = config.default_language,
217
- response_format: Annotated[ResponseFormat, Query()] = config.default_response_format,
218
  temperature: Annotated[float, Query()] = 0.0,
219
  ) -> None:
 
 
 
 
 
 
220
  await ws.accept()
221
  transcribe_opts = {
222
  "language": language,
@@ -229,7 +255,7 @@ async def transcribe_stream(
229
  audio_stream = AudioStream()
230
  async with asyncio.TaskGroup() as tg:
231
  tg.create_task(audio_receiver(ws, audio_stream))
232
- async for transcription in audio_transcriber(asr, audio_stream):
233
  logger.debug(f"Sending transcription: {transcription.text}")
234
  if ws.client_state == WebSocketState.DISCONNECTED:
235
  break
 
2
 
3
  import asyncio
4
  from io import BytesIO
5
+ import logging
6
  from typing import TYPE_CHECKING, Annotated, Literal
7
 
8
  from fastapi import (
 
17
  from fastapi.responses import StreamingResponse
18
  from fastapi.websockets import WebSocketState
19
  from faster_whisper.vad import VadOptions, get_speech_timestamps
20
+ from pydantic import AfterValidator
21
+
22
  from faster_whisper_server.asr import FasterWhisperASR
23
  from faster_whisper_server.audio import AudioStream, audio_samples_from_file
24
  from faster_whisper_server.config import (
 
26
  Language,
27
  ResponseFormat,
28
  Task,
 
29
  )
30
  from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
31
+ from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
 
32
  from faster_whisper_server.server_models import (
33
  TranscriptionJsonResponse,
34
  TranscriptionVerboseJsonResponse,
35
  )
36
  from faster_whisper_server.transcriber import audio_transcriber
 
37
 
38
  if TYPE_CHECKING:
39
  from collections.abc import Generator, Iterable
 
41
  from faster_whisper.transcribe import TranscriptionInfo
42
 
43
 
44
+ logger = logging.getLogger(__name__)
45
+
46
  router = APIRouter()
47
 
48
 
 
105
 
106
  For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
107
  """
108
+ config = get_config() # HACK
109
  if model_name == "whisper-1":
110
  logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
111
  return config.whisper.model
 
120
  response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
121
  )
122
  def translate_file(
123
+ config: ConfigDependency,
124
+ model_manager: ModelManagerDependency,
125
  file: Annotated[UploadFile, Form()],
126
+ model: Annotated[ModelName | None, Form()] = None,
127
  prompt: Annotated[str | None, Form()] = None,
128
+ response_format: Annotated[ResponseFormat | None, Form()] = None,
129
  temperature: Annotated[float, Form()] = 0.0,
130
  stream: Annotated[bool, Form()] = False,
131
  ) -> Response | StreamingResponse:
132
+ if model is None:
133
+ model = config.whisper.model
134
+ if response_format is None:
135
+ response_format = config.default_response_format
136
  whisper = model_manager.load_model(model)
137
  segments, transcription_info = whisper.transcribe(
138
  file.file,
 
156
  response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
157
  )
158
  def transcribe_file(
159
+ config: ConfigDependency,
160
+ model_manager: ModelManagerDependency,
161
  file: Annotated[UploadFile, Form()],
162
+ model: Annotated[ModelName | None, Form()] = None,
163
+ language: Annotated[Language | None, Form()] = None,
164
  prompt: Annotated[str | None, Form()] = None,
165
+ response_format: Annotated[ResponseFormat | None, Form()] = None,
166
  temperature: Annotated[float, Form()] = 0.0,
167
  timestamp_granularities: Annotated[
168
  list[Literal["segment", "word"]],
 
171
  stream: Annotated[bool, Form()] = False,
172
  hotwords: Annotated[str | None, Form()] = None,
173
  ) -> Response | StreamingResponse:
174
+ if model is None:
175
+ model = config.whisper.model
176
+ if language is None:
177
+ language = config.default_language
178
+ if response_format is None:
179
+ response_format = config.default_response_format
180
  whisper = model_manager.load_model(model)
181
  segments, transcription_info = whisper.transcribe(
182
  file.file,
 
197
 
198
 
199
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
200
+ config = get_config() # HACK
201
  try:
202
  while True:
203
  bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
 
229
 
230
  @router.websocket("/v1/audio/transcriptions")
231
  async def transcribe_stream(
232
+ config: ConfigDependency,
233
+ model_manager: ModelManagerDependency,
234
  ws: WebSocket,
235
+ model: Annotated[ModelName | None, Query()] = None,
236
+ language: Annotated[Language | None, Query()] = None,
237
+ response_format: Annotated[ResponseFormat | None, Query()] = None,
238
  temperature: Annotated[float, Query()] = 0.0,
239
  ) -> None:
240
+ if model is None:
241
+ model = config.whisper.model
242
+ if language is None:
243
+ language = config.default_language
244
+ if response_format is None:
245
+ response_format = config.default_response_format
246
  await ws.accept()
247
  transcribe_opts = {
248
  "language": language,
 
255
  audio_stream = AudioStream()
256
  async with asyncio.TaskGroup() as tg:
257
  tg.create_task(audio_receiver(ws, audio_stream))
258
+ async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
259
  logger.debug(f"Sending transcription: {transcription.text}")
260
  if ws.client_state == WebSocketState.DISCONNECTED:
261
  break
src/faster_whisper_server/transcriber.py CHANGED
@@ -1,17 +1,18 @@
1
  from __future__ import annotations
2
 
 
3
  from typing import TYPE_CHECKING
4
 
5
  from faster_whisper_server.audio import Audio, AudioStream
6
- from faster_whisper_server.config import config
7
  from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text
8
- from faster_whisper_server.logger import logger
9
 
10
  if TYPE_CHECKING:
11
  from collections.abc import AsyncGenerator
12
 
13
  from faster_whisper_server.asr import FasterWhisperASR
14
 
 
 
15
 
16
  class LocalAgreement:
17
  def __init__(self) -> None:
@@ -47,11 +48,12 @@ def prompt(confirmed: Transcription) -> str | None:
47
  async def audio_transcriber(
48
  asr: FasterWhisperASR,
49
  audio_stream: AudioStream,
 
50
  ) -> AsyncGenerator[Transcription, None]:
51
  local_agreement = LocalAgreement()
52
  full_audio = Audio()
53
  confirmed = Transcription()
54
- async for chunk in audio_stream.chunks(config.min_duration):
55
  full_audio.extend(chunk)
56
  audio = full_audio.after(needs_audio_after(confirmed))
57
  transcription, _ = await asr.transcribe(audio, prompt(confirmed))
 
1
  from __future__ import annotations
2
 
3
+ import logging
4
  from typing import TYPE_CHECKING
5
 
6
  from faster_whisper_server.audio import Audio, AudioStream
 
7
  from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text
 
8
 
9
  if TYPE_CHECKING:
10
  from collections.abc import AsyncGenerator
11
 
12
  from faster_whisper_server.asr import FasterWhisperASR
13
 
14
+ logger = logging.getLogger(__name__)
15
+
16
 
17
  class LocalAgreement:
18
  def __init__(self) -> None:
 
48
  async def audio_transcriber(
49
  asr: FasterWhisperASR,
50
  audio_stream: AudioStream,
51
+ min_duration: float,
52
  ) -> AsyncGenerator[Transcription, None]:
53
  local_agreement = LocalAgreement()
54
  full_audio = Audio()
55
  confirmed = Transcription()
56
+ async for chunk in audio_stream.chunks(min_duration):
57
  full_audio.extend(chunk)
58
  audio = full_audio.after(needs_audio_after(confirmed))
59
  transcription, _ = await asr.transcribe(audio, prompt(confirmed))
tests/conftest.py CHANGED
@@ -1,7 +1,9 @@
1
  from collections.abc import AsyncGenerator, Generator
2
  import logging
 
3
 
4
  from fastapi.testclient import TestClient
 
5
  from httpx import ASGITransport, AsyncClient
6
  from openai import OpenAI
7
  import pytest
@@ -18,17 +20,15 @@ def pytest_configure() -> None:
18
 
19
  @pytest.fixture()
20
  def client() -> Generator[TestClient, None, None]:
21
- from faster_whisper_server.main import app
22
-
23
- with TestClient(app) as client:
24
  yield client
25
 
26
 
27
  @pytest_asyncio.fixture()
28
  async def aclient() -> AsyncGenerator[AsyncClient, None]:
29
- from faster_whisper_server.main import app
30
-
31
- async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
32
  yield aclient
33
 
34
 
 
1
  from collections.abc import AsyncGenerator, Generator
2
  import logging
3
+ import os
4
 
5
  from fastapi.testclient import TestClient
6
+ from faster_whisper_server.main import create_app
7
  from httpx import ASGITransport, AsyncClient
8
  from openai import OpenAI
9
  import pytest
 
20
 
21
  @pytest.fixture()
22
  def client() -> Generator[TestClient, None, None]:
23
+ os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
24
+ with TestClient(create_app()) as client:
 
25
  yield client
26
 
27
 
28
  @pytest_asyncio.fixture()
29
  async def aclient() -> AsyncGenerator[AsyncClient, None]:
30
+ os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
31
+ async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
 
32
  yield aclient
33
 
34