Spaces:
Sleeping
Sleeping
File size: 5,408 Bytes
5ef2360 74081c9 ce1a70a 5ef2360 4cc073e 5ef2360 489ba9a 5ef2360 9a3524a 5ef2360 489ba9a 5ef2360 9a3524a 5ef2360 489ba9a 5ef2360 489ba9a 5ef2360 489ba9a 5ef2360 489ba9a 2e32fcb 489ba9a 5ef2360 489ba9a c1d862e 5ef2360 74081c9 489ba9a 74081c9 5ef2360 74081c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import logging
import uvicorn
import json
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
AlgoOptions,
SileroVadOptions,
audio_to_bytes,
)
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
from transformers.utils import is_flash_attn_2_available
from utils.logger_config import setup_logging
from utils.device import get_device, get_torch_and_np_dtypes, cuda_version_check
from utils.turn_server import get_rtc_credentials
load_dotenv()
setup_logging(level=logging.DEBUG)
logger = logging.getLogger(__name__)
APP_MODE = os.getenv("APP_MODE", "deployed")
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
UI_FILE = os.getenv("UI_FILE", "index.html")
device = get_device(force_cpu=False)
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
cuda_version, device_name = cuda_version_check()
logger.info(f"CUDA Version: {cuda_version}, GPU Device: {device_name}")
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
logger.info(f"Using attention: {attention}")
logger.info(f"Loading Whisper model: {MODEL_ID}")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attention
)
model.to(device)
processor = AutoProcessor.from_pretrained(MODEL_ID)
transcribe_pipeline = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
transcribe_pipeline(warmup_audio)
logger.info("Model warmup complete")
async def transcribe(audio: tuple[int, np.ndarray]):
sample_rate, audio_array = audio
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
outputs = transcribe_pipeline(
audio_to_bytes(audio),
chunk_length_s=3,
batch_size=1,
generate_kwargs={
'task': 'transcribe',
'language': 'english',
},
#return_timestamps="word"
)
yield AdditionalOutputs(outputs["text"].strip())
logger.info("Initializing FastRTC stream")
stream = Stream(
handler=ReplyOnPause(
transcribe,
algo_options=AlgoOptions(
# Duration in seconds of audio chunks (default 0.6)
audio_chunk_duration=0.6,
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
started_talking_threshold=0.2,
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
speech_threshold=0.1,
),
model_options=SileroVadOptions(
# Threshold for what is considered speech (default 0.5)
threshold=0.5,
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
min_speech_duration_ms=250,
# Max duration of speech chunks, longer will be split (default float('inf'))
max_speech_duration_s=3,
# Wait for ms at the end of each speech chunk before separating it (default 2000)
min_silence_duration_ms=2000,
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
window_size_samples=1024,
# Final speech chunks are padded by speech_pad_ms each side (default 400)
speech_pad_ms=400,
),
),
# send-receive: bidirectional streaming (default)
# send: client to server only
# receive: server to client only
modality="audio",
mode="send",
additional_outputs=[
gr.Textbox(label="Transcript"),
],
additional_outputs_handler=lambda current, new: current + " " + new,
rtc_configuration=get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None,
concurrency_limit=6
)
app = FastAPI()
stream.mount(app)
@app.get("/")
async def index():
html_content = open(UI_FILE).read()
rtc_config = get_rtc_credentials(provider="hf") if APP_MODE == "deployed" else None
return HTMLResponse(content=html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)))
@app.get("/transcript")
def _(webrtc_id: str):
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
async def output_stream():
try:
async for output in stream.output_stream(webrtc_id):
transcript = output.args[0]
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
yield f"event: output\ndata: {transcript}\n\n"
except Exception as e:
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
raise
return StreamingResponse(output_stream(), media_type="text/event-stream") |