Spaces:
Paused
Paused
import io | |
import time | |
import torch | |
import requests | |
import tempfile | |
import numpy as np | |
import soundfile as sf | |
from fastapi import FastAPI, HTTPException | |
from transformers import AutoModel | |
from pydantic import BaseModel | |
from typing import Optional, Dict | |
from starlette.responses import StreamingResponse | |
from fastapi.responses import RedirectResponse | |
# Check if flash-attn is available | |
try: | |
from flash_attn import flash_attention | |
FLASH_ATTENTION_AVAILABLE = True | |
except ImportError: | |
FLASH_ATTENTION_AVAILABLE = False | |
print("Flash Attention not available. Install with 'pip install flash-attn' for better performance.") | |
# Initialize FastAPI app | |
app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output") | |
# Load TTS model globally with optimizations | |
repo_id = "ai4bharat/IndicF5" | |
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() # Set model to evaluation mode | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
print("Device:", device) | |
# Precompile model if possible (PyTorch 2.0+) | |
if hasattr(torch, "compile"): | |
model = torch.compile(model) | |
# Example Data | |
EXAMPLES = [ | |
{ | |
"audio_name": "KAN_F (Happy)", | |
"audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav", | |
"ref_text": "ನಮ್ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.", | |
"synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ." | |
}, | |
] | |
# Pydantic models | |
class SynthesizeRequest(BaseModel): | |
text: str | |
ref_audio_name: str | |
ref_text: Optional[str] = None | |
class KannadaSynthesizeRequest(BaseModel): | |
text: str | |
# Cache for reference audio | |
audio_cache = {} | |
def load_audio_from_url(url: str) -> tuple: | |
start_time = time.time() | |
if url in audio_cache: | |
return audio_cache[url] | |
response = requests.get(url, timeout=10) | |
if response.status_code == 200: | |
audio_data, sample_rate = sf.read(io.BytesIO(response.content)) | |
audio_cache[url] = (sample_rate, audio_data) | |
return sample_rate, audio_data | |
raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.") | |
def synthesize_speech(text: str, ref_audio_name: str, ref_text: str) -> tuple[io.BytesIO, Dict[str, float]]: | |
timing = {} | |
start_total = time.time() | |
# Find matching example | |
ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None) | |
if not ref_audio_url: | |
raise HTTPException(status_code=400, detail="Invalid reference audio name.") | |
if not text.strip() or (not ref_text or not ref_text.strip()): | |
raise HTTPException(status_code=400, detail="Text fields cannot be empty.") | |
# Load reference audio | |
start_audio_load = time.time() | |
sample_rate, audio_data = load_audio_from_url(ref_audio_url) | |
timing["audio_load"] = time.time() - start_audio_load | |
# Save reference audio to temp file | |
start_temp = time.time() | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: | |
sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV') | |
temp_audio.flush() | |
# Inference with Flash Attention | |
start_inference = time.time() | |
with torch.no_grad(): | |
if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available(): | |
# Assuming model has an attention mechanism we can override | |
# This is a placeholder; actual implementation depends on model internals | |
try: | |
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text, attention_impl="flash") | |
except AttributeError: | |
print("Warning: Model does not support custom attention_impl. Using default.") | |
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text) | |
else: | |
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text) | |
timing["inference"] = time.time() - start_inference | |
timing["temp_file"] = time.time() - start_temp | |
# Normalize audio | |
start_normalize = time.time() | |
if audio.dtype == np.int16: | |
audio = audio.astype(np.float32) / 32768.0 | |
timing["normalize"] = time.time() - start_normalize | |
# Save to buffer | |
start_buffer = time.time() | |
buffer = io.BytesIO() | |
sf.write(buffer, audio, 24000, format='WAV') | |
buffer.seek(0) | |
timing["buffer"] = time.time() - start_buffer | |
timing["total"] = time.time() - start_total | |
return buffer, timing | |
async def synthesize_kannada(request: KannadaSynthesizeRequest): | |
start_time = time.time() | |
kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)") | |
if not request.text.strip(): | |
raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.") | |
audio_buffer, timing = synthesize_speech( | |
text=request.text, | |
ref_audio_name="KAN_F (Happy)", | |
ref_text=kannada_example["ref_text"] | |
) | |
print(f"Synthesis completed in {timing['total']:.2f} seconds: {timing}") | |
return StreamingResponse( | |
audio_buffer, | |
media_type="audio/wav", | |
headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"} | |
) | |
async def home(): | |
return RedirectResponse(url="/docs") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |