|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
from urllib.parse import urlparse, parse_qsl |
|
from transcribe.whisper_llm_serve import WhisperTranscriptionService |
|
from uuid import uuid1 |
|
from logging import getLogger |
|
import numpy as np |
|
from transcribe.translatepipes import TranslatePipes |
|
from contextlib import asynccontextmanager |
|
from multiprocessing import Process, freeze_support |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import RedirectResponse |
|
import os |
|
from transcribe.utils import pcm_bytes_to_np_array |
|
from config import BASE_DIR |
|
logger = getLogger(__name__) |
|
|
|
|
|
async def get_audio_from_websocket(websocket)->np.array: |
|
""" |
|
Receives audio buffer from websocket and creates a numpy array out of it. |
|
|
|
Args: |
|
websocket: The websocket to receive audio from. |
|
|
|
Returns: |
|
A numpy array containing the audio. |
|
""" |
|
frame_data = await websocket.receive_bytes() |
|
if frame_data == b"END_OF_AUDIO": |
|
return False |
|
return pcm_bytes_to_np_array(frame_data) |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app:FastAPI): |
|
global pipe |
|
pipe = TranslatePipes() |
|
pipe.wait_ready() |
|
logger.info("Pipeline is ready.") |
|
yield |
|
|
|
|
|
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend") |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
app.mount("/app", StaticFiles(directory=FRONTEND_DIR, html=True), name="frontend") |
|
pipe = None |
|
|
|
@app.get("/") |
|
async def root(): |
|
return RedirectResponse(url="/app/") |
|
|
|
@app.websocket("/ws") |
|
async def translate(websocket: WebSocket): |
|
query_parameters_dict = websocket.query_params |
|
from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to') |
|
|
|
client = WhisperTranscriptionService( |
|
websocket, |
|
pipe, |
|
language=from_lang, |
|
dst_lang=to_lang, |
|
client_uid=f"{uuid1()}", |
|
) |
|
|
|
if from_lang and to_lang and client: |
|
logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}") |
|
await websocket.accept() |
|
try: |
|
while True: |
|
frame_data = await get_audio_from_websocket(websocket) |
|
client.add_frames(frame_data) |
|
except WebSocketDisconnect: |
|
return |
|
|
|
if __name__ == '__main__': |
|
freeze_support() |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=9191) |
|
|