Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import uvicorn | |
import json | |
import gradio as gr | |
import numpy as np | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from fastapi.responses import StreamingResponse, HTMLResponse | |
from fastrtc import ( | |
AdditionalOutputs, | |
ReplyOnPause, | |
Stream, | |
AlgoOptions, | |
SileroVadOptions, | |
audio_to_bytes, | |
) | |
from transformers import ( | |
AutoModelForSpeechSeq2Seq, | |
AutoProcessor, | |
pipeline, | |
) | |
from transformers.utils import is_flash_attn_2_available | |
from utils.logger_config import setup_logging | |
from utils.device import get_device, get_torch_and_np_dtypes, cuda_version_check | |
from utils.turn_server import get_rtc_credentials | |
load_dotenv() | |
setup_logging(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
APP_MODE = os.getenv("APP_MODE", "deployed") | |
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo") | |
UI_FILE = os.getenv("UI_FILE", "index.html") | |
device = get_device(force_cpu=False) | |
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) | |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") | |
cuda_version, device_name = cuda_version_check() | |
logger.info(f"CUDA Version: {cuda_version}, GPU Device: {device_name}") | |
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" | |
logger.info(f"Using attention: {attention}") | |
logger.info(f"Loading Whisper model: {MODEL_ID}") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
attn_implementation=attention | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
transcribe_pipeline = pipeline( | |
task="automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
# Warm up the model with empty audio | |
logger.info("Warming up Whisper model with dummy input") | |
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence | |
transcribe_pipeline(warmup_audio) | |
logger.info("Model warmup complete") | |
async def transcribe(audio: tuple[int, np.ndarray]): | |
sample_rate, audio_array = audio | |
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") | |
outputs = transcribe_pipeline( | |
audio_to_bytes(audio), | |
chunk_length_s=3, | |
batch_size=1, | |
generate_kwargs={ | |
'task': 'transcribe', | |
'language': 'english', | |
}, | |
#return_timestamps="word" | |
) | |
yield AdditionalOutputs(outputs["text"].strip()) | |
logger.info("Initializing FastRTC stream") | |
stream = Stream( | |
handler=ReplyOnPause( | |
transcribe, | |
algo_options=AlgoOptions( | |
# Duration in seconds of audio chunks (default 0.6) | |
audio_chunk_duration=0.6, | |
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2) | |
started_talking_threshold=0.2, | |
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1) | |
speech_threshold=0.1, | |
), | |
model_options=SileroVadOptions( | |
# Threshold for what is considered speech (default 0.5) | |
threshold=0.5, | |
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250) | |
min_speech_duration_ms=250, | |
# Max duration of speech chunks, longer will be split (default float('inf')) | |
max_speech_duration_s=3, | |
# Wait for ms at the end of each speech chunk before separating it (default 2000) | |
min_silence_duration_ms=2000, | |
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024) | |
window_size_samples=1024, | |
# Final speech chunks are padded by speech_pad_ms each side (default 400) | |
speech_pad_ms=400, | |
), | |
), | |
# send-receive: bidirectional streaming (default) | |
# send: client to server only | |
# receive: server to client only | |
modality="audio", | |
mode="send", | |
additional_outputs=[ | |
gr.Textbox(label="Transcript"), | |
], | |
additional_outputs_handler=lambda current, new: current + " " + new, | |
rtc_configuration=get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None, | |
concurrency_limit=6 | |
) | |
app = FastAPI() | |
stream.mount(app) | |
async def index(): | |
html_content = open(UI_FILE).read() | |
rtc_config = get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None | |
return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))) | |
def _(webrtc_id: str): | |
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}") | |
async def output_stream(): | |
try: | |
async for output in stream.output_stream(webrtc_id): | |
transcript = output.args[0] | |
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...") | |
yield f"event: output\ndata: {transcript}\n\n" | |
except Exception as e: | |
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}") | |
raise | |
return StreamingResponse(output_stream(), media_type="text/event-stream") |