Fedir Zadniprovskyi commited on
Commit
7cc3853
1 Parent(s): 168a38a
src/faster_whisper_server/dependencies.py CHANGED
@@ -4,7 +4,7 @@ from typing import Annotated
4
  from fastapi import Depends
5
 
6
  from faster_whisper_server.config import Config
7
- from faster_whisper_server.model_manager import ModelManager
8
 
9
 
10
  @lru_cache
@@ -16,9 +16,18 @@ ConfigDependency = Annotated[Config, Depends(get_config)]
16
 
17
 
18
  @lru_cache
19
- def get_model_manager() -> ModelManager:
20
  config = get_config() # HACK
21
- return ModelManager(config.whisper)
22
 
23
 
24
- ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
 
 
 
 
 
 
 
 
 
 
4
  from fastapi import Depends
5
 
6
  from faster_whisper_server.config import Config
7
+ from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager
8
 
9
 
10
  @lru_cache
 
16
 
17
 
18
  @lru_cache
19
+ def get_model_manager() -> WhisperModelManager:
20
  config = get_config() # HACK
21
+ return WhisperModelManager(config.whisper)
22
 
23
 
24
+ ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manager)]
25
+
26
+
27
+ @lru_cache
28
+ def get_piper_model_manager() -> PiperModelManager:
29
+ config = get_config() # HACK
30
+ return PiperModelManager(config.whisper.ttl) # HACK
31
+
32
+
33
+ PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)]
src/faster_whisper_server/hf_utils.py CHANGED
@@ -1,9 +1,16 @@
1
  from collections.abc import Generator
 
 
2
  import logging
3
  from pathlib import Path
4
  import typing
 
5
 
6
  import huggingface_hub
 
 
 
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
@@ -12,10 +19,36 @@ TASK_NAME = "automatic-speech-recognition"
12
 
13
 
14
  def does_local_model_exist(model_id: str) -> bool:
15
- return any(model_id == model.repo_id for model, _ in list_local_models())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
- def list_local_models() -> Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]:
 
 
19
  hf_cache = huggingface_hub.scan_cache_dir()
20
  hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"]
21
  for model in hf_models:
@@ -36,3 +69,129 @@ def list_local_models() -> Generator[tuple[huggingface_hub.CachedRepoInfo, huggi
36
  and TASK_NAME in model_card_data.tags
37
  ):
38
  yield model, model_card_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from collections.abc import Generator
2
+ from functools import lru_cache
3
+ import json
4
  import logging
5
  from pathlib import Path
6
  import typing
7
+ from typing import Any, Literal
8
 
9
  import huggingface_hub
10
+ from huggingface_hub.constants import HF_HUB_CACHE
11
+ from pydantic import BaseModel
12
+
13
+ from faster_whisper_server.api_models import Model
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
19
 
20
 
21
  def does_local_model_exist(model_id: str) -> bool:
22
+ return any(model_id == model.repo_id for model, _ in list_local_whisper_models())
23
+
24
+
25
+ def list_whisper_models() -> Generator[Model, None, None]:
26
+ models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
27
+ models = list(models)
28
+ models.sort(key=lambda model: model.downloads or -1, reverse=True)
29
+ for model in models:
30
+ assert model.created_at is not None
31
+ assert model.card_data is not None
32
+ assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
33
+ if model.card_data.language is None:
34
+ language = []
35
+ elif isinstance(model.card_data.language, str):
36
+ language = [model.card_data.language]
37
+ else:
38
+ language = model.card_data.language
39
+ transformed_model = Model(
40
+ id=model.id,
41
+ created=int(model.created_at.timestamp()),
42
+ object_="model",
43
+ owned_by=model.id.split("/")[0],
44
+ language=language,
45
+ )
46
+ yield transformed_model
47
 
48
 
49
+ def list_local_whisper_models() -> (
50
+ Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]
51
+ ):
52
  hf_cache = huggingface_hub.scan_cache_dir()
53
  hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"]
54
  for model in hf_models:
 
69
  and TASK_NAME in model_card_data.tags
70
  ):
71
  yield model, model_card_data
72
+
73
+
74
+ def get_whisper_models() -> Generator[Model, None, None]:
75
+ models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
76
+ models = list(models)
77
+ models.sort(key=lambda model: model.downloads or -1, reverse=True)
78
+ for model in models:
79
+ assert model.created_at is not None
80
+ assert model.card_data is not None
81
+ assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
82
+ if model.card_data.language is None:
83
+ language = []
84
+ elif isinstance(model.card_data.language, str):
85
+ language = [model.card_data.language]
86
+ else:
87
+ language = model.card_data.language
88
+ transformed_model = Model(
89
+ id=model.id,
90
+ created=int(model.created_at.timestamp()),
91
+ object_="model",
92
+ owned_by=model.id.split("/")[0],
93
+ language=language,
94
+ )
95
+ yield transformed_model
96
+
97
+
98
+ class PiperModel(BaseModel):
99
+ id: str
100
+ object: Literal["model"] = "model"
101
+ created: int
102
+ owned_by: Literal["rhasspy"] = "rhasspy"
103
+ path: Path
104
+ config_path: Path
105
+
106
+
107
+ def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None:
108
+ if cache_dir is None:
109
+ cache_dir = HF_HUB_CACHE
110
+
111
+ cache_dir = Path(cache_dir).expanduser().resolve()
112
+ if not cache_dir.exists():
113
+ raise huggingface_hub.CacheNotFound(
114
+ f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.", # noqa: E501
115
+ cache_dir=cache_dir,
116
+ )
117
+
118
+ if cache_dir.is_file():
119
+ raise ValueError(
120
+ f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable." # noqa: E501
121
+ )
122
+
123
+ for repo_path in cache_dir.iterdir():
124
+ if not repo_path.is_dir():
125
+ continue
126
+ if repo_path.name == ".locks": # skip './.locks/' folder
127
+ continue
128
+ repo_type, repo_id = repo_path.name.split("--", maxsplit=1)
129
+ repo_type = repo_type[:-1] # "models" -> "model"
130
+ repo_id = repo_id.replace("--", "/") # google--fleurs -> "google/fleurs"
131
+ if repo_type != "model":
132
+ continue
133
+ if model_id == repo_id:
134
+ return repo_path
135
+
136
+ return None
137
+
138
+
139
+ def list_model_files(
140
+ model_id: str, glob_pattern: str = "**/*", *, cache_dir: str | Path | None = None
141
+ ) -> Generator[Path, None, None]:
142
+ repo_path = get_model_path(model_id, cache_dir=cache_dir)
143
+ if repo_path is None:
144
+ return None
145
+ snapshots_path = repo_path / "snapshots"
146
+ if not snapshots_path.exists():
147
+ return None
148
+ yield from list(snapshots_path.glob(glob_pattern))
149
+
150
+
151
+ def list_piper_models() -> Generator[PiperModel, None, None]:
152
+ model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx")
153
+ for model_weights_file in model_weights_files:
154
+ model_config_file = model_weights_file.with_suffix(".json")
155
+ yield PiperModel(
156
+ id=model_weights_file.name,
157
+ created=int(model_weights_file.stat().st_mtime),
158
+ path=model_weights_file,
159
+ config_path=model_config_file,
160
+ )
161
+
162
+
163
+ # NOTE: It's debatable whether caching should be done here or by the caller. Should be revisited.
164
+
165
+
166
+ @lru_cache
167
+ def read_piper_voices_config() -> dict[str, Any]:
168
+ voices_file = next(list_model_files("rhasspy/piper-voices", glob_pattern="**/voices.json"), None)
169
+ if voices_file is None:
170
+ raise FileNotFoundError("Could not find voices.json file") # noqa: EM101
171
+ return json.loads(voices_file.read_text())
172
+
173
+
174
+ @lru_cache
175
+ def get_piper_voice_model_file(voice: str) -> Path:
176
+ model_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx"), None)
177
+ if model_file is None:
178
+ raise FileNotFoundError(f"Could not find model file for '{voice}' voice")
179
+ return model_file
180
+
181
+
182
+ class PiperVoiceConfigAudio(BaseModel):
183
+ sample_rate: int
184
+ quality: int
185
+
186
+
187
+ class PiperVoiceConfig(BaseModel):
188
+ audio: PiperVoiceConfigAudio
189
+ # NOTE: there are more fields in the config, but we don't care about them
190
+
191
+
192
+ @lru_cache
193
+ def read_piper_voice_config(voice: str) -> PiperVoiceConfig:
194
+ model_config_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx.json"), None)
195
+ if model_config_file is None:
196
+ raise FileNotFoundError(f"Could not find config file for '{voice}' voice")
197
+ return PiperVoiceConfig.model_validate_json(model_config_file.read_text())
src/faster_whisper_server/main.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  from contextlib import asynccontextmanager
4
  import logging
 
5
  from typing import TYPE_CHECKING
6
 
7
  from fastapi import (
@@ -30,6 +31,14 @@ def create_app() -> FastAPI:
30
 
31
  logger = logging.getLogger(__name__)
32
 
 
 
 
 
 
 
 
 
33
  config = get_config() # HACK
34
  logger.debug(f"Config: {config}")
35
 
@@ -46,6 +55,8 @@ def create_app() -> FastAPI:
46
  app.include_router(stt_router)
47
  app.include_router(list_models_router)
48
  app.include_router(misc_router)
 
 
49
 
50
  if config.allow_origins is not None:
51
  app.add_middleware(
 
2
 
3
  from contextlib import asynccontextmanager
4
  import logging
5
+ import platform
6
  from typing import TYPE_CHECKING
7
 
8
  from fastapi import (
 
31
 
32
  logger = logging.getLogger(__name__)
33
 
34
+ if platform.machine() == "x86_64":
35
+ from faster_whisper_server.routers.speech import (
36
+ router as speech_router,
37
+ )
38
+ else:
39
+ logger.warning("`/v1/audio/speech` is only supported on x86_64 machines")
40
+ speech_router = None
41
+
42
  config = get_config() # HACK
43
  logger.debug(f"Config: {config}")
44
 
 
55
  app.include_router(stt_router)
56
  app.include_router(list_models_router)
57
  app.include_router(misc_router)
58
+ if speech_router is not None:
59
+ app.include_router(speech_router)
60
 
61
  if config.allow_origins is not None:
62
  app.add_middleware(
src/faster_whisper_server/model_manager.py CHANGED
@@ -9,63 +9,58 @@ from typing import TYPE_CHECKING
9
 
10
  from faster_whisper import WhisperModel
11
 
 
 
12
  if TYPE_CHECKING:
13
  from collections.abc import Callable
14
 
 
 
15
  from faster_whisper_server.config import (
16
  WhisperConfig,
17
  )
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
21
  # TODO: enable concurrent model downloads
22
 
23
 
24
- class SelfDisposingWhisperModel:
25
  def __init__(
26
- self,
27
- model_id: str,
28
- whisper_config: WhisperConfig,
29
- *,
30
- on_unload: Callable[[str], None] | None = None,
31
  ) -> None:
32
  self.model_id = model_id
33
- self.whisper_config = whisper_config
34
- self.on_unload = on_unload
 
35
 
36
  self.ref_count: int = 0
37
  self.rlock = threading.RLock()
38
  self.expire_timer: threading.Timer | None = None
39
- self.whisper: WhisperModel | None = None
40
 
41
  def unload(self) -> None:
42
  with self.rlock:
43
- if self.whisper is None:
44
  raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
45
  if self.ref_count > 0:
46
  raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
47
  if self.expire_timer:
48
  self.expire_timer.cancel()
49
- self.whisper = None
50
  # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
51
  gc.collect()
52
  logger.info(f"Model {self.model_id} unloaded")
53
- if self.on_unload is not None:
54
- self.on_unload(self.model_id)
55
 
56
  def _load(self) -> None:
57
  with self.rlock:
58
- assert self.whisper is None
59
  logger.debug(f"Loading model {self.model_id}")
60
  start = time.perf_counter()
61
- self.whisper = WhisperModel(
62
- self.model_id,
63
- device=self.whisper_config.inference_device,
64
- device_index=self.whisper_config.device_index,
65
- compute_type=self.whisper_config.compute_type,
66
- cpu_threads=self.whisper_config.cpu_threads,
67
- num_workers=self.whisper_config.num_workers,
68
- )
69
  logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
70
 
71
  def _increment_ref(self) -> None:
@@ -81,34 +76,84 @@ class SelfDisposingWhisperModel:
81
  self.ref_count -= 1
82
  logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
83
  if self.ref_count <= 0:
84
- if self.whisper_config.ttl > 0:
85
- logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
86
- self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
87
  self.expire_timer.start()
88
- elif self.whisper_config.ttl == 0:
89
  logger.info(f"Model {self.model_id} is idle, unloading immediately")
90
  self.unload()
91
  else:
92
  logger.info(f"Model {self.model_id} is idle, not unloading")
93
 
94
- def __enter__(self) -> WhisperModel:
95
  with self.rlock:
96
- if self.whisper is None:
97
  self._load()
98
  self._increment_ref()
99
- assert self.whisper is not None
100
- return self.whisper
101
 
102
  def __exit__(self, *_args) -> None: # noqa: ANN002
103
  self._decrement_ref()
104
 
105
 
106
- class ModelManager:
107
  def __init__(self, whisper_config: WhisperConfig) -> None:
108
  self.whisper_config = whisper_config
109
- self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  self._lock = threading.Lock()
111
 
 
 
 
 
 
 
112
  def _handle_model_unload(self, model_name: str) -> None:
113
  with self._lock:
114
  if model_name in self.loaded_models:
@@ -121,14 +166,17 @@ class ModelManager:
121
  raise KeyError(f"Model {model_name} not found")
122
  self.loaded_models[model_name].unload()
123
 
124
- def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
 
 
125
  with self._lock:
126
  if model_name in self.loaded_models:
127
  logger.debug(f"{model_name} model already loaded")
128
  return self.loaded_models[model_name]
129
- self.loaded_models[model_name] = SelfDisposingWhisperModel(
130
  model_name,
131
- self.whisper_config,
132
- on_unload=self._handle_model_unload,
 
133
  )
134
  return self.loaded_models[model_name]
 
9
 
10
  from faster_whisper import WhisperModel
11
 
12
+ from faster_whisper_server.hf_utils import get_piper_voice_model_file
13
+
14
  if TYPE_CHECKING:
15
  from collections.abc import Callable
16
 
17
+ from piper.voice import PiperVoice
18
+
19
  from faster_whisper_server.config import (
20
  WhisperConfig,
21
  )
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
+
26
  # TODO: enable concurrent model downloads
27
 
28
 
29
+ class SelfDisposingModel[T]:
30
  def __init__(
31
+ self, model_id: str, load_fn: Callable[[], T], ttl: int, unload_fn: Callable[[str], None] | None = None
 
 
 
 
32
  ) -> None:
33
  self.model_id = model_id
34
+ self.load_fn = load_fn
35
+ self.ttl = ttl
36
+ self.unload_fn = unload_fn
37
 
38
  self.ref_count: int = 0
39
  self.rlock = threading.RLock()
40
  self.expire_timer: threading.Timer | None = None
41
+ self.model: T | None = None
42
 
43
  def unload(self) -> None:
44
  with self.rlock:
45
+ if self.model is None:
46
  raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
47
  if self.ref_count > 0:
48
  raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
49
  if self.expire_timer:
50
  self.expire_timer.cancel()
51
+ self.model = None
52
  # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
53
  gc.collect()
54
  logger.info(f"Model {self.model_id} unloaded")
55
+ if self.unload_fn is not None:
56
+ self.unload_fn(self.model_id)
57
 
58
  def _load(self) -> None:
59
  with self.rlock:
60
+ assert self.model is None
61
  logger.debug(f"Loading model {self.model_id}")
62
  start = time.perf_counter()
63
+ self.model = self.load_fn()
 
 
 
 
 
 
 
64
  logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
65
 
66
  def _increment_ref(self) -> None:
 
76
  self.ref_count -= 1
77
  logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
78
  if self.ref_count <= 0:
79
+ if self.ttl > 0:
80
+ logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.ttl}s")
81
+ self.expire_timer = threading.Timer(self.ttl, self.unload)
82
  self.expire_timer.start()
83
+ elif self.ttl == 0:
84
  logger.info(f"Model {self.model_id} is idle, unloading immediately")
85
  self.unload()
86
  else:
87
  logger.info(f"Model {self.model_id} is idle, not unloading")
88
 
89
+ def __enter__(self) -> T:
90
  with self.rlock:
91
+ if self.model is None:
92
  self._load()
93
  self._increment_ref()
94
+ assert self.model is not None
95
+ return self.model
96
 
97
  def __exit__(self, *_args) -> None: # noqa: ANN002
98
  self._decrement_ref()
99
 
100
 
101
+ class WhisperModelManager:
102
  def __init__(self, whisper_config: WhisperConfig) -> None:
103
  self.whisper_config = whisper_config
104
+ self.loaded_models: OrderedDict[str, SelfDisposingModel[WhisperModel]] = OrderedDict()
105
+ self._lock = threading.Lock()
106
+
107
+ def _load_fn(self, model_id: str) -> WhisperModel:
108
+ return WhisperModel(
109
+ model_id,
110
+ device=self.whisper_config.inference_device,
111
+ device_index=self.whisper_config.device_index,
112
+ compute_type=self.whisper_config.compute_type,
113
+ cpu_threads=self.whisper_config.cpu_threads,
114
+ num_workers=self.whisper_config.num_workers,
115
+ )
116
+
117
+ def _handle_model_unload(self, model_name: str) -> None:
118
+ with self._lock:
119
+ if model_name in self.loaded_models:
120
+ del self.loaded_models[model_name]
121
+
122
+ def unload_model(self, model_name: str) -> None:
123
+ with self._lock:
124
+ model = self.loaded_models.get(model_name)
125
+ if model is None:
126
+ raise KeyError(f"Model {model_name} not found")
127
+ self.loaded_models[model_name].unload()
128
+
129
+ def load_model(self, model_name: str) -> SelfDisposingModel[WhisperModel]:
130
+ logger.debug(f"Loading model {model_name}")
131
+ with self._lock:
132
+ logger.debug("Acquired lock")
133
+ if model_name in self.loaded_models:
134
+ logger.debug(f"{model_name} model already loaded")
135
+ return self.loaded_models[model_name]
136
+ self.loaded_models[model_name] = SelfDisposingModel[WhisperModel](
137
+ model_name,
138
+ load_fn=lambda: self._load_fn(model_name),
139
+ ttl=self.whisper_config.ttl,
140
+ unload_fn=self._handle_model_unload,
141
+ )
142
+ return self.loaded_models[model_name]
143
+
144
+
145
+ class PiperModelManager:
146
+ def __init__(self, ttl: int) -> None:
147
+ self.ttl = ttl
148
+ self.loaded_models: OrderedDict[str, SelfDisposingModel[PiperVoice]] = OrderedDict()
149
  self._lock = threading.Lock()
150
 
151
+ def _load_fn(self, model_id: str) -> PiperVoice:
152
+ from piper.voice import PiperVoice
153
+
154
+ model_path = get_piper_voice_model_file(model_id)
155
+ return PiperVoice.load(model_path)
156
+
157
  def _handle_model_unload(self, model_name: str) -> None:
158
  with self._lock:
159
  if model_name in self.loaded_models:
 
166
  raise KeyError(f"Model {model_name} not found")
167
  self.loaded_models[model_name].unload()
168
 
169
+ def load_model(self, model_name: str) -> SelfDisposingModel[PiperVoice]:
170
+ from piper.voice import PiperVoice
171
+
172
  with self._lock:
173
  if model_name in self.loaded_models:
174
  logger.debug(f"{model_name} model already loaded")
175
  return self.loaded_models[model_name]
176
+ self.loaded_models[model_name] = SelfDisposingModel[PiperVoice](
177
  model_name,
178
+ load_fn=lambda: self._load_fn(model_name),
179
+ ttl=self.ttl,
180
+ unload_fn=self._handle_model_unload,
181
  )
182
  return self.loaded_models[model_name]
src/faster_whisper_server/routers/list_models.py CHANGED
@@ -13,6 +13,7 @@ from faster_whisper_server.api_models import (
13
  ListModelsResponse,
14
  Model,
15
  )
 
16
 
17
  if TYPE_CHECKING:
18
  from huggingface_hub.hf_api import ModelInfo
@@ -22,34 +23,13 @@ router = APIRouter()
22
 
23
  @router.get("/v1/models")
24
  def get_models() -> ListModelsResponse:
25
- models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
26
- models = list(models)
27
- models.sort(key=lambda model: model.downloads or -1, reverse=True)
28
- transformed_models: list[Model] = []
29
- for model in models:
30
- assert model.created_at is not None
31
- assert model.card_data is not None
32
- assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
33
- if model.card_data.language is None:
34
- language = []
35
- elif isinstance(model.card_data.language, str):
36
- language = [model.card_data.language]
37
- else:
38
- language = model.card_data.language
39
- transformed_model = Model(
40
- id=model.id,
41
- created=int(model.created_at.timestamp()),
42
- object_="model",
43
- owned_by=model.id.split("/")[0],
44
- language=language,
45
- )
46
- transformed_models.append(transformed_model)
47
- return ListModelsResponse(data=transformed_models)
48
 
49
 
50
  @router.get("/v1/models/{model_name:path}")
51
- # NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
52
  def get_model(
 
53
  model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
54
  ) -> Model:
55
  models = huggingface_hub.list_models(
 
13
  ListModelsResponse,
14
  Model,
15
  )
16
+ from faster_whisper_server.hf_utils import list_whisper_models
17
 
18
  if TYPE_CHECKING:
19
  from huggingface_hub.hf_api import ModelInfo
 
23
 
24
  @router.get("/v1/models")
25
  def get_models() -> ListModelsResponse:
26
+ whisper_models = list(list_whisper_models())
27
+ return ListModelsResponse(data=whisper_models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  @router.get("/v1/models/{model_name:path}")
 
31
  def get_model(
32
+ # NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
33
  model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
34
  ) -> Model:
35
  models = huggingface_hub.list_models(
src/faster_whisper_server/routers/speech.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Generator
2
+ import io
3
+ import logging
4
+ import time
5
+ from typing import Annotated, Literal, Self
6
+
7
+ from fastapi import APIRouter
8
+ from fastapi.responses import StreamingResponse
9
+ import numpy as np
10
+ from piper.voice import PiperVoice
11
+ from pydantic import BaseModel, BeforeValidator, Field, ValidationError, model_validator
12
+ import soundfile as sf
13
+
14
+ from faster_whisper_server.dependencies import PiperModelManagerDependency
15
+ from faster_whisper_server.hf_utils import read_piper_voices_config
16
+
17
+ DEFAULT_MODEL = "piper"
18
+ # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-response_format
19
+ DEFAULT_RESPONSE_FORMAT = "mp3"
20
+ DEFAULT_VOICE = "en_US-amy-medium" # TODO: make configurable
21
+ DEFAULT_VOICE_SAMPLE_RATE = 22050 # NOTE: Dependant on the voice
22
+
23
+ # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-model
24
+ # https://platform.openai.com/docs/models/tts
25
+ OPENAI_SUPPORTED_SPEECH_MODEL = ("tts-1", "tts-1-hd")
26
+
27
+ # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice
28
+ # https://platform.openai.com/docs/guides/text-to-speech/voice-options
29
+ OPENAI_SUPPORTED_SPEECH_VOICE_NAMES = ("alloy", "echo", "fable", "onyx", "nova", "shimmer")
30
+
31
+ # https://platform.openai.com/docs/guides/text-to-speech/supported-output-formats
32
+ type ResponseFormat = Literal["mp3", "flac", "wav", "pcm"]
33
+ SUPPORTED_RESPONSE_FORMATS = ("mp3", "flac", "wav", "pcm")
34
+ UNSUPORTED_RESPONSE_FORMATS = ("opus", "aac")
35
+
36
+ MIN_SAMPLE_RATE = 8000
37
+ MAX_SAMPLE_RATE = 48000
38
+
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ router = APIRouter()
43
+
44
+
45
+ # aip 'Write a function `resample_audio` which would take in RAW PCM 16-bit signed, little-endian audio data represented as bytes (`audio_bytes`) and resample it (either downsample or upsample) from `sample_rate` to `target_sample_rate` using numpy' # noqa: E501
46
+ def resample_audio(audio_bytes: bytes, sample_rate: int, target_sample_rate: int) -> bytes:
47
+ audio_data = np.frombuffer(audio_bytes, dtype=np.int16)
48
+ duration = len(audio_data) / sample_rate
49
+ target_length = int(duration * target_sample_rate)
50
+ resampled_data = np.interp(
51
+ np.linspace(0, len(audio_data), target_length, endpoint=False), np.arange(len(audio_data)), audio_data
52
+ )
53
+ return resampled_data.astype(np.int16).tobytes()
54
+
55
+
56
+ def generate_audio(
57
+ piper_tts: PiperVoice, text: str, *, speed: float = 1.0, sample_rate: int | None = None
58
+ ) -> Generator[bytes, None, None]:
59
+ if sample_rate is None:
60
+ sample_rate = piper_tts.config.sample_rate
61
+ start = time.perf_counter()
62
+ for audio_bytes in piper_tts.synthesize_stream_raw(text, length_scale=1.0 / speed):
63
+ if sample_rate != piper_tts.config.sample_rate:
64
+ audio_bytes = resample_audio(audio_bytes, piper_tts.config.sample_rate, sample_rate) # noqa: PLW2901
65
+ yield audio_bytes
66
+ logger.info(f"Generated audio for {len(text)} characters in {time.perf_counter() - start}s")
67
+
68
+
69
+ def convert_audio_format(
70
+ audio_bytes: bytes,
71
+ sample_rate: int,
72
+ audio_format: ResponseFormat,
73
+ format: str = "RAW", # noqa: A002
74
+ channels: int = 1,
75
+ subtype: str = "PCM_16",
76
+ endian: str = "LITTLE",
77
+ ) -> bytes:
78
+ # NOTE: the default dtype is float64. Should something else be used? Would that improve performance?
79
+ data, _ = sf.read(
80
+ io.BytesIO(audio_bytes),
81
+ samplerate=sample_rate,
82
+ format=format,
83
+ channels=channels,
84
+ subtype=subtype,
85
+ endian=endian,
86
+ )
87
+ converted_audio_bytes_buffer = io.BytesIO()
88
+ sf.write(converted_audio_bytes_buffer, data, samplerate=sample_rate, format=audio_format)
89
+ return converted_audio_bytes_buffer.getvalue()
90
+
91
+
92
+ def handle_openai_supported_model_ids(model_id: str) -> str:
93
+ if model_id in OPENAI_SUPPORTED_SPEECH_MODEL:
94
+ logger.warning(f"{model_id} is not a valid model name. Using '{DEFAULT_MODEL}' instead.")
95
+ return DEFAULT_MODEL
96
+ return model_id
97
+
98
+
99
+ ModelId = Annotated[
100
+ Literal["piper"],
101
+ BeforeValidator(handle_openai_supported_model_ids),
102
+ Field(
103
+ description=f"The ID of the model. The only supported model is '{DEFAULT_MODEL}'.",
104
+ examples=[DEFAULT_MODEL],
105
+ ),
106
+ ]
107
+
108
+
109
+ def handle_openai_supported_voices(voice: str) -> str:
110
+ if voice in OPENAI_SUPPORTED_SPEECH_VOICE_NAMES:
111
+ logger.warning(f"{voice} is not a valid voice name. Using '{DEFAULT_VOICE}' instead.")
112
+ return DEFAULT_VOICE
113
+ return voice
114
+
115
+
116
+ Voice = Annotated[str, BeforeValidator(handle_openai_supported_voices)] # TODO: description and examples
117
+
118
+
119
+ class CreateSpeechRequestBody(BaseModel):
120
+ model: ModelId = DEFAULT_MODEL
121
+ input: str = Field(
122
+ ...,
123
+ description="The text to generate audio for. ",
124
+ examples=[
125
+ "A rainbow is an optical phenomenon caused by refraction, internal reflection and dispersion of light in water droplets resulting in a continuous spectrum of light appearing in the sky. The rainbow takes the form of a multicoloured circular arc. Rainbows caused by sunlight always appear in the section of sky directly opposite the Sun. Rainbows can be caused by many forms of airborne water. These include not only rain, but also mist, spray, and airborne dew." # noqa: E501
126
+ ],
127
+ )
128
+ voice: Voice = DEFAULT_VOICE
129
+ response_format: ResponseFormat = Field(
130
+ DEFAULT_RESPONSE_FORMAT,
131
+ description=f"The format to audio in. Supported formats are {", ".join(SUPPORTED_RESPONSE_FORMATS)}. {", ".join(UNSUPORTED_RESPONSE_FORMATS)} are not supported", # noqa: E501
132
+ examples=list(SUPPORTED_RESPONSE_FORMATS),
133
+ )
134
+ # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice
135
+ speed: float = Field(1.0, ge=0.25, le=4.0)
136
+ """The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default."""
137
+ sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE) # TODO: document
138
+
139
+ # TODO: move into `Voice`
140
+ @model_validator(mode="after")
141
+ def verify_voice_is_valid(self) -> Self:
142
+ valid_voices = read_piper_voices_config()
143
+ if self.voice not in valid_voices:
144
+ raise ValidationError(f"Voice '{self.voice}' is not supported. Supported voices: {valid_voices.keys()}")
145
+ return self
146
+
147
+
148
+ # https://platform.openai.com/docs/api-reference/audio/createSpeech
149
+ @router.post("/v1/audio/speech")
150
+ def synthesize(
151
+ piper_model_manager: PiperModelManagerDependency,
152
+ body: CreateSpeechRequestBody,
153
+ ) -> StreamingResponse:
154
+ with piper_model_manager.load_model(body.voice) as piper_tts:
155
+ audio_generator = generate_audio(piper_tts, body.input, speed=body.speed, sample_rate=body.sample_rate)
156
+ if body.response_format != "pcm":
157
+ audio_generator = (
158
+ convert_audio_format(
159
+ audio_bytes, body.sample_rate or piper_tts.config.sample_rate, body.response_format
160
+ )
161
+ for audio_bytes in audio_generator
162
+ )
163
+
164
+ return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}")
src/faster_whisper_server/text_utils.py CHANGED
@@ -3,8 +3,6 @@ from __future__ import annotations
3
  import re
4
  from typing import TYPE_CHECKING
5
 
6
- from faster_whisper_server.dependencies import get_config
7
-
8
  if TYPE_CHECKING:
9
  from collections.abc import Iterable
10
 
@@ -40,6 +38,8 @@ class Transcription:
40
  self.words.extend(words)
41
 
42
  def _ensure_no_word_overlap(self, words: list[TranscriptionWord]) -> None:
 
 
43
  config = get_config() # HACK
44
  if len(self.words) > 0 and len(words) > 0:
45
  if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
 
3
  import re
4
  from typing import TYPE_CHECKING
5
 
 
 
6
  if TYPE_CHECKING:
7
  from collections.abc import Iterable
8
 
 
38
  self.words.extend(words)
39
 
40
  def _ensure_no_word_overlap(self, words: list[TranscriptionWord]) -> None:
41
+ from faster_whisper_server.dependencies import get_config # HACK: avoid circular import
42
+
43
  config = get_config() # HACK
44
  if len(self.words) > 0 and len(words) > 0:
45
  if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
tests/conftest.py CHANGED
@@ -4,6 +4,7 @@ import os
4
 
5
  from fastapi.testclient import TestClient
6
  from httpx import ASGITransport, AsyncClient
 
7
  from openai import AsyncOpenAI
8
  import pytest
9
  import pytest_asyncio
@@ -44,3 +45,10 @@ def actual_openai_client() -> AsyncOpenAI:
44
  return AsyncOpenAI(
45
  base_url="https://api.openai.com/v1"
46
  ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
 
 
 
 
 
 
 
 
4
 
5
  from fastapi.testclient import TestClient
6
  from httpx import ASGITransport, AsyncClient
7
+ from huggingface_hub import snapshot_download
8
  from openai import AsyncOpenAI
9
  import pytest
10
  import pytest_asyncio
 
45
  return AsyncOpenAI(
46
  base_url="https://api.openai.com/v1"
47
  ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
48
+
49
+
50
+ # TODO: remove the download after running the tests
51
+ @pytest.fixture(scope="session", autouse=True)
52
+ def download_piper_voices() -> None:
53
+ # Only download `voices.json` and the default voice
54
+ snapshot_download("rhasspy/piper-voices", allow_patterns=["voices.json", "en/en_US/amy/**"])
tests/speech_test.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import platform
3
+
4
+ from openai import APIConnectionError, AsyncOpenAI, UnprocessableEntityError
5
+ import pytest
6
+ import soundfile as sf
7
+
8
+ from faster_whisper_server.routers.speech import (
9
+ DEFAULT_MODEL,
10
+ DEFAULT_RESPONSE_FORMAT,
11
+ DEFAULT_VOICE,
12
+ SUPPORTED_RESPONSE_FORMATS,
13
+ ResponseFormat,
14
+ )
15
+
16
+ DEFAULT_INPUT = "Hello, world!"
17
+
18
+ platform_machine = platform.machine()
19
+
20
+
21
+ @pytest.mark.asyncio
22
+ @pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
23
+ @pytest.mark.parametrize("response_format", SUPPORTED_RESPONSE_FORMATS)
24
+ async def test_create_speech_formats(openai_client: AsyncOpenAI, response_format: ResponseFormat) -> None:
25
+ await openai_client.audio.speech.create(
26
+ model=DEFAULT_MODEL,
27
+ voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
28
+ input=DEFAULT_INPUT,
29
+ response_format=response_format,
30
+ )
31
+
32
+
33
+ GOOD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [
34
+ ("tts-1", "alloy"), # OpenAI and OpenAI
35
+ ("tts-1-hd", "echo"), # OpenAI and OpenAI
36
+ ("tts-1", DEFAULT_VOICE), # OpenAI and Piper
37
+ (DEFAULT_MODEL, "echo"), # Piper and OpenAI
38
+ (DEFAULT_MODEL, DEFAULT_VOICE), # Piper and Piper
39
+ ]
40
+
41
+
42
+ @pytest.mark.asyncio
43
+ @pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
44
+ @pytest.mark.parametrize(("model", "voice"), GOOD_MODEL_VOICE_PAIRS)
45
+ async def test_create_speech_good_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
46
+ await openai_client.audio.speech.create(
47
+ model=model,
48
+ voice=voice, # type: ignore # noqa: PGH003
49
+ input=DEFAULT_INPUT,
50
+ response_format=DEFAULT_RESPONSE_FORMAT,
51
+ )
52
+
53
+
54
+ BAD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [
55
+ ("tts-1", "invalid"), # OpenAI and invalid
56
+ ("invalid", "echo"), # Invalid and OpenAI
57
+ (DEFAULT_MODEL, "invalid"), # Piper and invalid
58
+ ("invalid", DEFAULT_VOICE), # Invalid and Piper
59
+ ("invalid", "invalid"), # Invalid and invalid
60
+ ]
61
+
62
+
63
+ @pytest.mark.asyncio
64
+ @pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
65
+ @pytest.mark.parametrize(("model", "voice"), BAD_MODEL_VOICE_PAIRS)
66
+ async def test_create_speech_bad_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
67
+ # NOTE: not sure why `APIConnectionError` is sometimes raised
68
+ with pytest.raises((UnprocessableEntityError, APIConnectionError)):
69
+ await openai_client.audio.speech.create(
70
+ model=model,
71
+ voice=voice, # type: ignore # noqa: PGH003
72
+ input=DEFAULT_INPUT,
73
+ response_format=DEFAULT_RESPONSE_FORMAT,
74
+ )
75
+
76
+
77
+ SUPPORTED_SPEEDS = [0.25, 0.5, 1.0, 2.0, 4.0]
78
+
79
+
80
+ @pytest.mark.asyncio
81
+ @pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
82
+ async def test_create_speech_with_varying_speed(openai_client: AsyncOpenAI) -> None:
83
+ previous_size: int | None = None
84
+ for speed in SUPPORTED_SPEEDS:
85
+ res = await openai_client.audio.speech.create(
86
+ model=DEFAULT_MODEL,
87
+ voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
88
+ input=DEFAULT_INPUT,
89
+ response_format="pcm",
90
+ speed=speed,
91
+ )
92
+ audio_bytes = res.read()
93
+ if previous_size is not None:
94
+ assert len(audio_bytes) * 1.5 < previous_size # TODO: document magic number
95
+ previous_size = len(audio_bytes)
96
+
97
+
98
+ UNSUPPORTED_SPEEDS = [0.1, 4.1]
99
+
100
+
101
+ @pytest.mark.asyncio
102
+ @pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
103
+ @pytest.mark.parametrize("speed", UNSUPPORTED_SPEEDS)
104
+ async def test_create_speech_with_unsupported_speed(openai_client: AsyncOpenAI, speed: float) -> None:
105
+ with pytest.raises(UnprocessableEntityError):
106
+ await openai_client.audio.speech.create(
107
+ model=DEFAULT_MODEL,
108
+ voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
109
+ input=DEFAULT_INPUT,
110
+ response_format="pcm",
111
+ speed=speed,
112
+ )
113
+
114
+
115
+ VALID_SAMPLE_RATES = [16000, 22050, 24000, 48000]
116
+
117
+
118
+ @pytest.mark.asyncio
119
+ @pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
120
+ @pytest.mark.parametrize("sample_rate", VALID_SAMPLE_RATES)
121
+ async def test_speech_valid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
122
+ res = await openai_client.audio.speech.create(
123
+ model=DEFAULT_MODEL,
124
+ voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
125
+ input=DEFAULT_INPUT,
126
+ response_format="wav",
127
+ extra_body={"sample_rate": sample_rate},
128
+ )
129
+ _, actual_sample_rate = sf.read(io.BytesIO(res.content))
130
+ assert actual_sample_rate == sample_rate
131
+
132
+
133
+ INVALID_SAMPLE_RATES = [7999, 48001]
134
+
135
+
136
+ @pytest.mark.asyncio
137
+ @pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
138
+ @pytest.mark.parametrize("sample_rate", INVALID_SAMPLE_RATES)
139
+ async def test_speech_invalid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
140
+ with pytest.raises(UnprocessableEntityError):
141
+ await openai_client.audio.speech.create(
142
+ model=DEFAULT_MODEL,
143
+ voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003
144
+ input=DEFAULT_INPUT,
145
+ response_format="wav",
146
+ extra_body={"sample_rate": sample_rate},
147
+ )
148
+
149
+
150
+ # TODO: implement the following test
151
+
152
+ # NUMBER_OF_MODELS = 1
153
+ # NUMBER_OF_VOICES = 124
154
+ #
155
+ #
156
+ # @pytest.mark.asyncio
157
+ # async def test_list_tts_models(openai_client: AsyncOpenAI) -> None:
158
+ # raise NotImplementedError