Sofia Casadei
up: upgrade
2e32fcb
import os
import logging
import uvicorn
import json
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
AlgoOptions,
SileroVadOptions,
audio_to_bytes,
)
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
from transformers.utils import is_flash_attn_2_available
from utils.logger_config import setup_logging
from utils.device import get_device, get_torch_and_np_dtypes, cuda_version_check
from utils.turn_server import get_rtc_credentials
load_dotenv()
setup_logging(level=logging.DEBUG)
logger = logging.getLogger(__name__)
APP_MODE = os.getenv("APP_MODE", "deployed")
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
UI_FILE = os.getenv("UI_FILE", "index.html")
device = get_device(force_cpu=False)
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
cuda_version, device_name = cuda_version_check()
logger.info(f"CUDA Version: {cuda_version}, GPU Device: {device_name}")
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
logger.info(f"Using attention: {attention}")
logger.info(f"Loading Whisper model: {MODEL_ID}")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attention
)
model.to(device)
processor = AutoProcessor.from_pretrained(MODEL_ID)
transcribe_pipeline = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
transcribe_pipeline(warmup_audio)
logger.info("Model warmup complete")
async def transcribe(audio: tuple[int, np.ndarray]):
sample_rate, audio_array = audio
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
outputs = transcribe_pipeline(
audio_to_bytes(audio),
chunk_length_s=3,
batch_size=1,
generate_kwargs={
'task': 'transcribe',
'language': 'english',
},
#return_timestamps="word"
)
yield AdditionalOutputs(outputs["text"].strip())
logger.info("Initializing FastRTC stream")
stream = Stream(
handler=ReplyOnPause(
transcribe,
algo_options=AlgoOptions(
# Duration in seconds of audio chunks (default 0.6)
audio_chunk_duration=0.6,
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
started_talking_threshold=0.2,
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
speech_threshold=0.1,
),
model_options=SileroVadOptions(
# Threshold for what is considered speech (default 0.5)
threshold=0.5,
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
min_speech_duration_ms=250,
# Max duration of speech chunks, longer will be split (default float('inf'))
max_speech_duration_s=3,
# Wait for ms at the end of each speech chunk before separating it (default 2000)
min_silence_duration_ms=2000,
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
window_size_samples=1024,
# Final speech chunks are padded by speech_pad_ms each side (default 400)
speech_pad_ms=400,
),
),
# send-receive: bidirectional streaming (default)
# send: client to server only
# receive: server to client only
modality="audio",
mode="send",
additional_outputs=[
gr.Textbox(label="Transcript"),
],
additional_outputs_handler=lambda current, new: current + " " + new,
rtc_configuration=get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None,
concurrency_limit=6
)
app = FastAPI()
stream.mount(app)
@app.get("/")
async def index():
html_content = open(UI_FILE).read()
rtc_config = get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None
return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)))
@app.get("/transcript")
def _(webrtc_id: str):
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
async def output_stream():
try:
async for output in stream.output_stream(webrtc_id):
transcript = output.args[0]
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
yield f"event: output\ndata: {transcript}\n\n"
except Exception as e:
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
raise
return StreamingResponse(output_stream(), media_type="text/event-stream")