tts-indic-f5 / tts_api.py
sachin
tet
457fdad
import io
import time
import torch
import requests
import tempfile
import numpy as np
import soundfile as sf
from fastapi import FastAPI, HTTPException
from transformers import AutoModel
from pydantic import BaseModel
from typing import Optional, Dict
from starlette.responses import StreamingResponse
from fastapi.responses import RedirectResponse
# Check if flash-attn is available
try:
from flash_attn import flash_attention
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
print("Flash Attention not available. Install with 'pip install flash-attn' for better performance.")
# Initialize FastAPI app
app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output")
# Load TTS model globally with optimizations
repo_id = "ai4bharat/IndicF5"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval() # Set model to evaluation mode
if torch.cuda.is_available():
torch.cuda.synchronize()
print("Device:", device)
# Precompile model if possible (PyTorch 2.0+)
if hasattr(torch, "compile"):
model = torch.compile(model)
# Example Data
EXAMPLES = [
{
"audio_name": "KAN_F (Happy)",
"audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
"ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
"synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
},
]
# Pydantic models
class SynthesizeRequest(BaseModel):
text: str
ref_audio_name: str
ref_text: Optional[str] = None
class KannadaSynthesizeRequest(BaseModel):
text: str
# Cache for reference audio
audio_cache = {}
def load_audio_from_url(url: str) -> tuple:
start_time = time.time()
if url in audio_cache:
return audio_cache[url]
response = requests.get(url, timeout=10)
if response.status_code == 200:
audio_data, sample_rate = sf.read(io.BytesIO(response.content))
audio_cache[url] = (sample_rate, audio_data)
return sample_rate, audio_data
raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
def synthesize_speech(text: str, ref_audio_name: str, ref_text: str) -> tuple[io.BytesIO, Dict[str, float]]:
timing = {}
start_total = time.time()
# Find matching example
ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None)
if not ref_audio_url:
raise HTTPException(status_code=400, detail="Invalid reference audio name.")
if not text.strip() or (not ref_text or not ref_text.strip()):
raise HTTPException(status_code=400, detail="Text fields cannot be empty.")
# Load reference audio
start_audio_load = time.time()
sample_rate, audio_data = load_audio_from_url(ref_audio_url)
timing["audio_load"] = time.time() - start_audio_load
# Save reference audio to temp file
start_temp = time.time()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
temp_audio.flush()
# Inference with Flash Attention
start_inference = time.time()
with torch.no_grad():
if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available():
# Assuming model has an attention mechanism we can override
# This is a placeholder; actual implementation depends on model internals
try:
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text, attention_impl="flash")
except AttributeError:
print("Warning: Model does not support custom attention_impl. Using default.")
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
else:
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
timing["inference"] = time.time() - start_inference
timing["temp_file"] = time.time() - start_temp
# Normalize audio
start_normalize = time.time()
if audio.dtype == np.int16:
audio = audio.astype(np.float32) / 32768.0
timing["normalize"] = time.time() - start_normalize
# Save to buffer
start_buffer = time.time()
buffer = io.BytesIO()
sf.write(buffer, audio, 24000, format='WAV')
buffer.seek(0)
timing["buffer"] = time.time() - start_buffer
timing["total"] = time.time() - start_total
return buffer, timing
@app.post("/audio/speech", response_class=StreamingResponse)
async def synthesize_kannada(request: KannadaSynthesizeRequest):
start_time = time.time()
kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
audio_buffer, timing = synthesize_speech(
text=request.text,
ref_audio_name="KAN_F (Happy)",
ref_text=kannada_example["ref_text"]
)
print(f"Synthesis completed in {timing['total']:.2f} seconds: {timing}")
return StreamingResponse(
audio_buffer,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
)
@app.get("/")
async def home():
return RedirectResponse(url="/docs")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)