import io import time import torch import requests import tempfile import numpy as np import soundfile as sf from fastapi import FastAPI, HTTPException from transformers import AutoModel from pydantic import BaseModel from typing import Optional, Dict from starlette.responses import StreamingResponse from fastapi.responses import RedirectResponse # Check if flash-attn is available try: from flash_attn import flash_attention FLASH_ATTENTION_AVAILABLE = True except ImportError: FLASH_ATTENTION_AVAILABLE = False print("Flash Attention not available. Install with 'pip install flash-attn' for better performance.") # Initialize FastAPI app app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output") # Load TTS model globally with optimizations repo_id = "ai4bharat/IndicF5" model = AutoModel.from_pretrained(repo_id, trust_remote_code=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() # Set model to evaluation mode if torch.cuda.is_available(): torch.cuda.synchronize() print("Device:", device) # Precompile model if possible (PyTorch 2.0+) if hasattr(torch, "compile"): model = torch.compile(model) # Example Data EXAMPLES = [ { "audio_name": "KAN_F (Happy)", "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav", "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.", "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ." }, ] # Pydantic models class SynthesizeRequest(BaseModel): text: str ref_audio_name: str ref_text: Optional[str] = None class KannadaSynthesizeRequest(BaseModel): text: str # Cache for reference audio audio_cache = {} def load_audio_from_url(url: str) -> tuple: start_time = time.time() if url in audio_cache: return audio_cache[url] response = requests.get(url, timeout=10) if response.status_code == 200: audio_data, sample_rate = sf.read(io.BytesIO(response.content)) audio_cache[url] = (sample_rate, audio_data) return sample_rate, audio_data raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.") def synthesize_speech(text: str, ref_audio_name: str, ref_text: str) -> tuple[io.BytesIO, Dict[str, float]]: timing = {} start_total = time.time() # Find matching example ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None) if not ref_audio_url: raise HTTPException(status_code=400, detail="Invalid reference audio name.") if not text.strip() or (not ref_text or not ref_text.strip()): raise HTTPException(status_code=400, detail="Text fields cannot be empty.") # Load reference audio start_audio_load = time.time() sample_rate, audio_data = load_audio_from_url(ref_audio_url) timing["audio_load"] = time.time() - start_audio_load # Save reference audio to temp file start_temp = time.time() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV') temp_audio.flush() # Inference with Flash Attention start_inference = time.time() with torch.no_grad(): if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available(): # Assuming model has an attention mechanism we can override # This is a placeholder; actual implementation depends on model internals try: audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text, attention_impl="flash") except AttributeError: print("Warning: Model does not support custom attention_impl. Using default.") audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text) else: audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text) timing["inference"] = time.time() - start_inference timing["temp_file"] = time.time() - start_temp # Normalize audio start_normalize = time.time() if audio.dtype == np.int16: audio = audio.astype(np.float32) / 32768.0 timing["normalize"] = time.time() - start_normalize # Save to buffer start_buffer = time.time() buffer = io.BytesIO() sf.write(buffer, audio, 24000, format='WAV') buffer.seek(0) timing["buffer"] = time.time() - start_buffer timing["total"] = time.time() - start_total return buffer, timing @app.post("/audio/speech", response_class=StreamingResponse) async def synthesize_kannada(request: KannadaSynthesizeRequest): start_time = time.time() kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)") if not request.text.strip(): raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.") audio_buffer, timing = synthesize_speech( text=request.text, ref_audio_name="KAN_F (Happy)", ref_text=kannada_example["ref_text"] ) print(f"Synthesis completed in {timing['total']:.2f} seconds: {timing}") return StreamingResponse( audio_buffer, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"} ) @app.get("/") async def home(): return RedirectResponse(url="/docs") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)