Spaces:
Running
on
T4
Running
on
T4
Fedir Zadniprovskyi
commited on
Commit
•
bf48682
1
Parent(s):
8f3dcc9
feat: dependency injection
Browse filesThe main purpose of this change is to allow modifying the configuration
for testing. This change does lead to some ugly code where `get_config`
function gets called in random places.
- Taskfile.yaml +1 -1
- pyproject.toml +1 -0
- src/faster_whisper_server/asr.py +3 -1
- src/faster_whisper_server/audio.py +4 -1
- src/faster_whisper_server/config.py +0 -3
- src/faster_whisper_server/core.py +2 -1
- src/faster_whisper_server/dependencies.py +24 -0
- src/faster_whisper_server/hf_utils.py +2 -1
- src/faster_whisper_server/logger.py +8 -5
- src/faster_whisper_server/main.py +33 -27
- src/faster_whisper_server/model_manager.py +20 -16
- src/faster_whisper_server/routers/misc.py +8 -5
- src/faster_whisper_server/routers/stt.py +39 -13
- src/faster_whisper_server/transcriber.py +5 -3
- tests/conftest.py +6 -6
Taskfile.yaml
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
version: "3"
|
2 |
tasks:
|
3 |
-
server: uvicorn --host 0.0.0.0 faster_whisper_server.main:
|
4 |
test:
|
5 |
cmds:
|
6 |
- pytest -o log_cli=true -o log_cli_level=DEBUG {{.CLI_ARGS}}
|
|
|
1 |
version: "3"
|
2 |
tasks:
|
3 |
+
server: uvicorn --factory --host 0.0.0.0 faster_whisper_server.main:create_app {{.CLI_ARGS}}
|
4 |
test:
|
5 |
cmds:
|
6 |
- pytest -o log_cli=true -o log_cli_level=DEBUG {{.CLI_ARGS}}
|
pyproject.toml
CHANGED
@@ -75,6 +75,7 @@ ignore = [
|
|
75 |
"ISC001", # recommended to disable for formatting
|
76 |
"INP001",
|
77 |
"PT018",
|
|
|
78 |
]
|
79 |
|
80 |
[tool.ruff.lint.isort]
|
|
|
75 |
"ISC001", # recommended to disable for formatting
|
76 |
"INP001",
|
77 |
"PT018",
|
78 |
+
"G004", # logging f string
|
79 |
]
|
80 |
|
81 |
[tool.ruff.lint.isort]
|
src/faster_whisper_server/asr.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
import asyncio
|
|
|
2 |
import time
|
3 |
|
4 |
from faster_whisper import transcribe
|
5 |
|
6 |
from faster_whisper_server.audio import Audio
|
7 |
from faster_whisper_server.core import Segment, Transcription, Word
|
8 |
-
|
|
|
9 |
|
10 |
|
11 |
class FasterWhisperASR:
|
|
|
1 |
import asyncio
|
2 |
+
import logging
|
3 |
import time
|
4 |
|
5 |
from faster_whisper import transcribe
|
6 |
|
7 |
from faster_whisper_server.audio import Audio
|
8 |
from faster_whisper_server.core import Segment, Transcription, Word
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
|
12 |
|
13 |
class FasterWhisperASR:
|
src/faster_whisper_server/audio.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
|
|
4 |
from typing import TYPE_CHECKING, BinaryIO
|
5 |
|
6 |
import numpy as np
|
7 |
import soundfile as sf
|
8 |
|
9 |
from faster_whisper_server.config import SAMPLES_PER_SECOND
|
10 |
-
from faster_whisper_server.logger import logger
|
11 |
|
12 |
if TYPE_CHECKING:
|
13 |
from collections.abc import AsyncGenerator
|
@@ -15,6 +15,9 @@ if TYPE_CHECKING:
|
|
15 |
from numpy.typing import NDArray
|
16 |
|
17 |
|
|
|
|
|
|
|
18 |
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
|
19 |
audio_and_sample_rate = sf.read(
|
20 |
file,
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
4 |
+
import logging
|
5 |
from typing import TYPE_CHECKING, BinaryIO
|
6 |
|
7 |
import numpy as np
|
8 |
import soundfile as sf
|
9 |
|
10 |
from faster_whisper_server.config import SAMPLES_PER_SECOND
|
|
|
11 |
|
12 |
if TYPE_CHECKING:
|
13 |
from collections.abc import AsyncGenerator
|
|
|
15 |
from numpy.typing import NDArray
|
16 |
|
17 |
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
|
22 |
audio_and_sample_rate = sf.read(
|
23 |
file,
|
src/faster_whisper_server/config.py
CHANGED
@@ -238,6 +238,3 @@ class Config(BaseSettings):
|
|
238 |
f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
|
239 |
)
|
240 |
return self
|
241 |
-
|
242 |
-
|
243 |
-
config = Config()
|
|
|
238 |
f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
|
239 |
)
|
240 |
return self
|
|
|
|
|
|
src/faster_whisper_server/core.py
CHANGED
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
|
5 |
|
6 |
from pydantic import BaseModel
|
7 |
|
8 |
-
from faster_whisper_server.
|
9 |
|
10 |
if TYPE_CHECKING:
|
11 |
from collections.abc import Iterable
|
@@ -113,6 +113,7 @@ class Transcription:
|
|
113 |
self.words.extend(words)
|
114 |
|
115 |
def _ensure_no_word_overlap(self, words: list[Word]) -> None:
|
|
|
116 |
if len(self.words) > 0 and len(words) > 0:
|
117 |
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
|
118 |
raise ValueError(
|
|
|
5 |
|
6 |
from pydantic import BaseModel
|
7 |
|
8 |
+
from faster_whisper_server.dependencies import get_config
|
9 |
|
10 |
if TYPE_CHECKING:
|
11 |
from collections.abc import Iterable
|
|
|
113 |
self.words.extend(words)
|
114 |
|
115 |
def _ensure_no_word_overlap(self, words: list[Word]) -> None:
|
116 |
+
config = get_config() # HACK
|
117 |
if len(self.words) > 0 and len(words) > 0:
|
118 |
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
|
119 |
raise ValueError(
|
src/faster_whisper_server/dependencies.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Annotated
|
3 |
+
|
4 |
+
from fastapi import Depends
|
5 |
+
|
6 |
+
from faster_whisper_server.config import Config
|
7 |
+
from faster_whisper_server.model_manager import ModelManager
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache
|
11 |
+
def get_config() -> Config:
|
12 |
+
return Config()
|
13 |
+
|
14 |
+
|
15 |
+
ConfigDependency = Annotated[Config, Depends(get_config)]
|
16 |
+
|
17 |
+
|
18 |
+
@lru_cache
|
19 |
+
def get_model_manager() -> ModelManager:
|
20 |
+
config = get_config() # HACK
|
21 |
+
return ModelManager(config)
|
22 |
+
|
23 |
+
|
24 |
+
ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
|
src/faster_whisper_server/hf_utils.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
from collections.abc import Generator
|
|
|
2 |
from pathlib import Path
|
3 |
import typing
|
4 |
|
5 |
import huggingface_hub
|
6 |
|
7 |
-
|
8 |
|
9 |
LIBRARY_NAME = "ctranslate2"
|
10 |
TASK_NAME = "automatic-speech-recognition"
|
|
|
1 |
from collections.abc import Generator
|
2 |
+
import logging
|
3 |
from pathlib import Path
|
4 |
import typing
|
5 |
|
6 |
import huggingface_hub
|
7 |
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
|
10 |
LIBRARY_NAME = "ctranslate2"
|
11 |
TASK_NAME = "automatic-speech-recognition"
|
src/faster_whisper_server/logger.py
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
import logging
|
2 |
|
3 |
-
from faster_whisper_server.
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
logging.
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
|
3 |
+
from faster_whisper_server.dependencies import get_config
|
4 |
|
5 |
+
|
6 |
+
def setup_logger() -> None:
|
7 |
+
config = get_config() # HACK
|
8 |
+
logging.getLogger().setLevel(logging.INFO)
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
logger.setLevel(config.log_level.upper())
|
11 |
+
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s")
|
src/faster_whisper_server/main.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
from contextlib import asynccontextmanager
|
|
|
4 |
from typing import TYPE_CHECKING
|
5 |
|
6 |
from fastapi import (
|
@@ -8,11 +9,8 @@ from fastapi import (
|
|
8 |
)
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
|
11 |
-
from faster_whisper_server.
|
12 |
-
|
13 |
-
)
|
14 |
-
from faster_whisper_server.logger import logger
|
15 |
-
from faster_whisper_server.model_manager import model_manager
|
16 |
from faster_whisper_server.routers.list_models import (
|
17 |
router as list_models_router,
|
18 |
)
|
@@ -27,34 +25,42 @@ if TYPE_CHECKING:
|
|
27 |
from collections.abc import AsyncGenerator
|
28 |
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
|
|
32 |
|
33 |
-
@asynccontextmanager
|
34 |
-
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
|
|
|
39 |
|
40 |
-
app
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
app.
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
if config.
|
47 |
-
|
48 |
-
CORSMiddleware,
|
49 |
-
allow_origins=config.allow_origins,
|
50 |
-
allow_credentials=True,
|
51 |
-
allow_methods=["*"],
|
52 |
-
allow_headers=["*"],
|
53 |
-
)
|
54 |
|
55 |
-
|
56 |
-
import gradio as gr
|
57 |
|
58 |
-
|
59 |
|
60 |
-
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
from contextlib import asynccontextmanager
|
4 |
+
import logging
|
5 |
from typing import TYPE_CHECKING
|
6 |
|
7 |
from fastapi import (
|
|
|
9 |
)
|
10 |
from fastapi.middleware.cors import CORSMiddleware
|
11 |
|
12 |
+
from faster_whisper_server.dependencies import get_config, get_model_manager
|
13 |
+
from faster_whisper_server.logger import setup_logger
|
|
|
|
|
|
|
14 |
from faster_whisper_server.routers.list_models import (
|
15 |
router as list_models_router,
|
16 |
)
|
|
|
25 |
from collections.abc import AsyncGenerator
|
26 |
|
27 |
|
28 |
+
def create_app() -> FastAPI:
|
29 |
+
setup_logger()
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
config = get_config() # HACK
|
34 |
+
logger.debug(f"Config: {config}")
|
35 |
|
36 |
+
model_manager = get_model_manager() # HACK
|
37 |
|
38 |
+
@asynccontextmanager
|
39 |
+
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
40 |
+
for model_name in config.preload_models:
|
41 |
+
model_manager.load_model(model_name)
|
42 |
+
yield
|
43 |
|
44 |
+
app = FastAPI(lifespan=lifespan)
|
45 |
|
46 |
+
app.include_router(stt_router)
|
47 |
+
app.include_router(list_models_router)
|
48 |
+
app.include_router(misc_router)
|
49 |
|
50 |
+
if config.allow_origins is not None:
|
51 |
+
app.add_middleware(
|
52 |
+
CORSMiddleware,
|
53 |
+
allow_origins=config.allow_origins,
|
54 |
+
allow_credentials=True,
|
55 |
+
allow_methods=["*"],
|
56 |
+
allow_headers=["*"],
|
57 |
+
)
|
58 |
|
59 |
+
if config.enable_ui:
|
60 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
from faster_whisper_server.gradio_app import create_gradio_demo
|
|
|
63 |
|
64 |
+
app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")
|
65 |
|
66 |
+
return app
|
src/faster_whisper_server/model_manager.py
CHANGED
@@ -2,27 +2,34 @@ from __future__ import annotations
|
|
2 |
|
3 |
from collections import OrderedDict
|
4 |
import gc
|
|
|
5 |
import time
|
|
|
6 |
|
7 |
from faster_whisper import WhisperModel
|
8 |
|
9 |
-
|
10 |
-
config
|
11 |
-
|
12 |
-
|
|
|
|
|
13 |
|
14 |
|
15 |
class ModelManager:
|
16 |
-
def __init__(self) -> None:
|
|
|
17 |
self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
18 |
|
19 |
def load_model(self, model_name: str) -> WhisperModel:
|
20 |
if model_name in self.loaded_models:
|
21 |
logger.debug(f"{model_name} model already loaded")
|
22 |
return self.loaded_models[model_name]
|
23 |
-
if len(self.loaded_models) >= config.max_models:
|
24 |
oldest_model_name = next(iter(self.loaded_models))
|
25 |
-
logger.info(
|
|
|
|
|
26 |
del self.loaded_models[oldest_model_name]
|
27 |
gc.collect()
|
28 |
logger.debug(f"Loading {model_name}...")
|
@@ -30,17 +37,14 @@ class ModelManager:
|
|
30 |
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
|
31 |
whisper = WhisperModel(
|
32 |
model_name,
|
33 |
-
device=config.whisper.inference_device,
|
34 |
-
device_index=config.whisper.device_index,
|
35 |
-
compute_type=config.whisper.compute_type,
|
36 |
-
cpu_threads=config.whisper.cpu_threads,
|
37 |
-
num_workers=config.whisper.num_workers,
|
38 |
)
|
39 |
logger.info(
|
40 |
-
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
|
41 |
)
|
42 |
self.loaded_models[model_name] = whisper
|
43 |
return whisper
|
44 |
-
|
45 |
-
|
46 |
-
model_manager = ModelManager()
|
|
|
2 |
|
3 |
from collections import OrderedDict
|
4 |
import gc
|
5 |
+
import logging
|
6 |
import time
|
7 |
+
from typing import TYPE_CHECKING
|
8 |
|
9 |
from faster_whisper import WhisperModel
|
10 |
|
11 |
+
if TYPE_CHECKING:
|
12 |
+
from faster_whisper_server.config import (
|
13 |
+
Config,
|
14 |
+
)
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
|
18 |
|
19 |
class ModelManager:
|
20 |
+
def __init__(self, config: Config) -> None:
|
21 |
+
self.config = config
|
22 |
self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
23 |
|
24 |
def load_model(self, model_name: str) -> WhisperModel:
|
25 |
if model_name in self.loaded_models:
|
26 |
logger.debug(f"{model_name} model already loaded")
|
27 |
return self.loaded_models[model_name]
|
28 |
+
if len(self.loaded_models) >= self.config.max_models:
|
29 |
oldest_model_name = next(iter(self.loaded_models))
|
30 |
+
logger.info(
|
31 |
+
f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
|
32 |
+
)
|
33 |
del self.loaded_models[oldest_model_name]
|
34 |
gc.collect()
|
35 |
logger.debug(f"Loading {model_name}...")
|
|
|
37 |
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
|
38 |
whisper = WhisperModel(
|
39 |
model_name,
|
40 |
+
device=self.config.whisper.inference_device,
|
41 |
+
device_index=self.config.whisper.device_index,
|
42 |
+
compute_type=self.config.whisper.compute_type,
|
43 |
+
cpu_threads=self.config.whisper.cpu_threads,
|
44 |
+
num_workers=self.config.whisper.num_workers,
|
45 |
)
|
46 |
logger.info(
|
47 |
+
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501
|
48 |
)
|
49 |
self.loaded_models[model_name] = whisper
|
50 |
return whisper
|
|
|
|
|
|
src/faster_whisper_server/routers/misc.py
CHANGED
@@ -6,11 +6,12 @@ from fastapi import (
|
|
6 |
APIRouter,
|
7 |
Response,
|
8 |
)
|
9 |
-
from faster_whisper_server import hf_utils
|
10 |
-
from faster_whisper_server.model_manager import model_manager
|
11 |
import huggingface_hub
|
12 |
from huggingface_hub.hf_api import RepositoryNotFoundError
|
13 |
|
|
|
|
|
|
|
14 |
router = APIRouter()
|
15 |
|
16 |
|
@@ -31,12 +32,14 @@ def pull_model(model_name: str) -> Response:
|
|
31 |
|
32 |
|
33 |
@router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
|
34 |
-
def get_running_models(
|
|
|
|
|
35 |
return {"models": list(model_manager.loaded_models.keys())}
|
36 |
|
37 |
|
38 |
@router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
|
39 |
-
def load_model_route(model_name: str) -> Response:
|
40 |
if model_name in model_manager.loaded_models:
|
41 |
return Response(status_code=409, content="Model already loaded")
|
42 |
model_manager.load_model(model_name)
|
@@ -44,7 +47,7 @@ def load_model_route(model_name: str) -> Response:
|
|
44 |
|
45 |
|
46 |
@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
|
47 |
-
def stop_running_model(model_name: str) -> Response:
|
48 |
model = model_manager.loaded_models.get(model_name)
|
49 |
if model is not None:
|
50 |
del model_manager.loaded_models[model_name]
|
|
|
6 |
APIRouter,
|
7 |
Response,
|
8 |
)
|
|
|
|
|
9 |
import huggingface_hub
|
10 |
from huggingface_hub.hf_api import RepositoryNotFoundError
|
11 |
|
12 |
+
from faster_whisper_server import hf_utils
|
13 |
+
from faster_whisper_server.dependencies import ModelManagerDependency # noqa: TCH001
|
14 |
+
|
15 |
router = APIRouter()
|
16 |
|
17 |
|
|
|
32 |
|
33 |
|
34 |
@router.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
|
35 |
+
def get_running_models(
|
36 |
+
model_manager: ModelManagerDependency,
|
37 |
+
) -> dict[str, list[str]]:
|
38 |
return {"models": list(model_manager.loaded_models.keys())}
|
39 |
|
40 |
|
41 |
@router.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
|
42 |
+
def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
|
43 |
if model_name in model_manager.loaded_models:
|
44 |
return Response(status_code=409, content="Model already loaded")
|
45 |
model_manager.load_model(model_name)
|
|
|
47 |
|
48 |
|
49 |
@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
|
50 |
+
def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
|
51 |
model = model_manager.loaded_models.get(model_name)
|
52 |
if model is not None:
|
53 |
del model_manager.loaded_models[model_name]
|
src/faster_whisper_server/routers/stt.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2 |
|
3 |
import asyncio
|
4 |
from io import BytesIO
|
|
|
5 |
from typing import TYPE_CHECKING, Annotated, Literal
|
6 |
|
7 |
from fastapi import (
|
@@ -16,6 +17,8 @@ from fastapi import (
|
|
16 |
from fastapi.responses import StreamingResponse
|
17 |
from fastapi.websockets import WebSocketState
|
18 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
|
|
|
|
19 |
from faster_whisper_server.asr import FasterWhisperASR
|
20 |
from faster_whisper_server.audio import AudioStream, audio_samples_from_file
|
21 |
from faster_whisper_server.config import (
|
@@ -23,17 +26,14 @@ from faster_whisper_server.config import (
|
|
23 |
Language,
|
24 |
ResponseFormat,
|
25 |
Task,
|
26 |
-
config,
|
27 |
)
|
28 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
|
29 |
-
from faster_whisper_server.
|
30 |
-
from faster_whisper_server.model_manager import model_manager
|
31 |
from faster_whisper_server.server_models import (
|
32 |
TranscriptionJsonResponse,
|
33 |
TranscriptionVerboseJsonResponse,
|
34 |
)
|
35 |
from faster_whisper_server.transcriber import audio_transcriber
|
36 |
-
from pydantic import AfterValidator
|
37 |
|
38 |
if TYPE_CHECKING:
|
39 |
from collections.abc import Generator, Iterable
|
@@ -41,6 +41,8 @@ if TYPE_CHECKING:
|
|
41 |
from faster_whisper.transcribe import TranscriptionInfo
|
42 |
|
43 |
|
|
|
|
|
44 |
router = APIRouter()
|
45 |
|
46 |
|
@@ -103,6 +105,7 @@ def handle_default_openai_model(model_name: str) -> str:
|
|
103 |
|
104 |
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
|
105 |
"""
|
|
|
106 |
if model_name == "whisper-1":
|
107 |
logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
|
108 |
return config.whisper.model
|
@@ -117,13 +120,19 @@ ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
|
|
117 |
response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
|
118 |
)
|
119 |
def translate_file(
|
|
|
|
|
120 |
file: Annotated[UploadFile, Form()],
|
121 |
-
model: Annotated[ModelName, Form()] =
|
122 |
prompt: Annotated[str | None, Form()] = None,
|
123 |
-
response_format: Annotated[ResponseFormat, Form()] =
|
124 |
temperature: Annotated[float, Form()] = 0.0,
|
125 |
stream: Annotated[bool, Form()] = False,
|
126 |
) -> Response | StreamingResponse:
|
|
|
|
|
|
|
|
|
127 |
whisper = model_manager.load_model(model)
|
128 |
segments, transcription_info = whisper.transcribe(
|
129 |
file.file,
|
@@ -147,11 +156,13 @@ def translate_file(
|
|
147 |
response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
|
148 |
)
|
149 |
def transcribe_file(
|
|
|
|
|
150 |
file: Annotated[UploadFile, Form()],
|
151 |
-
model: Annotated[ModelName, Form()] =
|
152 |
-
language: Annotated[Language | None, Form()] =
|
153 |
prompt: Annotated[str | None, Form()] = None,
|
154 |
-
response_format: Annotated[ResponseFormat, Form()] =
|
155 |
temperature: Annotated[float, Form()] = 0.0,
|
156 |
timestamp_granularities: Annotated[
|
157 |
list[Literal["segment", "word"]],
|
@@ -160,6 +171,12 @@ def transcribe_file(
|
|
160 |
stream: Annotated[bool, Form()] = False,
|
161 |
hotwords: Annotated[str | None, Form()] = None,
|
162 |
) -> Response | StreamingResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
whisper = model_manager.load_model(model)
|
164 |
segments, transcription_info = whisper.transcribe(
|
165 |
file.file,
|
@@ -180,6 +197,7 @@ def transcribe_file(
|
|
180 |
|
181 |
|
182 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
|
|
183 |
try:
|
184 |
while True:
|
185 |
bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
|
@@ -211,12 +229,20 @@ async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
|
211 |
|
212 |
@router.websocket("/v1/audio/transcriptions")
|
213 |
async def transcribe_stream(
|
|
|
|
|
214 |
ws: WebSocket,
|
215 |
-
model: Annotated[ModelName, Query()] =
|
216 |
-
language: Annotated[Language | None, Query()] =
|
217 |
-
response_format: Annotated[ResponseFormat, Query()] =
|
218 |
temperature: Annotated[float, Query()] = 0.0,
|
219 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
await ws.accept()
|
221 |
transcribe_opts = {
|
222 |
"language": language,
|
@@ -229,7 +255,7 @@ async def transcribe_stream(
|
|
229 |
audio_stream = AudioStream()
|
230 |
async with asyncio.TaskGroup() as tg:
|
231 |
tg.create_task(audio_receiver(ws, audio_stream))
|
232 |
-
async for transcription in audio_transcriber(asr, audio_stream):
|
233 |
logger.debug(f"Sending transcription: {transcription.text}")
|
234 |
if ws.client_state == WebSocketState.DISCONNECTED:
|
235 |
break
|
|
|
2 |
|
3 |
import asyncio
|
4 |
from io import BytesIO
|
5 |
+
import logging
|
6 |
from typing import TYPE_CHECKING, Annotated, Literal
|
7 |
|
8 |
from fastapi import (
|
|
|
17 |
from fastapi.responses import StreamingResponse
|
18 |
from fastapi.websockets import WebSocketState
|
19 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
20 |
+
from pydantic import AfterValidator
|
21 |
+
|
22 |
from faster_whisper_server.asr import FasterWhisperASR
|
23 |
from faster_whisper_server.audio import AudioStream, audio_samples_from_file
|
24 |
from faster_whisper_server.config import (
|
|
|
26 |
Language,
|
27 |
ResponseFormat,
|
28 |
Task,
|
|
|
29 |
)
|
30 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
|
31 |
+
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
|
|
|
32 |
from faster_whisper_server.server_models import (
|
33 |
TranscriptionJsonResponse,
|
34 |
TranscriptionVerboseJsonResponse,
|
35 |
)
|
36 |
from faster_whisper_server.transcriber import audio_transcriber
|
|
|
37 |
|
38 |
if TYPE_CHECKING:
|
39 |
from collections.abc import Generator, Iterable
|
|
|
41 |
from faster_whisper.transcribe import TranscriptionInfo
|
42 |
|
43 |
|
44 |
+
logger = logging.getLogger(__name__)
|
45 |
+
|
46 |
router = APIRouter()
|
47 |
|
48 |
|
|
|
105 |
|
106 |
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
|
107 |
"""
|
108 |
+
config = get_config() # HACK
|
109 |
if model_name == "whisper-1":
|
110 |
logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
|
111 |
return config.whisper.model
|
|
|
120 |
response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
|
121 |
)
|
122 |
def translate_file(
|
123 |
+
config: ConfigDependency,
|
124 |
+
model_manager: ModelManagerDependency,
|
125 |
file: Annotated[UploadFile, Form()],
|
126 |
+
model: Annotated[ModelName | None, Form()] = None,
|
127 |
prompt: Annotated[str | None, Form()] = None,
|
128 |
+
response_format: Annotated[ResponseFormat | None, Form()] = None,
|
129 |
temperature: Annotated[float, Form()] = 0.0,
|
130 |
stream: Annotated[bool, Form()] = False,
|
131 |
) -> Response | StreamingResponse:
|
132 |
+
if model is None:
|
133 |
+
model = config.whisper.model
|
134 |
+
if response_format is None:
|
135 |
+
response_format = config.default_response_format
|
136 |
whisper = model_manager.load_model(model)
|
137 |
segments, transcription_info = whisper.transcribe(
|
138 |
file.file,
|
|
|
156 |
response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
|
157 |
)
|
158 |
def transcribe_file(
|
159 |
+
config: ConfigDependency,
|
160 |
+
model_manager: ModelManagerDependency,
|
161 |
file: Annotated[UploadFile, Form()],
|
162 |
+
model: Annotated[ModelName | None, Form()] = None,
|
163 |
+
language: Annotated[Language | None, Form()] = None,
|
164 |
prompt: Annotated[str | None, Form()] = None,
|
165 |
+
response_format: Annotated[ResponseFormat | None, Form()] = None,
|
166 |
temperature: Annotated[float, Form()] = 0.0,
|
167 |
timestamp_granularities: Annotated[
|
168 |
list[Literal["segment", "word"]],
|
|
|
171 |
stream: Annotated[bool, Form()] = False,
|
172 |
hotwords: Annotated[str | None, Form()] = None,
|
173 |
) -> Response | StreamingResponse:
|
174 |
+
if model is None:
|
175 |
+
model = config.whisper.model
|
176 |
+
if language is None:
|
177 |
+
language = config.default_language
|
178 |
+
if response_format is None:
|
179 |
+
response_format = config.default_response_format
|
180 |
whisper = model_manager.load_model(model)
|
181 |
segments, transcription_info = whisper.transcribe(
|
182 |
file.file,
|
|
|
197 |
|
198 |
|
199 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
200 |
+
config = get_config() # HACK
|
201 |
try:
|
202 |
while True:
|
203 |
bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
|
|
|
229 |
|
230 |
@router.websocket("/v1/audio/transcriptions")
|
231 |
async def transcribe_stream(
|
232 |
+
config: ConfigDependency,
|
233 |
+
model_manager: ModelManagerDependency,
|
234 |
ws: WebSocket,
|
235 |
+
model: Annotated[ModelName | None, Query()] = None,
|
236 |
+
language: Annotated[Language | None, Query()] = None,
|
237 |
+
response_format: Annotated[ResponseFormat | None, Query()] = None,
|
238 |
temperature: Annotated[float, Query()] = 0.0,
|
239 |
) -> None:
|
240 |
+
if model is None:
|
241 |
+
model = config.whisper.model
|
242 |
+
if language is None:
|
243 |
+
language = config.default_language
|
244 |
+
if response_format is None:
|
245 |
+
response_format = config.default_response_format
|
246 |
await ws.accept()
|
247 |
transcribe_opts = {
|
248 |
"language": language,
|
|
|
255 |
audio_stream = AudioStream()
|
256 |
async with asyncio.TaskGroup() as tg:
|
257 |
tg.create_task(audio_receiver(ws, audio_stream))
|
258 |
+
async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
|
259 |
logger.debug(f"Sending transcription: {transcription.text}")
|
260 |
if ws.client_state == WebSocketState.DISCONNECTED:
|
261 |
break
|
src/faster_whisper_server/transcriber.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
from typing import TYPE_CHECKING
|
4 |
|
5 |
from faster_whisper_server.audio import Audio, AudioStream
|
6 |
-
from faster_whisper_server.config import config
|
7 |
from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text
|
8 |
-
from faster_whisper_server.logger import logger
|
9 |
|
10 |
if TYPE_CHECKING:
|
11 |
from collections.abc import AsyncGenerator
|
12 |
|
13 |
from faster_whisper_server.asr import FasterWhisperASR
|
14 |
|
|
|
|
|
15 |
|
16 |
class LocalAgreement:
|
17 |
def __init__(self) -> None:
|
@@ -47,11 +48,12 @@ def prompt(confirmed: Transcription) -> str | None:
|
|
47 |
async def audio_transcriber(
|
48 |
asr: FasterWhisperASR,
|
49 |
audio_stream: AudioStream,
|
|
|
50 |
) -> AsyncGenerator[Transcription, None]:
|
51 |
local_agreement = LocalAgreement()
|
52 |
full_audio = Audio()
|
53 |
confirmed = Transcription()
|
54 |
-
async for chunk in audio_stream.chunks(
|
55 |
full_audio.extend(chunk)
|
56 |
audio = full_audio.after(needs_audio_after(confirmed))
|
57 |
transcription, _ = await asr.transcribe(audio, prompt(confirmed))
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import logging
|
4 |
from typing import TYPE_CHECKING
|
5 |
|
6 |
from faster_whisper_server.audio import Audio, AudioStream
|
|
|
7 |
from faster_whisper_server.core import Transcription, Word, common_prefix, to_full_sentences, word_to_text
|
|
|
8 |
|
9 |
if TYPE_CHECKING:
|
10 |
from collections.abc import AsyncGenerator
|
11 |
|
12 |
from faster_whisper_server.asr import FasterWhisperASR
|
13 |
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
|
17 |
class LocalAgreement:
|
18 |
def __init__(self) -> None:
|
|
|
48 |
async def audio_transcriber(
|
49 |
asr: FasterWhisperASR,
|
50 |
audio_stream: AudioStream,
|
51 |
+
min_duration: float,
|
52 |
) -> AsyncGenerator[Transcription, None]:
|
53 |
local_agreement = LocalAgreement()
|
54 |
full_audio = Audio()
|
55 |
confirmed = Transcription()
|
56 |
+
async for chunk in audio_stream.chunks(min_duration):
|
57 |
full_audio.extend(chunk)
|
58 |
audio = full_audio.after(needs_audio_after(confirmed))
|
59 |
transcription, _ = await asr.transcribe(audio, prompt(confirmed))
|
tests/conftest.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from collections.abc import AsyncGenerator, Generator
|
2 |
import logging
|
|
|
3 |
|
4 |
from fastapi.testclient import TestClient
|
|
|
5 |
from httpx import ASGITransport, AsyncClient
|
6 |
from openai import OpenAI
|
7 |
import pytest
|
@@ -18,17 +20,15 @@ def pytest_configure() -> None:
|
|
18 |
|
19 |
@pytest.fixture()
|
20 |
def client() -> Generator[TestClient, None, None]:
|
21 |
-
|
22 |
-
|
23 |
-
with TestClient(app) as client:
|
24 |
yield client
|
25 |
|
26 |
|
27 |
@pytest_asyncio.fixture()
|
28 |
async def aclient() -> AsyncGenerator[AsyncClient, None]:
|
29 |
-
|
30 |
-
|
31 |
-
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as aclient:
|
32 |
yield aclient
|
33 |
|
34 |
|
|
|
1 |
from collections.abc import AsyncGenerator, Generator
|
2 |
import logging
|
3 |
+
import os
|
4 |
|
5 |
from fastapi.testclient import TestClient
|
6 |
+
from faster_whisper_server.main import create_app
|
7 |
from httpx import ASGITransport, AsyncClient
|
8 |
from openai import OpenAI
|
9 |
import pytest
|
|
|
20 |
|
21 |
@pytest.fixture()
|
22 |
def client() -> Generator[TestClient, None, None]:
|
23 |
+
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
|
24 |
+
with TestClient(create_app()) as client:
|
|
|
25 |
yield client
|
26 |
|
27 |
|
28 |
@pytest_asyncio.fixture()
|
29 |
async def aclient() -> AsyncGenerator[AsyncClient, None]:
|
30 |
+
os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
|
31 |
+
async with AsyncClient(transport=ASGITransport(app=create_app()), base_url="http://test") as aclient:
|
|
|
32 |
yield aclient
|
33 |
|
34 |
|