Spaces:
Running
on
T4
Running
on
T4
Fedir Zadniprovskyi
commited on
Commit
•
7cc3853
1
Parent(s):
168a38a
feat: tts
Browse files- src/faster_whisper_server/dependencies.py +13 -4
- src/faster_whisper_server/hf_utils.py +161 -2
- src/faster_whisper_server/main.py +11 -0
- src/faster_whisper_server/model_manager.py +84 -36
- src/faster_whisper_server/routers/list_models.py +4 -24
- src/faster_whisper_server/routers/speech.py +164 -0
- src/faster_whisper_server/text_utils.py +2 -2
- tests/conftest.py +8 -0
- tests/speech_test.py +158 -0
src/faster_whisper_server/dependencies.py
CHANGED
@@ -4,7 +4,7 @@ from typing import Annotated
|
|
4 |
from fastapi import Depends
|
5 |
|
6 |
from faster_whisper_server.config import Config
|
7 |
-
from faster_whisper_server.model_manager import
|
8 |
|
9 |
|
10 |
@lru_cache
|
@@ -16,9 +16,18 @@ ConfigDependency = Annotated[Config, Depends(get_config)]
|
|
16 |
|
17 |
|
18 |
@lru_cache
|
19 |
-
def get_model_manager() ->
|
20 |
config = get_config() # HACK
|
21 |
-
return
|
22 |
|
23 |
|
24 |
-
ModelManagerDependency = Annotated[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from fastapi import Depends
|
5 |
|
6 |
from faster_whisper_server.config import Config
|
7 |
+
from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
|
8 |
|
9 |
|
10 |
@lru_cache
|
|
|
16 |
|
17 |
|
18 |
@lru_cache
|
19 |
+
def get_model_manager() -> WhisperModelManager:
|
20 |
config = get_config() # HACK
|
21 |
+
return WhisperModelManager(config.whisper)
|
22 |
|
23 |
|
24 |
+
ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manager)]
|
25 |
+
|
26 |
+
|
27 |
+
@lru_cache
|
28 |
+
def get_piper_model_manager() -> PiperModelManager:
|
29 |
+
config = get_config() # HACK
|
30 |
+
return PiperModelManager(config.whisper.ttl) # HACK
|
31 |
+
|
32 |
+
|
33 |
+
PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
|
src/faster_whisper_server/hf_utils.py
CHANGED
@@ -1,9 +1,16 @@
|
|
1 |
from collections.abc import Generator
|
|
|
|
|
2 |
import logging
|
3 |
from pathlib import Path
|
4 |
import typing
|
|
|
5 |
|
6 |
import huggingface_hub
|
|
|
|
|
|
|
|
|
7 |
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
@@ -12,10 +19,36 @@ TASK_NAME = "automatic-speech-recognition"
|
|
12 |
|
13 |
|
14 |
def does_local_model_exist(model_id: str) -> bool:
|
15 |
-
return any(model_id == model.repo_id for model, _ in
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
-
def
|
|
|
|
|
19 |
hf_cache = huggingface_hub.scan_cache_dir()
|
20 |
hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"]
|
21 |
for model in hf_models:
|
@@ -36,3 +69,129 @@ def list_local_models() -> Generator[tuple[huggingface_hub.CachedRepoInfo, huggi
|
|
36 |
and TASK_NAME in model_card_data.tags
|
37 |
):
|
38 |
yield model, model_card_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from collections.abc import Generator
|
2 |
+
from functools import lru_cache
|
3 |
+
import json
|
4 |
import logging
|
5 |
from pathlib import Path
|
6 |
import typing
|
7 |
+
from typing import Any, Literal
|
8 |
|
9 |
import huggingface_hub
|
10 |
+
from huggingface_hub.constants import HF_HUB_CACHE
|
11 |
+
from pydantic import BaseModel
|
12 |
+
|
13 |
+
from faster_whisper_server.api_models import Model
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
|
|
|
19 |
|
20 |
|
21 |
def does_local_model_exist(model_id: str) -> bool:
|
22 |
+
return any(model_id == model.repo_id for model, _ in list_local_whisper_models())
|
23 |
+
|
24 |
+
|
25 |
+
def list_whisper_models() -> Generator[Model, None, None]:
|
26 |
+
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
|
27 |
+
models = list(models)
|
28 |
+
models.sort(key=lambda model: model.downloads or -1, reverse=True)
|
29 |
+
for model in models:
|
30 |
+
assert model.created_at is not None
|
31 |
+
assert model.card_data is not None
|
32 |
+
assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
|
33 |
+
if model.card_data.language is None:
|
34 |
+
language = []
|
35 |
+
elif isinstance(model.card_data.language, str):
|
36 |
+
language = [model.card_data.language]
|
37 |
+
else:
|
38 |
+
language = model.card_data.language
|
39 |
+
transformed_model = Model(
|
40 |
+
id=model.id,
|
41 |
+
created=int(model.created_at.timestamp()),
|
42 |
+
object_="model",
|
43 |
+
owned_by=model.id.split("/")[0],
|
44 |
+
language=language,
|
45 |
+
)
|
46 |
+
yield transformed_model
|
47 |
|
48 |
|
49 |
+
def list_local_whisper_models() -> (
|
50 |
+
Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]
|
51 |
+
):
|
52 |
hf_cache = huggingface_hub.scan_cache_dir()
|
53 |
hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"]
|
54 |
for model in hf_models:
|
|
|
69 |
and TASK_NAME in model_card_data.tags
|
70 |
):
|
71 |
yield model, model_card_data
|
72 |
+
|
73 |
+
|
74 |
+
def get_whisper_models() -> Generator[Model, None, None]:
|
75 |
+
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
|
76 |
+
models = list(models)
|
77 |
+
models.sort(key=lambda model: model.downloads or -1, reverse=True)
|
78 |
+
for model in models:
|
79 |
+
assert model.created_at is not None
|
80 |
+
assert model.card_data is not None
|
81 |
+
assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
|
82 |
+
if model.card_data.language is None:
|
83 |
+
language = []
|
84 |
+
elif isinstance(model.card_data.language, str):
|
85 |
+
language = [model.card_data.language]
|
86 |
+
else:
|
87 |
+
language = model.card_data.language
|
88 |
+
transformed_model = Model(
|
89 |
+
id=model.id,
|
90 |
+
created=int(model.created_at.timestamp()),
|
91 |
+
object_="model",
|
92 |
+
owned_by=model.id.split("/")[0],
|
93 |
+
language=language,
|
94 |
+
)
|
95 |
+
yield transformed_model
|
96 |
+
|
97 |
+
|
98 |
+
class PiperModel(BaseModel):
|
99 |
+
id: str
|
100 |
+
object: Literal["model"] = "model"
|
101 |
+
created: int
|
102 |
+
owned_by: Literal["rhasspy"] = "rhasspy"
|
103 |
+
path: Path
|
104 |
+
config_path: Path
|
105 |
+
|
106 |
+
|
107 |
+
def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None:
|
108 |
+
if cache_dir is None:
|
109 |
+
cache_dir = HF_HUB_CACHE
|
110 |
+
|
111 |
+
cache_dir = Path(cache_dir).expanduser().resolve()
|
112 |
+
if not cache_dir.exists():
|
113 |
+
raise huggingface_hub.CacheNotFound(
|
114 |
+
f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.", # noqa: E501
|
115 |
+
cache_dir=cache_dir,
|
116 |
+
)
|
117 |
+
|
118 |
+
if cache_dir.is_file():
|
119 |
+
raise ValueError(
|
120 |
+
f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable." # noqa: E501
|
121 |
+
)
|
122 |
+
|
123 |
+
for repo_path in cache_dir.iterdir():
|
124 |
+
if not repo_path.is_dir():
|
125 |
+
continue
|
126 |
+
if repo_path.name == ".locks": # skip './.locks/' folder
|
127 |
+
continue
|
128 |
+
repo_type, repo_id = repo_path.name.split("--", maxsplit=1)
|
129 |
+
repo_type = repo_type[:-1] # "models" -> "model"
|
130 |
+
repo_id = repo_id.replace("--", "/") # google--fleurs -> "google/fleurs"
|
131 |
+
if repo_type != "model":
|
132 |
+
continue
|
133 |
+
if model_id == repo_id:
|
134 |
+
return repo_path
|
135 |
+
|
136 |
+
return None
|
137 |
+
|
138 |
+
|
139 |
+
def list_model_files(
|
140 |
+
model_id: str, glob_pattern: str = "**/*", *, cache_dir: str | Path | None = None
|
141 |
+
) -> Generator[Path, None, None]:
|
142 |
+
repo_path = get_model_path(model_id, cache_dir=cache_dir)
|
143 |
+
if repo_path is None:
|
144 |
+
return None
|
145 |
+
snapshots_path = repo_path / "snapshots"
|
146 |
+
if not snapshots_path.exists():
|
147 |
+
return None
|
148 |
+
yield from list(snapshots_path.glob(glob_pattern))
|
149 |
+
|
150 |
+
|
151 |
+
def list_piper_models() -> Generator[PiperModel, None, None]:
|
152 |
+
model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx")
|
153 |
+
for model_weights_file in model_weights_files:
|
154 |
+
model_config_file = model_weights_file.with_suffix(".json")
|
155 |
+
yield PiperModel(
|
156 |
+
id=model_weights_file.name,
|
157 |
+
created=int(model_weights_file.stat().st_mtime),
|
158 |
+
path=model_weights_file,
|
159 |
+
config_path=model_config_file,
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
# NOTE: It's debatable whether caching should be done here or by the caller. Should be revisited.
|
164 |
+
|
165 |
+
|
166 |
+
@lru_cache
|
167 |
+
def read_piper_voices_config() -> dict[str, Any]:
|
168 |
+
voices_file = next(list_model_files("rhasspy/piper-voices", glob_pattern="**/voices.json"), None)
|
169 |
+
if voices_file is None:
|
170 |
+
raise FileNotFoundError("Could not find voices.json file") # noqa: EM101
|
171 |
+
return json.loads(voices_file.read_text())
|
172 |
+
|
173 |
+
|
174 |
+
@lru_cache
|
175 |
+
def get_piper_voice_model_file(voice: str) -> Path:
|
176 |
+
model_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx"), None)
|
177 |
+
if model_file is None:
|
178 |
+
raise FileNotFoundError(f"Could not find model file for '{voice}' voice")
|
179 |
+
return model_file
|
180 |
+
|
181 |
+
|
182 |
+
class PiperVoiceConfigAudio(BaseModel):
|
183 |
+
sample_rate: int
|
184 |
+
quality: int
|
185 |
+
|
186 |
+
|
187 |
+
class PiperVoiceConfig(BaseModel):
|
188 |
+
audio: PiperVoiceConfigAudio
|
189 |
+
# NOTE: there are more fields in the config, but we don't care about them
|
190 |
+
|
191 |
+
|
192 |
+
@lru_cache
|
193 |
+
def read_piper_voice_config(voice: str) -> PiperVoiceConfig:
|
194 |
+
model_config_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx.json"), None)
|
195 |
+
if model_config_file is None:
|
196 |
+
raise FileNotFoundError(f"Could not find config file for '{voice}' voice")
|
197 |
+
return PiperVoiceConfig.model_validate_json(model_config_file.read_text())
|
src/faster_whisper_server/main.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2 |
|
3 |
from contextlib import asynccontextmanager
|
4 |
import logging
|
|
|
5 |
from typing import TYPE_CHECKING
|
6 |
|
7 |
from fastapi import (
|
@@ -30,6 +31,14 @@ def create_app() -> FastAPI:
|
|
30 |
|
31 |
logger = logging.getLogger(__name__)
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
config = get_config() # HACK
|
34 |
logger.debug(f"Config: {config}")
|
35 |
|
@@ -46,6 +55,8 @@ def create_app() -> FastAPI:
|
|
46 |
app.include_router(stt_router)
|
47 |
app.include_router(list_models_router)
|
48 |
app.include_router(misc_router)
|
|
|
|
|
49 |
|
50 |
if config.allow_origins is not None:
|
51 |
app.add_middleware(
|
|
|
2 |
|
3 |
from contextlib import asynccontextmanager
|
4 |
import logging
|
5 |
+
import platform
|
6 |
from typing import TYPE_CHECKING
|
7 |
|
8 |
from fastapi import (
|
|
|
31 |
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
34 |
+
if platform.machine() == "x86_64":
|
35 |
+
from faster_whisper_server.routers.speech import (
|
36 |
+
router as speech_router,
|
37 |
+
)
|
38 |
+
else:
|
39 |
+
logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
|
40 |
+
speech_router = None
|
41 |
+
|
42 |
config = get_config() # HACK
|
43 |
logger.debug(f"Config: {config}")
|
44 |
|
|
|
55 |
app.include_router(stt_router)
|
56 |
app.include_router(list_models_router)
|
57 |
app.include_router(misc_router)
|
58 |
+
if speech_router is not None:
|
59 |
+
app.include_router(speech_router)
|
60 |
|
61 |
if config.allow_origins is not None:
|
62 |
app.add_middleware(
|
src/faster_whisper_server/model_manager.py
CHANGED
@@ -9,63 +9,58 @@ from typing import TYPE_CHECKING
|
|
9 |
|
10 |
from faster_whisper import WhisperModel
|
11 |
|
|
|
|
|
12 |
if TYPE_CHECKING:
|
13 |
from collections.abc import Callable
|
14 |
|
|
|
|
|
15 |
from faster_whisper_server.config import (
|
16 |
WhisperConfig,
|
17 |
)
|
18 |
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
|
|
21 |
# TODO: enable concurrent model downloads
|
22 |
|
23 |
|
24 |
-
class
|
25 |
def __init__(
|
26 |
-
self,
|
27 |
-
model_id: str,
|
28 |
-
whisper_config: WhisperConfig,
|
29 |
-
*,
|
30 |
-
on_unload: Callable[[str], None] | None = None,
|
31 |
) -> None:
|
32 |
self.model_id = model_id
|
33 |
-
self.
|
34 |
-
self.
|
|
|
35 |
|
36 |
self.ref_count: int = 0
|
37 |
self.rlock = threading.RLock()
|
38 |
self.expire_timer: threading.Timer | None = None
|
39 |
-
self.
|
40 |
|
41 |
def unload(self) -> None:
|
42 |
with self.rlock:
|
43 |
-
if self.
|
44 |
raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
|
45 |
if self.ref_count > 0:
|
46 |
raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
|
47 |
if self.expire_timer:
|
48 |
self.expire_timer.cancel()
|
49 |
-
self.
|
50 |
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
|
51 |
gc.collect()
|
52 |
logger.info(f"Model {self.model_id} unloaded")
|
53 |
-
if self.
|
54 |
-
self.
|
55 |
|
56 |
def _load(self) -> None:
|
57 |
with self.rlock:
|
58 |
-
assert self.
|
59 |
logger.debug(f"Loading model {self.model_id}")
|
60 |
start = time.perf_counter()
|
61 |
-
self.
|
62 |
-
self.model_id,
|
63 |
-
device=self.whisper_config.inference_device,
|
64 |
-
device_index=self.whisper_config.device_index,
|
65 |
-
compute_type=self.whisper_config.compute_type,
|
66 |
-
cpu_threads=self.whisper_config.cpu_threads,
|
67 |
-
num_workers=self.whisper_config.num_workers,
|
68 |
-
)
|
69 |
logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
|
70 |
|
71 |
def _increment_ref(self) -> None:
|
@@ -81,34 +76,84 @@ class SelfDisposingWhisperModel:
|
|
81 |
self.ref_count -= 1
|
82 |
logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
|
83 |
if self.ref_count <= 0:
|
84 |
-
if self.
|
85 |
-
logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.
|
86 |
-
self.expire_timer = threading.Timer(self.
|
87 |
self.expire_timer.start()
|
88 |
-
elif self.
|
89 |
logger.info(f"Model {self.model_id} is idle, unloading immediately")
|
90 |
self.unload()
|
91 |
else:
|
92 |
logger.info(f"Model {self.model_id} is idle, not unloading")
|
93 |
|
94 |
-
def __enter__(self) ->
|
95 |
with self.rlock:
|
96 |
-
if self.
|
97 |
self._load()
|
98 |
self._increment_ref()
|
99 |
-
assert self.
|
100 |
-
return self.
|
101 |
|
102 |
def __exit__(self, *_args) -> None: # noqa: ANN002
|
103 |
self._decrement_ref()
|
104 |
|
105 |
|
106 |
-
class
|
107 |
def __init__(self, whisper_config: WhisperConfig) -> None:
|
108 |
self.whisper_config = whisper_config
|
109 |
-
self.loaded_models: OrderedDict[str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
self._lock = threading.Lock()
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
def _handle_model_unload(self, model_name: str) -> None:
|
113 |
with self._lock:
|
114 |
if model_name in self.loaded_models:
|
@@ -121,14 +166,17 @@ class ModelManager:
|
|
121 |
raise KeyError(f"Model {model_name} not found")
|
122 |
self.loaded_models[model_name].unload()
|
123 |
|
124 |
-
def load_model(self, model_name: str) ->
|
|
|
|
|
125 |
with self._lock:
|
126 |
if model_name in self.loaded_models:
|
127 |
logger.debug(f"{model_name} model already loaded")
|
128 |
return self.loaded_models[model_name]
|
129 |
-
self.loaded_models[model_name] =
|
130 |
model_name,
|
131 |
-
self.
|
132 |
-
|
|
|
133 |
)
|
134 |
return self.loaded_models[model_name]
|
|
|
9 |
|
10 |
from faster_whisper import WhisperModel
|
11 |
|
12 |
+
from faster_whisper_server.hf_utils import get_piper_voice_model_file
|
13 |
+
|
14 |
if TYPE_CHECKING:
|
15 |
from collections.abc import Callable
|
16 |
|
17 |
+
from piper.voice import PiperVoice
|
18 |
+
|
19 |
from faster_whisper_server.config import (
|
20 |
WhisperConfig,
|
21 |
)
|
22 |
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
+
|
26 |
# TODO: enable concurrent model downloads
|
27 |
|
28 |
|
29 |
+
class SelfDisposingModel[T]:
|
30 |
def __init__(
|
31 |
+
self, model_id: str, load_fn: Callable[[], T], ttl: int, unload_fn: Callable[[str], None] | None = None
|
|
|
|
|
|
|
|
|
32 |
) -> None:
|
33 |
self.model_id = model_id
|
34 |
+
self.load_fn = load_fn
|
35 |
+
self.ttl = ttl
|
36 |
+
self.unload_fn = unload_fn
|
37 |
|
38 |
self.ref_count: int = 0
|
39 |
self.rlock = threading.RLock()
|
40 |
self.expire_timer: threading.Timer | None = None
|
41 |
+
self.model: T | None = None
|
42 |
|
43 |
def unload(self) -> None:
|
44 |
with self.rlock:
|
45 |
+
if self.model is None:
|
46 |
raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
|
47 |
if self.ref_count > 0:
|
48 |
raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
|
49 |
if self.expire_timer:
|
50 |
self.expire_timer.cancel()
|
51 |
+
self.model = None
|
52 |
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
|
53 |
gc.collect()
|
54 |
logger.info(f"Model {self.model_id} unloaded")
|
55 |
+
if self.unload_fn is not None:
|
56 |
+
self.unload_fn(self.model_id)
|
57 |
|
58 |
def _load(self) -> None:
|
59 |
with self.rlock:
|
60 |
+
assert self.model is None
|
61 |
logger.debug(f"Loading model {self.model_id}")
|
62 |
start = time.perf_counter()
|
63 |
+
self.model = self.load_fn()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
|
65 |
|
66 |
def _increment_ref(self) -> None:
|
|
|
76 |
self.ref_count -= 1
|
77 |
logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
|
78 |
if self.ref_count <= 0:
|
79 |
+
if self.ttl > 0:
|
80 |
+
logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.ttl}s")
|
81 |
+
self.expire_timer = threading.Timer(self.ttl, self.unload)
|
82 |
self.expire_timer.start()
|
83 |
+
elif self.ttl == 0:
|
84 |
logger.info(f"Model {self.model_id} is idle, unloading immediately")
|
85 |
self.unload()
|
86 |
else:
|
87 |
logger.info(f"Model {self.model_id} is idle, not unloading")
|
88 |
|
89 |
+
def __enter__(self) -> T:
|
90 |
with self.rlock:
|
91 |
+
if self.model is None:
|
92 |
self._load()
|
93 |
self._increment_ref()
|
94 |
+
assert self.model is not None
|
95 |
+
return self.model
|
96 |
|
97 |
def __exit__(self, *_args) -> None: # noqa: ANN002
|
98 |
self._decrement_ref()
|
99 |
|
100 |
|
101 |
+
class WhisperModelManager:
|
102 |
def __init__(self, whisper_config: WhisperConfig) -> None:
|
103 |
self.whisper_config = whisper_config
|
104 |
+
self.loaded_models: OrderedDict[str, SelfDisposingModel[WhisperModel]] = OrderedDict()
|
105 |
+
self._lock = threading.Lock()
|
106 |
+
|
107 |
+
def _load_fn(self, model_id: str) -> WhisperModel:
|
108 |
+
return WhisperModel(
|
109 |
+
model_id,
|
110 |
+
device=self.whisper_config.inference_device,
|
111 |
+
device_index=self.whisper_config.device_index,
|
112 |
+
compute_type=self.whisper_config.compute_type,
|
113 |
+
cpu_threads=self.whisper_config.cpu_threads,
|
114 |
+
num_workers=self.whisper_config.num_workers,
|
115 |
+
)
|
116 |
+
|
117 |
+
def _handle_model_unload(self, model_name: str) -> None:
|
118 |
+
with self._lock:
|
119 |
+
if model_name in self.loaded_models:
|
120 |
+
del self.loaded_models[model_name]
|
121 |
+
|
122 |
+
def unload_model(self, model_name: str) -> None:
|
123 |
+
with self._lock:
|
124 |
+
model = self.loaded_models.get(model_name)
|
125 |
+
if model is None:
|
126 |
+
raise KeyError(f"Model {model_name} not found")
|
127 |
+
self.loaded_models[model_name].unload()
|
128 |
+
|
129 |
+
def load_model(self, model_name: str) -> SelfDisposingModel[WhisperModel]:
|
130 |
+
logger.debug(f"Loading model {model_name}")
|
131 |
+
with self._lock:
|
132 |
+
logger.debug("Acquired lock")
|
133 |
+
if model_name in self.loaded_models:
|
134 |
+
logger.debug(f"{model_name} model already loaded")
|
135 |
+
return self.loaded_models[model_name]
|
136 |
+
self.loaded_models[model_name] = SelfDisposingModel[WhisperModel](
|
137 |
+
model_name,
|
138 |
+
load_fn=lambda: self._load_fn(model_name),
|
139 |
+
ttl=self.whisper_config.ttl,
|
140 |
+
unload_fn=self._handle_model_unload,
|
141 |
+
)
|
142 |
+
return self.loaded_models[model_name]
|
143 |
+
|
144 |
+
|
145 |
+
class PiperModelManager:
|
146 |
+
def __init__(self, ttl: int) -> None:
|
147 |
+
self.ttl = ttl
|
148 |
+
self.loaded_models: OrderedDict[str, SelfDisposingModel[PiperVoice]] = OrderedDict()
|
149 |
self._lock = threading.Lock()
|
150 |
|
151 |
+
def _load_fn(self, model_id: str) -> PiperVoice:
|
152 |
+
from piper.voice import PiperVoice
|
153 |
+
|
154 |
+
model_path = get_piper_voice_model_file(model_id)
|
155 |
+
return PiperVoice.load(model_path)
|
156 |
+
|
157 |
def _handle_model_unload(self, model_name: str) -> None:
|
158 |
with self._lock:
|
159 |
if model_name in self.loaded_models:
|
|
|
166 |
raise KeyError(f"Model {model_name} not found")
|
167 |
self.loaded_models[model_name].unload()
|
168 |
|
169 |
+
def load_model(self, model_name: str) -> SelfDisposingModel[PiperVoice]:
|
170 |
+
from piper.voice import PiperVoice
|
171 |
+
|
172 |
with self._lock:
|
173 |
if model_name in self.loaded_models:
|
174 |
logger.debug(f"{model_name} model already loaded")
|
175 |
return self.loaded_models[model_name]
|
176 |
+
self.loaded_models[model_name] = SelfDisposingModel[PiperVoice](
|
177 |
model_name,
|
178 |
+
load_fn=lambda: self._load_fn(model_name),
|
179 |
+
ttl=self.ttl,
|
180 |
+
unload_fn=self._handle_model_unload,
|
181 |
)
|
182 |
return self.loaded_models[model_name]
|
src/faster_whisper_server/routers/list_models.py
CHANGED
@@ -13,6 +13,7 @@ from faster_whisper_server.api_models import (
|
|
13 |
ListModelsResponse,
|
14 |
Model,
|
15 |
)
|
|
|
16 |
|
17 |
if TYPE_CHECKING:
|
18 |
from huggingface_hub.hf_api import ModelInfo
|
@@ -22,34 +23,13 @@ router = APIRouter()
|
|
22 |
|
23 |
@router.get("/v1/models")
|
24 |
def get_models() -> ListModelsResponse:
|
25 |
-
|
26 |
-
|
27 |
-
models.sort(key=lambda model: model.downloads or -1, reverse=True)
|
28 |
-
transformed_models: list[Model] = []
|
29 |
-
for model in models:
|
30 |
-
assert model.created_at is not None
|
31 |
-
assert model.card_data is not None
|
32 |
-
assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
|
33 |
-
if model.card_data.language is None:
|
34 |
-
language = []
|
35 |
-
elif isinstance(model.card_data.language, str):
|
36 |
-
language = [model.card_data.language]
|
37 |
-
else:
|
38 |
-
language = model.card_data.language
|
39 |
-
transformed_model = Model(
|
40 |
-
id=model.id,
|
41 |
-
created=int(model.created_at.timestamp()),
|
42 |
-
object_="model",
|
43 |
-
owned_by=model.id.split("/")[0],
|
44 |
-
language=language,
|
45 |
-
)
|
46 |
-
transformed_models.append(transformed_model)
|
47 |
-
return ListModelsResponse(data=transformed_models)
|
48 |
|
49 |
|
50 |
@router.get("/v1/models/{model_name:path}")
|
51 |
-
# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
|
52 |
def get_model(
|
|
|
53 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
|
54 |
) -> Model:
|
55 |
models = huggingface_hub.list_models(
|
|
|
13 |
ListModelsResponse,
|
14 |
Model,
|
15 |
)
|
16 |
+
from faster_whisper_server.hf_utils import list_whisper_models
|
17 |
|
18 |
if TYPE_CHECKING:
|
19 |
from huggingface_hub.hf_api import ModelInfo
|
|
|
23 |
|
24 |
@router.get("/v1/models")
|
25 |
def get_models() -> ListModelsResponse:
|
26 |
+
whisper_models = list(list_whisper_models())
|
27 |
+
return ListModelsResponse(data=whisper_models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
@router.get("/v1/models/{model_name:path}")
|
|
|
31 |
def get_model(
|
32 |
+
# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
|
33 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
|
34 |
) -> Model:
|
35 |
models = huggingface_hub.list_models(
|
src/faster_whisper_server/routers/speech.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Generator
|
2 |
+
import io
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
from typing import Annotated, Literal, Self
|
6 |
+
|
7 |
+
from fastapi import APIRouter
|
8 |
+
from fastapi.responses import StreamingResponse
|
9 |
+
import numpy as np
|
10 |
+
from piper.voice import PiperVoice
|
11 |
+
from pydantic import BaseModel, BeforeValidator, Field, ValidationError, model_validator
|
12 |
+
import soundfile as sf
|
13 |
+
|
14 |
+
from faster_whisper_server.dependencies import PiperModelManagerDependency
|
15 |
+
from faster_whisper_server.hf_utils import read_piper_voices_config
|
16 |
+
|
17 |
+
DEFAULT_MODEL = "piper"
|
18 |
+
# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-response_format
|
19 |
+
DEFAULT_RESPONSE_FORMAT = "mp3"
|
20 |
+
DEFAULT_VOICE = "en_US-amy-medium" # TODO: make configurable
|
21 |
+
DEFAULT_VOICE_SAMPLE_RATE = 22050 # NOTE: Dependant on the voice
|
22 |
+
|
23 |
+
# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-model
|
24 |
+
# https://platform.openai.com/docs/models/tts
|
25 |
+
OPENAI_SUPPORTED_SPEECH_MODEL = ("tts-1", "tts-1-hd")
|
26 |
+
|
27 |
+
# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice
|
28 |
+
# https://platform.openai.com/docs/guides/text-to-speech/voice-options
|
29 |
+
OPENAI_SUPPORTED_SPEECH_VOICE_NAMES = ("alloy", "echo", "fable", "onyx", "nova", "shimmer")
|
30 |
+
|
31 |
+
# https://platform.openai.com/docs/guides/text-to-speech/supported-output-formats
|
32 |
+
type ResponseFormat = Literal["mp3", "flac", "wav", "pcm"]
|
33 |
+
SUPPORTED_RESPONSE_FORMATS = ("mp3", "flac", "wav", "pcm")
|
34 |
+
UNSUPORTED_RESPONSE_FORMATS = ("opus", "aac")
|
35 |
+
|
36 |
+
MIN_SAMPLE_RATE = 8000
|
37 |
+
MAX_SAMPLE_RATE = 48000
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.getLogger(__name__)
|
41 |
+
|
42 |
+
router = APIRouter()
|
43 |
+
|
44 |
+
|
45 |
+
# aip 'Write a function `resample_audio` which would take in RAW PCM 16-bit signed, little-endian audio data represented as bytes (`audio_bytes`) and resample it (either downsample or upsample) from `sample_rate` to `target_sample_rate` using numpy' # noqa: E501
|
46 |
+
def resample_audio(audio_bytes: bytes, sample_rate: int, target_sample_rate: int) -> bytes:
|
47 |
+
audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
|
48 |
+
duration = len(audio_data) / sample_rate
|
49 |
+
target_length = int(duration * target_sample_rate)
|
50 |
+
resampled_data = np.interp(
|
51 |
+
np.linspace(0, len(audio_data), target_length, endpoint=False), np.arange(len(audio_data)), audio_data
|
52 |
+
)
|
53 |
+
return resampled_data.astype(np.int16).tobytes()
|
54 |
+
|
55 |
+
|
56 |
+
def generate_audio(
|
57 |
+
piper_tts: PiperVoice, text: str, *, speed: float = 1.0, sample_rate: int | None = None
|
58 |
+
) -> Generator[bytes, None, None]:
|
59 |
+
if sample_rate is None:
|
60 |
+
sample_rate = piper_tts.config.sample_rate
|
61 |
+
start = time.perf_counter()
|
62 |
+
for audio_bytes in piper_tts.synthesize_stream_raw(text, length_scale=1.0 / speed):
|
63 |
+
if sample_rate != piper_tts.config.sample_rate:
|
64 |
+
audio_bytes = resample_audio(audio_bytes, piper_tts.config.sample_rate, sample_rate) # noqa: PLW2901
|
65 |
+
yield audio_bytes
|
66 |
+
logger.info(f"Generated audio for {len(text)} characters in {time.perf_counter() - start}s")
|
67 |
+
|
68 |
+
|
69 |
+
def convert_audio_format(
|
70 |
+
audio_bytes: bytes,
|
71 |
+
sample_rate: int,
|
72 |
+
audio_format: ResponseFormat,
|
73 |
+
format: str = "RAW", # noqa: A002
|
74 |
+
channels: int = 1,
|
75 |
+
subtype: str = "PCM_16",
|
76 |
+
endian: str = "LITTLE",
|
77 |
+
) -> bytes:
|
78 |
+
# NOTE: the default dtype is float64. Should something else be used? Would that improve performance?
|
79 |
+
data, _ = sf.read(
|
80 |
+
io.BytesIO(audio_bytes),
|
81 |
+
samplerate=sample_rate,
|
82 |
+
format=format,
|
83 |
+
channels=channels,
|
84 |
+
subtype=subtype,
|
85 |
+
endian=endian,
|
86 |
+
)
|
87 |
+
converted_audio_bytes_buffer = io.BytesIO()
|
88 |
+
sf.write(converted_audio_bytes_buffer, data, samplerate=sample_rate, format=audio_format)
|
89 |
+
return converted_audio_bytes_buffer.getvalue()
|
90 |
+
|
91 |
+
|
92 |
+
def handle_openai_supported_model_ids(model_id: str) -> str:
|
93 |
+
if model_id in OPENAI_SUPPORTED_SPEECH_MODEL:
|
94 |
+
logger.warning(f"{model_id} is not a valid model name. Using '{DEFAULT_MODEL}' instead.")
|
95 |
+
return DEFAULT_MODEL
|
96 |
+
return model_id
|
97 |
+
|
98 |
+
|
99 |
+
ModelId = Annotated[
|
100 |
+
Literal["piper"],
|
101 |
+
BeforeValidator(handle_openai_supported_model_ids),
|
102 |
+
Field(
|
103 |
+
description=f"The ID of the model. The only supported model is '{DEFAULT_MODEL}'.",
|
104 |
+
examples=[DEFAULT_MODEL],
|
105 |
+
),
|
106 |
+
]
|
107 |
+
|
108 |
+
|
109 |
+
def handle_openai_supported_voices(voice: str) -> str:
|
110 |
+
if voice in OPENAI_SUPPORTED_SPEECH_VOICE_NAMES:
|
111 |
+
logger.warning(f"{voice} is not a valid voice name. Using '{DEFAULT_VOICE}' instead.")
|
112 |
+
return DEFAULT_VOICE
|
113 |
+
return voice
|
114 |
+
|
115 |
+
|
116 |
+
Voice = Annotated[str, BeforeValidator(handle_openai_supported_voices)] # TODO: description and examples
|
117 |
+
|
118 |
+
|
119 |
+
class CreateSpeechRequestBody(BaseModel):
|
120 |
+
model: ModelId = DEFAULT_MODEL
|
121 |
+
input: str = Field(
|
122 |
+
...,
|
123 |
+
description="The text to generate audio for. ",
|
124 |
+
examples=[
|
125 |
+
"A rainbow is an optical phenomenon caused by refraction, internal reflection and dispersion of light in water droplets resulting in a continuous spectrum of light appearing in the sky. The rainbow takes the form of a multicoloured circular arc. Rainbows caused by sunlight always appear in the section of sky directly opposite the Sun. Rainbows can be caused by many forms of airborne water. These include not only rain, but also mist, spray, and airborne dew." # noqa: E501
|
126 |
+
],
|
127 |
+
)
|
128 |
+
voice: Voice = DEFAULT_VOICE
|
129 |
+
response_format: ResponseFormat = Field(
|
130 |
+
DEFAULT_RESPONSE_FORMAT,
|
131 |
+
description=f"The format to audio in. Supported formats are {", ".join(SUPPORTED_RESPONSE_FORMATS)}. {", ".join(UNSUPORTED_RESPONSE_FORMATS)} are not supported", # noqa: E501
|
132 |
+
examples=list(SUPPORTED_RESPONSE_FORMATS),
|
133 |
+
)
|
134 |
+
# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice
|
135 |
+
speed: float = Field(1.0, ge=0.25, le=4.0)
|
136 |
+
"""The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default."""
|
137 |
+
sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE) # TODO: document
|
138 |
+
|
139 |
+
# TODO: move into `Voice`
|
140 |
+
@model_validator(mode="after")
|
141 |
+
def verify_voice_is_valid(self) -> Self:
|
142 |
+
valid_voices = read_piper_voices_config()
|
143 |
+
if self.voice not in valid_voices:
|
144 |
+
raise ValidationError(f"Voice '{self.voice}' is not supported. Supported voices: {valid_voices.keys()}")
|
145 |
+
return self
|
146 |
+
|
147 |
+
|
148 |
+
# https://platform.openai.com/docs/api-reference/audio/createSpeech
|
149 |
+
@router.post("/v1/audio/speech")
|
150 |
+
def synthesize(
|
151 |
+
piper_model_manager: PiperModelManagerDependency,
|
152 |
+
body: CreateSpeechRequestBody,
|
153 |
+
) -> StreamingResponse:
|
154 |
+
with piper_model_manager.load_model(body.voice) as piper_tts:
|
155 |
+
audio_generator = generate_audio(piper_tts, body.input, speed=body.speed, sample_rate=body.sample_rate)
|
156 |
+
if body.response_format != "pcm":
|
157 |
+
audio_generator = (
|
158 |
+
convert_audio_format(
|
159 |
+
audio_bytes, body.sample_rate or piper_tts.config.sample_rate, body.response_format
|
160 |
+
)
|
161 |
+
for audio_bytes in audio_generator
|
162 |
+
)
|
163 |
+
|
164 |
+
return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}")
|
src/faster_whisper_server/text_utils.py
CHANGED
@@ -3,8 +3,6 @@ from __future__ import annotations
|
|
3 |
import re
|
4 |
from typing import TYPE_CHECKING
|
5 |
|
6 |
-
from faster_whisper_server.dependencies import get_config
|
7 |
-
|
8 |
if TYPE_CHECKING:
|
9 |
from collections.abc import Iterable
|
10 |
|
@@ -40,6 +38,8 @@ class Transcription:
|
|
40 |
self.words.extend(words)
|
41 |
|
42 |
def _ensure_no_word_overlap(self, words: list[TranscriptionWord]) -> None:
|
|
|
|
|
43 |
config = get_config() # HACK
|
44 |
if len(self.words) > 0 and len(words) > 0:
|
45 |
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
|
|
|
3 |
import re
|
4 |
from typing import TYPE_CHECKING
|
5 |
|
|
|
|
|
6 |
if TYPE_CHECKING:
|
7 |
from collections.abc import Iterable
|
8 |
|
|
|
38 |
self.words.extend(words)
|
39 |
|
40 |
def _ensure_no_word_overlap(self, words: list[TranscriptionWord]) -> None:
|
41 |
+
from faster_whisper_server.dependencies import get_config # HACK: avoid circular import
|
42 |
+
|
43 |
config = get_config() # HACK
|
44 |
if len(self.words) > 0 and len(words) > 0:
|
45 |
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
|
tests/conftest.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
|
5 |
from fastapi.testclient import TestClient
|
6 |
from httpx import ASGITransport, AsyncClient
|
|
|
7 |
from openai import AsyncOpenAI
|
8 |
import pytest
|
9 |
import pytest_asyncio
|
@@ -44,3 +45,10 @@ def actual_openai_client() -> AsyncOpenAI:
|
|
44 |
return AsyncOpenAI(
|
45 |
base_url="https://api.openai.com/v1"
|
46 |
) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
from fastapi.testclient import TestClient
|
6 |
from httpx import ASGITransport, AsyncClient
|
7 |
+
from huggingface_hub import snapshot_download
|
8 |
from openai import AsyncOpenAI
|
9 |
import pytest
|
10 |
import pytest_asyncio
|
|
|
45 |
return AsyncOpenAI(
|
46 |
base_url="https://api.openai.com/v1"
|
47 |
) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
|
48 |
+
|
49 |
+
|
50 |
+
# TODO: remove the download after running the tests
|
51 |
+
@pytest.fixture(scope="session", autouse=True)
|
52 |
+
def download_piper_voices() -> None:
|
53 |
+
# Only download `voices.json` and the default voice
|
54 |
+
snapshot_download("rhasspy/piper-voices", allow_patterns=["voices.json", "en/en_US/amy/**"])
|
tests/speech_test.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import platform
|
3 |
+
|
4 |
+
from openai import APIConnectionError, AsyncOpenAI, UnprocessableEntityError
|
5 |
+
import pytest
|
6 |
+
import soundfile as sf
|
7 |
+
|
8 |
+
from faster_whisper_server.routers.speech import (
|
9 |
+
DEFAULT_MODEL,
|
10 |
+
DEFAULT_RESPONSE_FORMAT,
|
11 |
+
DEFAULT_VOICE,
|
12 |
+
SUPPORTED_RESPONSE_FORMATS,
|
13 |
+
ResponseFormat,
|
14 |
+
)
|
15 |
+
|
16 |
+
DEFAULT_INPUT = "Hello, world!"
|
17 |
+
|
18 |
+
platform_machine = platform.machine()
|
19 |
+
|
20 |
+
|
21 |
+
@pytest.mark.asyncio
|
22 |
+
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
|
23 |
+
@pytest.mark.parametrize("response_format", SUPPORTED_RESPONSE_FORMATS)
|
24 |
+
async def test_create_speech_formats(openai_client: AsyncOpenAI, response_format: ResponseFormat) -> None:
|
25 |
+
await openai_client.audio.speech.create(
|
26 |
+
model=DEFAULT_MODEL,
|
27 |
+
voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
|
28 |
+
input=DEFAULT_INPUT,
|
29 |
+
response_format=response_format,
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
GOOD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [
|
34 |
+
("tts-1", "alloy"), # OpenAI and OpenAI
|
35 |
+
("tts-1-hd", "echo"), # OpenAI and OpenAI
|
36 |
+
("tts-1", DEFAULT_VOICE), # OpenAI and Piper
|
37 |
+
(DEFAULT_MODEL, "echo"), # Piper and OpenAI
|
38 |
+
(DEFAULT_MODEL, DEFAULT_VOICE), # Piper and Piper
|
39 |
+
]
|
40 |
+
|
41 |
+
|
42 |
+
@pytest.mark.asyncio
|
43 |
+
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
|
44 |
+
@pytest.mark.parametrize(("model", "voice"), GOOD_MODEL_VOICE_PAIRS)
|
45 |
+
async def test_create_speech_good_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
|
46 |
+
await openai_client.audio.speech.create(
|
47 |
+
model=model,
|
48 |
+
voice=voice, # type: ignore # noqa: PGH003
|
49 |
+
input=DEFAULT_INPUT,
|
50 |
+
response_format=DEFAULT_RESPONSE_FORMAT,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
BAD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [
|
55 |
+
("tts-1", "invalid"), # OpenAI and invalid
|
56 |
+
("invalid", "echo"), # Invalid and OpenAI
|
57 |
+
(DEFAULT_MODEL, "invalid"), # Piper and invalid
|
58 |
+
("invalid", DEFAULT_VOICE), # Invalid and Piper
|
59 |
+
("invalid", "invalid"), # Invalid and invalid
|
60 |
+
]
|
61 |
+
|
62 |
+
|
63 |
+
@pytest.mark.asyncio
|
64 |
+
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
|
65 |
+
@pytest.mark.parametrize(("model", "voice"), BAD_MODEL_VOICE_PAIRS)
|
66 |
+
async def test_create_speech_bad_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
|
67 |
+
# NOTE: not sure why `APIConnectionError` is sometimes raised
|
68 |
+
with pytest.raises((UnprocessableEntityError, APIConnectionError)):
|
69 |
+
await openai_client.audio.speech.create(
|
70 |
+
model=model,
|
71 |
+
voice=voice, # type: ignore # noqa: PGH003
|
72 |
+
input=DEFAULT_INPUT,
|
73 |
+
response_format=DEFAULT_RESPONSE_FORMAT,
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
SUPPORTED_SPEEDS = [0.25, 0.5, 1.0, 2.0, 4.0]
|
78 |
+
|
79 |
+
|
80 |
+
@pytest.mark.asyncio
|
81 |
+
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
|
82 |
+
async def test_create_speech_with_varying_speed(openai_client: AsyncOpenAI) -> None:
|
83 |
+
previous_size: int | None = None
|
84 |
+
for speed in SUPPORTED_SPEEDS:
|
85 |
+
res = await openai_client.audio.speech.create(
|
86 |
+
model=DEFAULT_MODEL,
|
87 |
+
voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
|
88 |
+
input=DEFAULT_INPUT,
|
89 |
+
response_format="pcm",
|
90 |
+
speed=speed,
|
91 |
+
)
|
92 |
+
audio_bytes = res.read()
|
93 |
+
if previous_size is not None:
|
94 |
+
assert len(audio_bytes) * 1.5 < previous_size # TODO: document magic number
|
95 |
+
previous_size = len(audio_bytes)
|
96 |
+
|
97 |
+
|
98 |
+
UNSUPPORTED_SPEEDS = [0.1, 4.1]
|
99 |
+
|
100 |
+
|
101 |
+
@pytest.mark.asyncio
|
102 |
+
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
|
103 |
+
@pytest.mark.parametrize("speed", UNSUPPORTED_SPEEDS)
|
104 |
+
async def test_create_speech_with_unsupported_speed(openai_client: AsyncOpenAI, speed: float) -> None:
|
105 |
+
with pytest.raises(UnprocessableEntityError):
|
106 |
+
await openai_client.audio.speech.create(
|
107 |
+
model=DEFAULT_MODEL,
|
108 |
+
voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
|
109 |
+
input=DEFAULT_INPUT,
|
110 |
+
response_format="pcm",
|
111 |
+
speed=speed,
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
VALID_SAMPLE_RATES = [16000, 22050, 24000, 48000]
|
116 |
+
|
117 |
+
|
118 |
+
@pytest.mark.asyncio
|
119 |
+
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
|
120 |
+
@pytest.mark.parametrize("sample_rate", VALID_SAMPLE_RATES)
|
121 |
+
async def test_speech_valid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
|
122 |
+
res = await openai_client.audio.speech.create(
|
123 |
+
model=DEFAULT_MODEL,
|
124 |
+
voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
|
125 |
+
input=DEFAULT_INPUT,
|
126 |
+
response_format="wav",
|
127 |
+
extra_body={"sample_rate": sample_rate},
|
128 |
+
)
|
129 |
+
_, actual_sample_rate = sf.read(io.BytesIO(res.content))
|
130 |
+
assert actual_sample_rate == sample_rate
|
131 |
+
|
132 |
+
|
133 |
+
INVALID_SAMPLE_RATES = [7999, 48001]
|
134 |
+
|
135 |
+
|
136 |
+
@pytest.mark.asyncio
|
137 |
+
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
|
138 |
+
@pytest.mark.parametrize("sample_rate", INVALID_SAMPLE_RATES)
|
139 |
+
async def test_speech_invalid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
|
140 |
+
with pytest.raises(UnprocessableEntityError):
|
141 |
+
await openai_client.audio.speech.create(
|
142 |
+
model=DEFAULT_MODEL,
|
143 |
+
voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
|
144 |
+
input=DEFAULT_INPUT,
|
145 |
+
response_format="wav",
|
146 |
+
extra_body={"sample_rate": sample_rate},
|
147 |
+
)
|
148 |
+
|
149 |
+
|
150 |
+
# TODO: implement the following test
|
151 |
+
|
152 |
+
# NUMBER_OF_MODELS = 1
|
153 |
+
# NUMBER_OF_VOICES = 124
|
154 |
+
#
|
155 |
+
#
|
156 |
+
# @pytest.mark.asyncio
|
157 |
+
# async def test_list_tts_models(openai_client: AsyncOpenAI) -> None:
|
158 |
+
# raise NotImplementedError
|