Fedir Zadniprovskyi
init
313814b
raw
history blame
5.6 kB
from __future__ import annotations
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from io import BytesIO
from typing import Annotated
from fastapi import (
Depends,
FastAPI,
Response,
UploadFile,
WebSocket,
WebSocketDisconnect,
)
from fastapi.websockets import WebSocketState
from faster_whisper import WhisperModel
from faster_whisper.vad import VadOptions, get_speech_timestamps
from speaches.asr import FasterWhisperASR, TranscribeOpts
from speaches.audio import AudioStream, audio_samples_from_file
from speaches.config import SAMPLES_PER_SECOND, Language, config
from speaches.core import Transcription
from speaches.logger import logger
from speaches.server_models import (
ResponseFormat,
TranscriptionResponse,
TranscriptionVerboseResponse,
)
from speaches.transcriber import audio_transcriber
whisper: WhisperModel = None # type: ignore
@asynccontextmanager
async def lifespan(_: FastAPI):
global whisper
logging.debug(f"Loading {config.whisper.model}")
start = time.perf_counter()
whisper = WhisperModel(
config.whisper.model,
device=config.whisper.inference_device,
compute_type=config.whisper.compute_type,
)
end = time.perf_counter()
logger.debug(f"Loaded {config.whisper.model} loaded in {end - start:.2f} seconds")
yield
app = FastAPI(lifespan=lifespan)
@app.get("/health")
def health() -> Response:
return Response(status_code=200, content="Everything is peachy!")
async def transcription_parameters(
language: Language = Language.EN,
vad_filter: bool = True,
condition_on_previous_text: bool = False,
) -> TranscribeOpts:
return TranscribeOpts(
language=language,
vad_filter=vad_filter,
condition_on_previous_text=condition_on_previous_text,
)
TranscribeParams = Annotated[TranscribeOpts, Depends(transcription_parameters)]
@app.post("/v1/audio/transcriptions")
async def transcribe_file(
file: UploadFile,
transcription_opts: TranscribeParams,
response_format: ResponseFormat = ResponseFormat.JSON,
) -> str:
asr = FasterWhisperASR(whisper, transcription_opts)
audio_samples = audio_samples_from_file(file.file)
audio = AudioStream(audio_samples)
transcription, _ = await asr.transcribe(audio)
return format_transcription(transcription, response_format)
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
try:
while True:
bytes_ = await asyncio.wait_for(
ws.receive_bytes(), timeout=config.max_no_data_seconds
)
logger.debug(f"Received {len(bytes_)} bytes of audio data")
audio_samples = audio_samples_from_file(BytesIO(bytes_))
audio_stream.extend(audio_samples)
if audio_stream.duration - config.inactivity_window_seconds >= 0:
audio = audio_stream.after(
audio_stream.duration - config.inactivity_window_seconds
)
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
timestamps = get_speech_timestamps(audio.data, vad_opts)
if len(timestamps) == 0:
logger.info(
f"No speech detected in the last {config.inactivity_window_seconds} seconds."
)
break
elif (
# last speech end time
config.inactivity_window_seconds
- timestamps[-1]["end"] / SAMPLES_PER_SECOND
>= config.max_inactivity_seconds
):
logger.info(
f"Not enough speech in the last {config.inactivity_window_seconds} seconds."
)
break
except asyncio.TimeoutError:
logger.info(
f"No data received in {config.max_no_data_seconds} seconds. Closing the connection."
)
except WebSocketDisconnect as e:
logger.info(f"Client disconnected: {e}")
audio_stream.close()
def format_transcription(
transcription: Transcription, response_format: ResponseFormat
) -> str:
if response_format == ResponseFormat.TEXT:
return transcription.text
elif response_format == ResponseFormat.JSON:
return TranscriptionResponse(text=transcription.text).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
return TranscriptionVerboseResponse(
duration=transcription.duration,
text=transcription.text,
words=transcription.words,
).model_dump_json()
@app.websocket("/v1/audio/transcriptions")
async def transcribe_stream(
ws: WebSocket,
transcription_opts: TranscribeParams,
response_format: ResponseFormat = ResponseFormat.JSON,
) -> None:
await ws.accept()
asr = FasterWhisperASR(whisper, transcription_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream):
logger.debug(f"Sending transcription: {transcription.text}")
# Or should it be
if ws.client_state == WebSocketState.DISCONNECTED:
break
await ws.send_text(format_transcription(transcription, response_format))
if not ws.client_state == WebSocketState.DISCONNECTED:
# this means that the client HASNT disconnected
await ws.close()