Fedir Zadniprovskyi commited on
Commit
313814b
·
0 Parent(s):
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ tests/data
3
+ .pytest_cache
4
+ .git
5
+ flake.nix
6
+ flake.lock
7
+ .envrc
8
+ .gitignore
9
+ .direnv
10
+ .task
11
+ Taskfile.yaml
12
+ README.md
.envrc ADDED
@@ -0,0 +1 @@
 
 
1
+ use flake
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__
2
+ .pytest_cache
3
+ tests/data
4
+ .direnv
5
+ .task
.pre-commit-config.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See https://pre-commit.com for more information
2
+ # See https://pre-commit.com/hooks.html for more hooks
3
+ repos:
4
+ - repo: https://github.com/pre-commit/pre-commit-hooks
5
+ rev: v3.2.0
6
+ hooks:
7
+ - id: trailing-whitespace
8
+ - id: end-of-file-fixer
9
+ - id: check-yaml
10
+ - id: check-added-large-files
11
+
12
+ - repo: https://github.com/pre-commit/mirrors-mypy
13
+ rev: v1.10.0
14
+ hooks:
15
+ - id: mypy
16
+
17
+ # - repo: https://github.com/PyCQA/isort
18
+ # rev: 5.13.2
19
+ # hooks:
20
+ # - id: isort
21
+ #
22
+ # - repo: https://github.com/psf/black
23
+ # rev: 24.4.2
24
+ # hooks:
25
+ # - id: black
Dockerfile.cpu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+ RUN apt-get update && \
3
+ apt-get install -y curl software-properties-common && \
4
+ add-apt-repository ppa:deadsnakes/ppa && \
5
+ apt-get update && \
6
+ DEBIAN_FRONTEND=noninteractive apt-get -y install python3.11 python3.11-distutils && \
7
+ curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
8
+ RUN pip install --no-cache-dir poetry==1.8.2
9
+ WORKDIR /root/speaches
10
+ COPY pyproject.toml poetry.lock ./
11
+ RUN poetry install
12
+ COPY ./speaches ./speaches
13
+ ENTRYPOINT ["poetry", "run"]
14
+ CMD ["uvicorn", "speaches.main:app"]
Dockerfile.cuda ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04
2
+ RUN apt-get update && \
3
+ apt-get install -y curl software-properties-common && \
4
+ add-apt-repository ppa:deadsnakes/ppa && \
5
+ apt-get update && \
6
+ DEBIAN_FRONTEND=noninteractive apt-get -y install python3.11 python3.11-distutils && \
7
+ curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
8
+ RUN pip install --no-cache-dir poetry==1.8.2
9
+ WORKDIR /root/speaches
10
+ COPY pyproject.toml poetry.lock ./
11
+ RUN poetry install
12
+ COPY ./speaches ./speaches
13
+ ENTRYPOINT ["poetry", "run"]
14
+ CMD ["uvicorn", "speaches.main:app"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Fedir Zadniprovskyi
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Intro
2
+ `speaches` is a webserver that supports real-time transcription using WebSockets.
3
+ - [faster-whisper](https://github.com/SYSTRAN/faster-whisper) is used as the backend. Both GPU and CPU inference is supported.
4
+ - LocalAgreement2([paper](https://aclanthology.org/2023.ijcnlp-demo.3.pdf)|[original implementation](https://github.com/ufal/whisper_streaming)) algorithm is used for real-time transcription.
5
+ - Can be deployed using Docker (Compose configuration can be found in (compose.yaml[./compose.yaml])).
6
+ - All configuration is done through environment variables. See [config.py](./speaches/config.py).
7
+ - NOTE: only transcription of single channel, 16000 sample rate, raw, 16-bit little-endian audio is supported.
8
+ - NOTE: this isn't really meant to be used as a standalone tool but rather to add transcription features to other applications
9
+ Please create an issue if you find a bug, have a question, or a feature suggestion.
10
+ # Quick Start
11
+ NOTE: You'll need to install [websocat](https://github.com/vi/websocat?tab=readme-ov-file#installation) or an alternative.
12
+ Spinning up a `speaches` web-server
13
+ ```bash
14
+ docker run --detach --gpus=all --publish 8000:8000 --mount ~/.cache/huggingface:/root/.cache/huggingface --name speaches fedirz/speaches:cuda
15
+ # or
16
+ docker run --detach --publish 8000:8000 --mount ~/.cache/huggingface:/root/.cache/huggingface --name speaches fedirz/speaches:cpu
17
+ ```
18
+ Sending audio data via websocket
19
+ ```bash
20
+ arecord -f S16_LE -c1 -r 16000 -t raw -D default | websocat --binary ws://localhost:8000/v1/audio/transcriptions
21
+ # or
22
+ ffmpeg -f alsa -ac 1 -ar 16000 -sample_fmt s16le -i default | websocat --binary ws://localhost:8000/v1/audio/transcriptions
23
+ ```
24
+ # Example
Taskfile.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3"
2
+ tasks:
3
+ speaches: poetry run uvicorn speaches.main:app {{.CLI_ARGS}}
4
+ test:
5
+ cmds:
6
+ - poetry run pytest -o log_cli=true -o log_cli_level=DEBUG {{.CLI_ARGS}}
7
+ sources:
8
+ - "**/*.py"
9
+ build-and-push:
10
+ cmds:
11
+ - docker compose build --push speaches
12
+ sources:
13
+ - Dockerfile
14
+ - speaches/*.py
15
+ sync: lsyncd -nodaemon -delay 0 -rsyncssh . gpu-box speaches
compose.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ speaches-cuda:
3
+ image: fedirz/speaches:cuda
4
+ build:
5
+ dockerfile: Dockerfile.cuda
6
+ context: .
7
+ tags:
8
+ - fedirz/speaches:cuda
9
+ volumes:
10
+ - ~/.cache/huggingface:/root/.cache/huggingface
11
+ restart: unless-stopped
12
+ ports:
13
+ - 8000:8000
14
+ environment:
15
+ - INFERENCE_DEVICE=cuda
16
+ deploy:
17
+ resources:
18
+ reservations:
19
+ devices:
20
+ - capabilities: ["gpu"]
21
+ speaches-cpu:
22
+ image: fedirz/speaches:cpu
23
+ build:
24
+ dockerfile: Dockerfile.cpu
25
+ context: .
26
+ tags:
27
+ - fedirz/speaches:cpu
28
+ volumes:
29
+ - ~/.cache/huggingface:/root/.cache/huggingface
30
+ restart: unless-stopped
31
+ ports:
32
+ - 8000:8000
33
+ environment:
34
+ - INFERENCE_DEVICE=cpu
flake.lock ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-utils": {
4
+ "inputs": {
5
+ "systems": "systems"
6
+ },
7
+ "locked": {
8
+ "lastModified": 1710146030,
9
+ "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
10
+ "owner": "numtide",
11
+ "repo": "flake-utils",
12
+ "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
13
+ "type": "github"
14
+ },
15
+ "original": {
16
+ "owner": "numtide",
17
+ "repo": "flake-utils",
18
+ "type": "github"
19
+ }
20
+ },
21
+ "nixpkgs": {
22
+ "locked": {
23
+ "lastModified": 1716073433,
24
+ "narHash": "sha256-9G0BS7I/5z0n35Vx1d+TLxaIKQ93rEf5VLXNLWu7/44=",
25
+ "owner": "NixOS",
26
+ "repo": "nixpkgs",
27
+ "rev": "b7d845292c304e026d86097e6d07409070e80dcc",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "NixOS",
32
+ "ref": "master",
33
+ "repo": "nixpkgs",
34
+ "type": "github"
35
+ }
36
+ },
37
+ "root": {
38
+ "inputs": {
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs"
41
+ }
42
+ },
43
+ "systems": {
44
+ "locked": {
45
+ "lastModified": 1681028828,
46
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
47
+ "owner": "nix-systems",
48
+ "repo": "default",
49
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
50
+ "type": "github"
51
+ },
52
+ "original": {
53
+ "owner": "nix-systems",
54
+ "repo": "default",
55
+ "type": "github"
56
+ }
57
+ }
58
+ },
59
+ "root": "root",
60
+ "version": 7
61
+ }
flake.nix ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ inputs = {
3
+ nixpkgs.url = "github:NixOS/nixpkgs/master";
4
+ flake-utils.url = "github:numtide/flake-utils";
5
+ };
6
+ outputs =
7
+ { nixpkgs, flake-utils, ... }:
8
+ flake-utils.lib.eachDefaultSystem (
9
+ system:
10
+ let
11
+ pkgs = import nixpkgs {
12
+ inherit system;
13
+ config.allowUnfree = true;
14
+ };
15
+ in
16
+ {
17
+ devShells = {
18
+ default = pkgs.mkShell {
19
+ nativeBuildInputs = with pkgs; [
20
+ (with python311Packages; huggingface-hub)
21
+ ffmpeg-full
22
+ go-task
23
+ lsyncd
24
+ poetry
25
+ pre-commit
26
+ pyright
27
+ python311
28
+ websocat
29
+ ];
30
+ shellHook = ''
31
+ source $(poetry env info --path)/bin/activate
32
+ export LD_LIBRARY_PATH=${pkgs.stdenv.cc.cc.lib}/lib:$LD_LIBRARY_PATH
33
+ export LD_LIBRARY_PATH=${pkgs.zlib}/lib:$LD_LIBRARY_PATH
34
+ '';
35
+ };
36
+ };
37
+ formatter = pkgs.nixfmt;
38
+ }
39
+ );
40
+ }
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ package-mode = false
3
+
4
+ [tool.poetry.dependencies]
5
+ python = "^3.11"
6
+ faster-whisper = "^1.0.2"
7
+ pydantic = "^2.7.1"
8
+ fastapi = "^0.111.0"
9
+ uvicorn = "^0.29.0"
10
+ python-multipart = "^0.0.9"
11
+ soundfile = "^0.12.1"
12
+ pydantic-settings = "^2.2.1"
13
+ websockets = "^12.0"
14
+ numpy = "^1.26.4"
15
+
16
+
17
+ [tool.poetry.group.dev.dependencies]
18
+ pytest = "^8.2.0"
19
+ pytest-asyncio = "^0.23.6"
20
+ httpx = "^0.27.0"
21
+ httpx-ws = "^0.6.0"
22
+ pytest-xdist = "^3.6.1"
23
+
24
+
25
+
26
+ [tool.poetry.group.client.dependencies]
27
+ httpx = "^0.27.0"
28
+ httpx-ws = "^0.6.0"
29
+
30
+ [build-system]
31
+ requires = ["poetry-core"]
32
+ build-backend = "poetry.core.masonry.api"
speaches/__init__.py ADDED
File without changes
speaches/asr.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import time
3
+ from typing import Iterable
4
+
5
+ from faster_whisper import transcribe
6
+ from pydantic import BaseModel
7
+
8
+ from speaches.audio import Audio
9
+ from speaches.config import Language
10
+ from speaches.core import Transcription, Word
11
+ from speaches.logger import logger
12
+
13
+
14
+ class TranscribeOpts(BaseModel):
15
+ language: Language | None
16
+ vad_filter: bool
17
+ condition_on_previous_text: bool
18
+
19
+
20
+ class FasterWhisperASR:
21
+ def __init__(
22
+ self,
23
+ whisper: transcribe.WhisperModel,
24
+ transcribe_opts: TranscribeOpts,
25
+ ) -> None:
26
+ self.whisper = whisper
27
+ self.transcribe_opts = transcribe_opts
28
+
29
+ def _transcribe(
30
+ self,
31
+ audio: Audio,
32
+ prompt: str | None = None,
33
+ ) -> tuple[Transcription, transcribe.TranscriptionInfo]:
34
+ start = time.perf_counter()
35
+ segments, transcription_info = self.whisper.transcribe(
36
+ audio.data,
37
+ initial_prompt=prompt,
38
+ word_timestamps=True,
39
+ **self.transcribe_opts.model_dump(),
40
+ )
41
+ words = words_from_whisper_segments(segments)
42
+ for word in words:
43
+ word.offset(audio.start)
44
+ transcription = Transcription(words)
45
+ end = time.perf_counter()
46
+ logger.info(
47
+ f"Transcribed {audio} in {end - start:.2f} seconds. Prompt: {prompt}. Transcription: {transcription.text}"
48
+ )
49
+ return (transcription, transcription_info)
50
+
51
+ async def transcribe(
52
+ self,
53
+ audio: Audio,
54
+ prompt: str | None = None,
55
+ ) -> tuple[Transcription, transcribe.TranscriptionInfo]:
56
+ """Wrapper around _transcribe so it can be used in async context"""
57
+ # is this the optimal way to execute a blocking call in an async context?
58
+ # TODO: verify performance when running inference on a CPU
59
+ return await asyncio.get_running_loop().run_in_executor(
60
+ None,
61
+ self._transcribe,
62
+ audio,
63
+ prompt,
64
+ )
65
+
66
+
67
+ def words_from_whisper_segments(segments: Iterable[transcribe.Segment]) -> list[Word]:
68
+ words: list[Word] = []
69
+ for segment in segments:
70
+ assert segment.words is not None
71
+ words.extend(
72
+ Word(
73
+ start=word.start,
74
+ end=word.end,
75
+ text=word.word,
76
+ probability=word.probability,
77
+ )
78
+ for word in segment.words
79
+ )
80
+ return words
speaches/audio.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from typing import AsyncGenerator, BinaryIO
5
+
6
+ import numpy as np
7
+ import soundfile as sf
8
+ from numpy.typing import NDArray
9
+
10
+ from speaches.config import SAMPLES_PER_SECOND
11
+ from speaches.logger import logger
12
+
13
+
14
+ def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
15
+ audio_and_sample_rate: tuple[NDArray[np.float32], Any] = sf.read( # type: ignore
16
+ file,
17
+ format="RAW",
18
+ channels=1,
19
+ samplerate=SAMPLES_PER_SECOND,
20
+ subtype="PCM_16",
21
+ dtype="float32",
22
+ endian="LITTLE",
23
+ )
24
+ audio = audio_and_sample_rate[0]
25
+ return audio
26
+
27
+
28
+ class Audio:
29
+ def __init__(
30
+ self,
31
+ data: NDArray[np.float32] = np.array([], dtype=np.float32),
32
+ start: float = 0.0,
33
+ ) -> None:
34
+ self.data = data
35
+ self.start = start
36
+
37
+ def __repr__(self) -> str:
38
+ return f"Audio(start={self.start:.2f}, end={self.end:.2f})"
39
+
40
+ @property
41
+ def end(self) -> float:
42
+ return self.start + self.duration
43
+
44
+ @property
45
+ def duration(self) -> float:
46
+ return len(self.data) / SAMPLES_PER_SECOND
47
+
48
+ def after(self, ts: float) -> Audio:
49
+ assert ts <= self.duration
50
+ return Audio(self.data[int(ts * SAMPLES_PER_SECOND) :], start=ts)
51
+
52
+ def extend(self, data: NDArray[np.float32]) -> None:
53
+ # logger.debug(f"Extending audio by {len(data) / SAMPLES_PER_SECOND:.2f}s")
54
+ self.data = np.append(self.data, data)
55
+ # logger.debug(f"Audio duration: {self.duration:.2f}s")
56
+
57
+
58
+ # TODO: trim data longer than x
59
+ class AudioStream(Audio):
60
+ def __init__(
61
+ self,
62
+ data: NDArray[np.float32] = np.array([], dtype=np.float32),
63
+ start: float = 0.0,
64
+ ) -> None:
65
+ super().__init__(data, start)
66
+ self.closed = False
67
+
68
+ self.modify_event = asyncio.Event()
69
+
70
+ def extend(self, data: NDArray[np.float32]) -> None:
71
+ assert self.closed == False
72
+ super().extend(data)
73
+ self.modify_event.set()
74
+
75
+ def close(self) -> None:
76
+ assert self.closed == False
77
+ self.closed = True
78
+ self.modify_event.set()
79
+ logger.info("AudioStream closed")
80
+
81
+ async def chunks(
82
+ self, min_duration: float
83
+ ) -> AsyncGenerator[NDArray[np.float32], None]:
84
+ i = 0.0 # end time of last chunk
85
+ while True:
86
+ await self.modify_event.wait()
87
+ self.modify_event.clear()
88
+ if self.closed or self.duration - i >= min_duration:
89
+ # If `i` shouldn't be set to `duration` after the yield
90
+ # because by the time assignment would happen more data might have been added
91
+ i_ = i
92
+ i = self.duration
93
+ # NOTE: probably better to just to a slice
94
+ yield self.after(i_).data
95
+ if self.closed:
96
+ return
speaches/client.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: move out of `speaches` package
2
+ import asyncio
3
+ import signal
4
+
5
+ import httpx
6
+ from httpx_ws import AsyncWebSocketSession, WebSocketDisconnect, aconnect_ws
7
+ from wsproto.connection import ConnectionState
8
+
9
+ CHUNK = 1024 * 4
10
+ AUDIO_RECORD_CMD = "arecord -D default -f S16_LE -r 16000 -c 1 -t raw"
11
+ COPY_TO_CLIPBOARD_CMD = "wl-copy"
12
+ NOTIFY_CMD = "notify-desktop"
13
+
14
+ client = httpx.AsyncClient(base_url="ws://localhost:8000")
15
+
16
+
17
+ async def audio_sender(ws: AsyncWebSocketSession) -> None:
18
+ process = await asyncio.create_subprocess_shell(
19
+ AUDIO_RECORD_CMD,
20
+ stdout=asyncio.subprocess.PIPE,
21
+ stderr=asyncio.subprocess.DEVNULL,
22
+ )
23
+ assert process.stdout is not None
24
+ try:
25
+ while not process.stdout.at_eof():
26
+ data = await process.stdout.read(CHUNK)
27
+ if ws.connection.state != ConnectionState.OPEN:
28
+ break
29
+ await ws.send_bytes(data)
30
+ except Exception as e:
31
+ print(e)
32
+ finally:
33
+ process.kill()
34
+
35
+
36
+ async def transcription_receiver(ws: AsyncWebSocketSession) -> None:
37
+ transcription = ""
38
+ notification_id: int | None = None
39
+ try:
40
+ while True:
41
+ data = await ws.receive_text()
42
+ if not data:
43
+ break
44
+ transcription += data
45
+ await copy_to_clipboard(transcription)
46
+ notification_id = await notify(transcription, replaces_id=notification_id)
47
+ except WebSocketDisconnect:
48
+ pass
49
+ print(transcription)
50
+
51
+
52
+ async def copy_to_clipboard(text: str) -> None:
53
+ process = await asyncio.create_subprocess_shell(
54
+ COPY_TO_CLIPBOARD_CMD, stdin=asyncio.subprocess.PIPE
55
+ )
56
+ await process.communicate(input=text.encode("utf-8"))
57
+ await process.wait()
58
+
59
+
60
+ async def notify(text: str, replaces_id: int | None = None) -> int:
61
+ cmd = ["notify-desktop", "--app-name", "Speaches"]
62
+ if replaces_id is not None:
63
+ cmd.extend(["--replaces-id", str(replaces_id)])
64
+ cmd.append("'Speaches'")
65
+ cmd.append(f"'{text}'")
66
+ process = await asyncio.create_subprocess_shell(
67
+ " ".join(cmd),
68
+ stdout=asyncio.subprocess.PIPE,
69
+ )
70
+ await process.wait()
71
+ assert process.stdout is not None
72
+ notification_id = (await process.stdout.read()).decode("utf-8")
73
+ return int(notification_id)
74
+
75
+
76
+ async def main() -> None:
77
+ async with aconnect_ws("/v1/audio/transcriptions", client) as ws:
78
+ async with asyncio.TaskGroup() as tg:
79
+ sender_task = tg.create_task(audio_sender(ws))
80
+ receiver_task = tg.create_task(transcription_receiver(ws))
81
+
82
+ async def on_interrupt():
83
+ sender_task.cancel()
84
+ receiver_task.cancel()
85
+ await asyncio.gather(sender_task, receiver_task)
86
+
87
+ asyncio.get_running_loop().add_signal_handler(
88
+ signal.SIGINT,
89
+ lambda: asyncio.create_task(on_interrupt()),
90
+ )
91
+
92
+
93
+ asyncio.run(main())
94
+ # poetry --directory /home/nixos/code/speaches run python /home/nixos/code/speaches/speaches/client.py
speaches/config.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+ from pydantic import BaseModel, Field
4
+ from pydantic_settings import BaseSettings, SettingsConfigDict
5
+
6
+ SAMPLES_PER_SECOND = 16000
7
+ BYTES_PER_SAMPLE = 2
8
+ BYTES_PER_SECOND = SAMPLES_PER_SECOND * BYTES_PER_SAMPLE
9
+ # 2 BYTES = 16 BITS = 1 SAMPLE
10
+ # 1 SECOND OF AUDIO = 32000 BYTES = 16000 SAMPLES
11
+
12
+
13
+ # TODO: confirm names
14
+ class Model(enum.StrEnum):
15
+ TINY_EN = "tiny.en"
16
+ TINY = "tiny"
17
+ BASE_EN = "base.en"
18
+ BASE = "base"
19
+ SMALL_EN = "small.en"
20
+ SMALL = "small"
21
+ MEDIUM_EN = "medium.en"
22
+ MEDIUM = "medium"
23
+ LARGE = "large"
24
+ LARGE_V1 = "large-v1"
25
+ LARGE_V2 = "large-v2"
26
+ LARGE_V3 = "large-v3"
27
+ DISTIL_SMALL_EN = "distil-small.en"
28
+ DISTIL_MEDIUM_EN = "distil-medium.en"
29
+ DISTIL_LARGE_V2 = "distil-large-v2"
30
+ DISTIL_LARGE_V3 = "distil-large-v3"
31
+
32
+
33
+ class Device(enum.StrEnum):
34
+ CPU = "cpu"
35
+ CUDA = "cuda"
36
+ AUTO = "auto"
37
+
38
+
39
+ # https://github.com/OpenNMT/CTranslate2/blob/master/docs/quantization.md
40
+ class Quantization(enum.StrEnum):
41
+ INT8 = "int8"
42
+ INT8_FLOAT16 = "int8_float16"
43
+ INT8_BFLOAT16 = "int8_bfloat16"
44
+ INT8_FLOAT32 = "int8_float32"
45
+ INT16 = "int16"
46
+ FLOAT16 = "float16"
47
+ BFLOAT16 = "bfloat16"
48
+ FLOAT32 = "float32"
49
+ DEFAULT = "default"
50
+
51
+
52
+ class Language(enum.StrEnum):
53
+ AF = "af"
54
+ AM = "am"
55
+ AR = "ar"
56
+ AS = "as"
57
+ AZ = "az"
58
+ BA = "ba"
59
+ BE = "be"
60
+ BG = "bg"
61
+ BN = "bn"
62
+ BO = "bo"
63
+ BR = "br"
64
+ BS = "bs"
65
+ CA = "ca"
66
+ CS = "cs"
67
+ CY = "cy"
68
+ DA = "da"
69
+ DE = "de"
70
+ EL = "el"
71
+ EN = "en"
72
+ ES = "es"
73
+ ET = "et"
74
+ EU = "eu"
75
+ FA = "fa"
76
+ FI = "fi"
77
+ FO = "fo"
78
+ FR = "fr"
79
+ GL = "gl"
80
+ GU = "gu"
81
+ HA = "ha"
82
+ HAW = "haw"
83
+ HE = "he"
84
+ HI = "hi"
85
+ HR = "hr"
86
+ HT = "ht"
87
+ HU = "hu"
88
+ HY = "hy"
89
+ ID = "id"
90
+ IS = "is"
91
+ IT = "it"
92
+ JA = "ja"
93
+ JW = "jw"
94
+ KA = "ka"
95
+ KK = "kk"
96
+ KM = "km"
97
+ KN = "kn"
98
+ KO = "ko"
99
+ LA = "la"
100
+ LB = "lb"
101
+ LN = "ln"
102
+ LO = "lo"
103
+ LT = "lt"
104
+ LV = "lv"
105
+ MG = "mg"
106
+ MI = "mi"
107
+ MK = "mk"
108
+ ML = "ml"
109
+ MN = "mn"
110
+ MR = "mr"
111
+ MS = "ms"
112
+ MT = "mt"
113
+ MY = "my"
114
+ NE = "ne"
115
+ NL = "nl"
116
+ NN = "nn"
117
+ NO = "no"
118
+ OC = "oc"
119
+ PA = "pa"
120
+ PL = "pl"
121
+ PS = "ps"
122
+ PT = "pt"
123
+ RO = "ro"
124
+ RU = "ru"
125
+ SA = "sa"
126
+ SD = "sd"
127
+ SI = "si"
128
+ SK = "sk"
129
+ SL = "sl"
130
+ SN = "sn"
131
+ SO = "so"
132
+ SQ = "sq"
133
+ SR = "sr"
134
+ SU = "su"
135
+ SV = "sv"
136
+ SW = "sw"
137
+ TA = "ta"
138
+ TE = "te"
139
+ TG = "tg"
140
+ TH = "th"
141
+ TK = "tk"
142
+ TL = "tl"
143
+ TR = "tr"
144
+ TT = "tt"
145
+ UK = "uk"
146
+ UR = "ur"
147
+ UZ = "uz"
148
+ VI = "vi"
149
+ YI = "yi"
150
+ YO = "yo"
151
+ YUE = "yue"
152
+ ZH = "zh"
153
+
154
+
155
+ class WhisperConfig(BaseModel):
156
+ model: Model = Field(default=Model.DISTIL_SMALL_EN)
157
+ inference_device: Device = Field(default=Device.AUTO)
158
+ compute_type: Quantization = Field(default=Quantization.DEFAULT)
159
+
160
+
161
+ class Config(BaseSettings):
162
+ model_config = SettingsConfigDict(env_nested_delimiter="_")
163
+
164
+ log_level: str = "info"
165
+ whisper: WhisperConfig = WhisperConfig()
166
+ """
167
+ Max duration to for the next audio chunk before finilizing the transcription and closing the connection.
168
+ """
169
+ max_no_data_seconds: float = 1.0
170
+ min_duration: float = 1.0
171
+ word_timestamp_error_margin: float = 0.2
172
+ inactivity_window_seconds: float = 3.0
173
+ max_inactivity_seconds: float = 1.5
174
+
175
+
176
+ config = Config()
speaches/core.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: rename module
2
+ from __future__ import annotations
3
+
4
+ import re
5
+ from dataclasses import dataclass
6
+
7
+ from speaches.config import config
8
+
9
+
10
+ # TODO: use the `Segment` from `faster-whisper.transcribe` instead
11
+ @dataclass
12
+ class Segment:
13
+ text: str
14
+ start: float = 0.0
15
+ end: float = 0.0
16
+
17
+ @property
18
+ def is_eos(self) -> bool:
19
+ if self.text.endswith("..."):
20
+ return False
21
+ for punctuation_symbol in ".?!":
22
+ if self.text.endswith(punctuation_symbol):
23
+ return True
24
+ return False
25
+
26
+ def offset(self, seconds: float) -> None:
27
+ self.start += seconds
28
+ self.end += seconds
29
+
30
+
31
+ # TODO: use the `Word` from `faster-whisper.transcribe` instead
32
+ @dataclass
33
+ class Word(Segment):
34
+ probability: float = 0.0
35
+
36
+ @classmethod
37
+ def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
38
+ i = 0
39
+ while (
40
+ i < len(a)
41
+ and i < len(b)
42
+ and canonicalize_word(a[i].text) == canonicalize_word(b[i].text)
43
+ ):
44
+ i += 1
45
+ return a[:i]
46
+
47
+
48
+ class Transcription:
49
+ def __init__(self, words: list[Word] = []) -> None:
50
+ self.words: list[Word] = []
51
+ self.extend(words)
52
+
53
+ @property
54
+ def text(self) -> str:
55
+ return " ".join(word.text for word in self.words).strip()
56
+
57
+ @property
58
+ def start(self) -> float:
59
+ return self.words[0].start if len(self.words) > 0 else 0.0
60
+
61
+ @property
62
+ def end(self) -> float:
63
+ return self.words[-1].end if len(self.words) > 0 else 0.0
64
+
65
+ @property
66
+ def duration(self) -> float:
67
+ return self.end - self.start
68
+
69
+ def after(self, seconds: float) -> Transcription:
70
+ return Transcription(
71
+ words=[word for word in self.words if word.start > seconds]
72
+ )
73
+
74
+ def extend(self, words: list[Word]) -> None:
75
+ self._ensure_no_word_overlap(words)
76
+ self.words.extend(words)
77
+
78
+ def _ensure_no_word_overlap(self, words: list[Word]) -> None:
79
+ if len(self.words) > 0 and len(words) > 0:
80
+ if (
81
+ words[0].start + config.word_timestamp_error_margin
82
+ <= self.words[-1].end
83
+ ):
84
+ raise ValueError(
85
+ f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}"
86
+ )
87
+ for i in range(1, len(words)):
88
+ if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end:
89
+ raise ValueError(
90
+ f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}"
91
+ )
92
+
93
+
94
+ def test_segment_is_eos():
95
+ assert Segment("Hello").is_eos == False
96
+ assert Segment("Hello...").is_eos == False
97
+ assert Segment("Hello.").is_eos == True
98
+ assert Segment("Hello!").is_eos == True
99
+ assert Segment("Hello?").is_eos == True
100
+ assert Segment("Hello. Yo").is_eos == False
101
+ assert Segment("Hello. Yo...").is_eos == False
102
+ assert Segment("Hello. Yo.").is_eos == True
103
+
104
+
105
+ def to_full_sentences(words: list[Word]) -> list[Segment]:
106
+ sentences: list[Segment] = [Segment("")]
107
+ for word in words:
108
+ sentences[-1] = Segment(
109
+ start=sentences[-1].start,
110
+ end=word.end,
111
+ text=sentences[-1].text + word.text,
112
+ )
113
+ if word.is_eos:
114
+ sentences.append(Segment(""))
115
+ if len(sentences) > 0 and not sentences[-1].is_eos:
116
+ sentences.pop()
117
+ return sentences
118
+
119
+
120
+ def tests_to_full_sentences():
121
+ assert to_full_sentences([]) == []
122
+ assert to_full_sentences([Word(text="Hello")]) == []
123
+ assert to_full_sentences([Word(text="Hello..."), Word(" world")]) == []
124
+ assert to_full_sentences([Word(text="Hello..."), Word(" world.")]) == [
125
+ Segment(text="Hello... world.")
126
+ ]
127
+ assert to_full_sentences(
128
+ [Word(text="Hello..."), Word(" world."), Word(" How")]
129
+ ) == [Segment(text="Hello... world.")]
130
+
131
+
132
+ def to_text(words: list[Word]) -> str:
133
+ return "".join(word.text for word in words)
134
+
135
+
136
+ def to_text_w_ts(words: list[Word]) -> str:
137
+ return "".join(f"{word.text}({word.start:.2f}-{word.end:.2f})" for word in words)
138
+
139
+
140
+ def canonicalize_word(text: str) -> str:
141
+ text = text.lower()
142
+ # Remove non-alphabetic characters using regular expression
143
+ text = re.sub(r"[^a-z]", "", text)
144
+ return text.lower().strip().strip(".,?!")
145
+
146
+
147
+ def test_canonicalize_word():
148
+ assert canonicalize_word("ABC") == "abc"
149
+ assert canonicalize_word("...ABC?") == "abc"
150
+ assert canonicalize_word("... AbC ...") == "abc"
151
+
152
+
153
+ def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
154
+ i = 0
155
+ while (
156
+ i < len(a)
157
+ and i < len(b)
158
+ and canonicalize_word(a[i].text) == canonicalize_word(b[i].text)
159
+ ):
160
+ i += 1
161
+ return a[:i]
162
+
163
+
164
+ def test_common_prefix():
165
+ def word(text: str) -> Word:
166
+ return Word(text=text, start=0.0, end=0.0, probability=0.0)
167
+
168
+ a = [word("a"), word("b"), word("c")]
169
+ b = [word("a"), word("b"), word("c")]
170
+ assert common_prefix(a, b) == [word("a"), word("b"), word("c")]
171
+
172
+ a = [word("a"), word("b"), word("c")]
173
+ b = [word("a"), word("b"), word("d")]
174
+ assert common_prefix(a, b) == [word("a"), word("b")]
175
+
176
+ a = [word("a"), word("b"), word("c")]
177
+ b = [word("a")]
178
+ assert common_prefix(a, b) == [word("a")]
179
+
180
+ a = [word("a")]
181
+ b = [word("a"), word("b"), word("c")]
182
+ assert common_prefix(a, b) == [word("a")]
183
+
184
+ a = [word("a")]
185
+ b = []
186
+ assert common_prefix(a, b) == []
187
+
188
+ a = []
189
+ b = [word("a")]
190
+ assert common_prefix(a, b) == []
191
+
192
+ a = [word("a"), word("b"), word("c")]
193
+ b = [word("b"), word("c")]
194
+ assert common_prefix(a, b) == []
195
+
196
+
197
+ def test_common_prefix_and_canonicalization():
198
+ def word(text: str) -> Word:
199
+ return Word(text=text, start=0.0, end=0.0, probability=0.0)
200
+
201
+ a = [word("A...")]
202
+ b = [word("a?"), word("b"), word("c")]
203
+ assert common_prefix(a, b) == [word("A...")]
204
+
205
+ a = [word("A..."), word("B?"), word("C,")]
206
+ b = [word("a??"), word(" b"), word(" ,c")]
207
+ assert common_prefix(a, b) == [word("A..."), word("B?"), word("C,")]
speaches/logger.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from speaches.config import config
4
+
5
+ # Disables all but `speaches` logger
6
+
7
+ root_logger = logging.getLogger()
8
+ root_logger.setLevel(logging.CRITICAL)
9
+ logger = logging.getLogger(__name__)
10
+ logger.setLevel(config.log_level.upper())
11
+ logging.basicConfig(
12
+ format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s"
13
+ )
speaches/main.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import time
6
+ from contextlib import asynccontextmanager
7
+ from io import BytesIO
8
+ from typing import Annotated
9
+
10
+ from fastapi import (
11
+ Depends,
12
+ FastAPI,
13
+ Response,
14
+ UploadFile,
15
+ WebSocket,
16
+ WebSocketDisconnect,
17
+ )
18
+ from fastapi.websockets import WebSocketState
19
+ from faster_whisper import WhisperModel
20
+ from faster_whisper.vad import VadOptions, get_speech_timestamps
21
+
22
+ from speaches.asr import FasterWhisperASR, TranscribeOpts
23
+ from speaches.audio import AudioStream, audio_samples_from_file
24
+ from speaches.config import SAMPLES_PER_SECOND, Language, config
25
+ from speaches.core import Transcription
26
+ from speaches.logger import logger
27
+ from speaches.server_models import (
28
+ ResponseFormat,
29
+ TranscriptionResponse,
30
+ TranscriptionVerboseResponse,
31
+ )
32
+ from speaches.transcriber import audio_transcriber
33
+
34
+ whisper: WhisperModel = None # type: ignore
35
+
36
+
37
+ @asynccontextmanager
38
+ async def lifespan(_: FastAPI):
39
+ global whisper
40
+ logging.debug(f"Loading {config.whisper.model}")
41
+ start = time.perf_counter()
42
+ whisper = WhisperModel(
43
+ config.whisper.model,
44
+ device=config.whisper.inference_device,
45
+ compute_type=config.whisper.compute_type,
46
+ )
47
+ end = time.perf_counter()
48
+ logger.debug(f"Loaded {config.whisper.model} loaded in {end - start:.2f} seconds")
49
+ yield
50
+
51
+
52
+ app = FastAPI(lifespan=lifespan)
53
+
54
+
55
+ @app.get("/health")
56
+ def health() -> Response:
57
+ return Response(status_code=200, content="Everything is peachy!")
58
+
59
+
60
+ async def transcription_parameters(
61
+ language: Language = Language.EN,
62
+ vad_filter: bool = True,
63
+ condition_on_previous_text: bool = False,
64
+ ) -> TranscribeOpts:
65
+ return TranscribeOpts(
66
+ language=language,
67
+ vad_filter=vad_filter,
68
+ condition_on_previous_text=condition_on_previous_text,
69
+ )
70
+
71
+
72
+ TranscribeParams = Annotated[TranscribeOpts, Depends(transcription_parameters)]
73
+
74
+
75
+ @app.post("/v1/audio/transcriptions")
76
+ async def transcribe_file(
77
+ file: UploadFile,
78
+ transcription_opts: TranscribeParams,
79
+ response_format: ResponseFormat = ResponseFormat.JSON,
80
+ ) -> str:
81
+ asr = FasterWhisperASR(whisper, transcription_opts)
82
+ audio_samples = audio_samples_from_file(file.file)
83
+ audio = AudioStream(audio_samples)
84
+ transcription, _ = await asr.transcribe(audio)
85
+ return format_transcription(transcription, response_format)
86
+
87
+
88
+ async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
89
+ try:
90
+ while True:
91
+ bytes_ = await asyncio.wait_for(
92
+ ws.receive_bytes(), timeout=config.max_no_data_seconds
93
+ )
94
+ logger.debug(f"Received {len(bytes_)} bytes of audio data")
95
+ audio_samples = audio_samples_from_file(BytesIO(bytes_))
96
+ audio_stream.extend(audio_samples)
97
+ if audio_stream.duration - config.inactivity_window_seconds >= 0:
98
+ audio = audio_stream.after(
99
+ audio_stream.duration - config.inactivity_window_seconds
100
+ )
101
+ vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
102
+ timestamps = get_speech_timestamps(audio.data, vad_opts)
103
+ if len(timestamps) == 0:
104
+ logger.info(
105
+ f"No speech detected in the last {config.inactivity_window_seconds} seconds."
106
+ )
107
+ break
108
+ elif (
109
+ # last speech end time
110
+ config.inactivity_window_seconds
111
+ - timestamps[-1]["end"] / SAMPLES_PER_SECOND
112
+ >= config.max_inactivity_seconds
113
+ ):
114
+ logger.info(
115
+ f"Not enough speech in the last {config.inactivity_window_seconds} seconds."
116
+ )
117
+ break
118
+ except asyncio.TimeoutError:
119
+ logger.info(
120
+ f"No data received in {config.max_no_data_seconds} seconds. Closing the connection."
121
+ )
122
+ except WebSocketDisconnect as e:
123
+ logger.info(f"Client disconnected: {e}")
124
+ audio_stream.close()
125
+
126
+
127
+ def format_transcription(
128
+ transcription: Transcription, response_format: ResponseFormat
129
+ ) -> str:
130
+ if response_format == ResponseFormat.TEXT:
131
+ return transcription.text
132
+ elif response_format == ResponseFormat.JSON:
133
+ return TranscriptionResponse(text=transcription.text).model_dump_json()
134
+ elif response_format == ResponseFormat.VERBOSE_JSON:
135
+ return TranscriptionVerboseResponse(
136
+ duration=transcription.duration,
137
+ text=transcription.text,
138
+ words=transcription.words,
139
+ ).model_dump_json()
140
+
141
+
142
+ @app.websocket("/v1/audio/transcriptions")
143
+ async def transcribe_stream(
144
+ ws: WebSocket,
145
+ transcription_opts: TranscribeParams,
146
+ response_format: ResponseFormat = ResponseFormat.JSON,
147
+ ) -> None:
148
+ await ws.accept()
149
+ asr = FasterWhisperASR(whisper, transcription_opts)
150
+ audio_stream = AudioStream()
151
+ async with asyncio.TaskGroup() as tg:
152
+ tg.create_task(audio_receiver(ws, audio_stream))
153
+ async for transcription in audio_transcriber(asr, audio_stream):
154
+ logger.debug(f"Sending transcription: {transcription.text}")
155
+ # Or should it be
156
+ if ws.client_state == WebSocketState.DISCONNECTED:
157
+ break
158
+ await ws.send_text(format_transcription(transcription, response_format))
159
+
160
+ if not ws.client_state == WebSocketState.DISCONNECTED:
161
+ # this means that the client HASNT disconnected
162
+ await ws.close()
speaches/server_models.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from speaches.core import Word
6
+
7
+
8
+ class ResponseFormat(enum.StrEnum):
9
+ JSON = "json"
10
+ TEXT = "text"
11
+ VERBOSE_JSON = "verbose_json"
12
+
13
+
14
+ # https://platform.openai.com/docs/api-reference/audio/json-object
15
+ class TranscriptionResponse(BaseModel):
16
+ text: str
17
+
18
+
19
+ # Subset of https://platform.openai.com/docs/api-reference/audio/verbose-json-object
20
+ class TranscriptionVerboseResponse(BaseModel):
21
+ task: str = "transcribe"
22
+ duration: float
23
+ text: str
24
+ words: list[
25
+ Word
26
+ ] # Different from OpenAI's `words`. `Word.text` instead of `Word.word`
speaches/transcriber.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import AsyncGenerator
4
+
5
+ from speaches.asr import FasterWhisperASR
6
+ from speaches.audio import Audio, AudioStream
7
+ from speaches.config import config
8
+ from speaches.core import Transcription, Word, common_prefix, to_full_sentences
9
+ from speaches.logger import logger
10
+
11
+
12
+ class LocalAgreement:
13
+ def __init__(self) -> None:
14
+ self.unconfirmed = Transcription()
15
+
16
+ def merge(self, confirmed: Transcription, incoming: Transcription) -> list[Word]:
17
+ # https://github.com/ufal/whisper_streaming/blob/main/whisper_online.py#L264
18
+ incoming = incoming.after(confirmed.end - 0.1)
19
+ prefix = common_prefix(incoming.words, self.unconfirmed.words)
20
+ logger.debug(f"Confirmed: {confirmed.text}")
21
+ logger.debug(f"Unconfirmed: {self.unconfirmed.text}")
22
+ logger.debug(f"Incoming: {incoming.text}")
23
+
24
+ if len(incoming.words) > len(prefix):
25
+ self.unconfirmed = Transcription(incoming.words[len(prefix) :])
26
+ else:
27
+ self.unconfirmed = Transcription()
28
+
29
+ return prefix
30
+
31
+ @classmethod
32
+ def prompt(cls, confirmed: Transcription) -> str | None:
33
+ sentences = to_full_sentences(confirmed.words)
34
+ if len(sentences) == 0:
35
+ return None
36
+ return sentences[-1].text
37
+
38
+ # TODO: better name
39
+ @classmethod
40
+ def needs_audio_after(cls, confirmed: Transcription) -> float:
41
+ full_sentences = to_full_sentences(confirmed.words)
42
+ return full_sentences[-1].end if len(full_sentences) > 0 else 0.0
43
+
44
+
45
+ def needs_audio_after(confirmed: Transcription) -> float:
46
+ full_sentences = to_full_sentences(confirmed.words)
47
+ return full_sentences[-1].end if len(full_sentences) > 0 else 0.0
48
+
49
+
50
+ def prompt(confirmed: Transcription) -> str | None:
51
+ sentences = to_full_sentences(confirmed.words)
52
+ if len(sentences) == 0:
53
+ return None
54
+ return sentences[-1].text
55
+
56
+
57
+ async def audio_transcriber(
58
+ asr: FasterWhisperASR,
59
+ audio_stream: AudioStream,
60
+ ) -> AsyncGenerator[Transcription, None]:
61
+ local_agreement = LocalAgreement()
62
+ full_audio = Audio()
63
+ confirmed = Transcription()
64
+ async for chunk in audio_stream.chunks(config.min_duration):
65
+ full_audio.extend(chunk)
66
+ audio = full_audio.after(needs_audio_after(confirmed))
67
+ transcription, _ = await asr.transcribe(audio, prompt(confirmed))
68
+ new_words = local_agreement.merge(confirmed, transcription)
69
+ if len(new_words) > 0:
70
+ confirmed.extend(new_words)
71
+ yield confirmed
72
+ logger.debug("Flushing...")
73
+ confirmed.extend(local_agreement.unconfirmed.words)
74
+ yield confirmed
75
+ logger.info("Audio transcriber finished")
tests/__init__.py ADDED
File without changes
tests/app_test.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import threading
4
+ import time
5
+ from difflib import SequenceMatcher
6
+ from typing import Generator
7
+
8
+ import pytest
9
+ from fastapi import WebSocketDisconnect
10
+ from fastapi.testclient import TestClient
11
+ from starlette.testclient import WebSocketTestSession
12
+
13
+ from speaches.config import BYTES_PER_SECOND
14
+ from speaches.main import app
15
+ from speaches.server_models import TranscriptionVerboseResponse
16
+
17
+ SIMILARITY_THRESHOLD = 0.97
18
+
19
+
20
+ @pytest.fixture()
21
+ def client() -> Generator[TestClient, None, None]:
22
+ with TestClient(app) as client:
23
+ yield client
24
+
25
+
26
+ def get_audio_file_paths():
27
+ file_paths = []
28
+ directory = "tests/data"
29
+ for filename in reversed(os.listdir(directory)[5:6]):
30
+ if filename.endswith(".raw"):
31
+ file_paths.append(os.path.join(directory, filename))
32
+ return file_paths
33
+
34
+
35
+ file_paths = get_audio_file_paths()
36
+
37
+
38
+ def stream_audio_data(
39
+ ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0
40
+ ):
41
+ for i in range(0, len(data), chunk_size):
42
+ ws.send_bytes(data[i : i + chunk_size])
43
+ delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed
44
+ time.sleep(delay)
45
+
46
+
47
+ def transcribe_audio_data(
48
+ client: TestClient, data: bytes
49
+ ) -> TranscriptionVerboseResponse:
50
+ response = client.post(
51
+ "/v1/audio/transcriptions?response_format=verbose_json",
52
+ files={"file": ("audio.raw", data, "audio/raw")},
53
+ )
54
+ data = json.loads(response.json()) # TODO: figure this out
55
+ return TranscriptionVerboseResponse(**data) # type: ignore
56
+
57
+
58
+ @pytest.mark.parametrize("file_path", file_paths)
59
+ def test_ws_audio_transcriptions(client: TestClient, file_path: str):
60
+ with open(file_path, "rb") as file:
61
+ data = file.read()
62
+ streaming_transcription: TranscriptionVerboseResponse = None # type: ignore
63
+ with client.websocket_connect(
64
+ "/v1/audio/transcriptions?response_format=verbose_json"
65
+ ) as ws:
66
+ thread = threading.Thread(
67
+ target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
68
+ )
69
+ thread.start()
70
+ while True:
71
+ try:
72
+ streaming_transcription = TranscriptionVerboseResponse(
73
+ **ws.receive_json()
74
+ )
75
+ except WebSocketDisconnect:
76
+ break
77
+ ws.close()
78
+ file_transcription = transcribe_audio_data(client, data)
79
+ s = SequenceMatcher(
80
+ lambda x: x == " ", file_transcription.text, streaming_transcription.text
81
+ )
82
+ assert (
83
+ s.ratio() > SIMILARITY_THRESHOLD
84
+ ), f"\nExpected: {file_transcription.text}\nReceived: {streaming_transcription.text}"
tests/conftest.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ disable_loggers = ["multipart.multipart", "faster_whisper"]
4
+
5
+
6
+ def pytest_configure():
7
+ for logger_name in disable_loggers:
8
+ logger = logging.getLogger(logger_name)
9
+ logger.disabled = True