Spaces:
Paused
Paused
File size: 6,299 Bytes
4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 457fdad 4b03bb8 56c2e15 4b03bb8 457fdad 56c2e15 457fdad 56c2e15 4b03bb8 457fdad 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 d638e5c 56c2e15 457fdad 56c2e15 d638e5c 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 457fdad 56c2e15 457fdad 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 56c2e15 4b03bb8 d638e5c 56c2e15 d638e5c 56c2e15 d638e5c 56c2e15 d638e5c 4b03bb8 56c2e15 4b03bb8 2994754 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import io
import time
import torch
import requests
import tempfile
import numpy as np
import soundfile as sf
from fastapi import FastAPI, HTTPException
from transformers import AutoModel
from pydantic import BaseModel
from typing import Optional, Dict
from starlette.responses import StreamingResponse
from fastapi.responses import RedirectResponse
# Check if flash-attn is available
try:
from flash_attn import flash_attention
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
print("Flash Attention not available. Install with 'pip install flash-attn' for better performance.")
# Initialize FastAPI app
app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output")
# Load TTS model globally with optimizations
repo_id = "ai4bharat/IndicF5"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval() # Set model to evaluation mode
if torch.cuda.is_available():
torch.cuda.synchronize()
print("Device:", device)
# Precompile model if possible (PyTorch 2.0+)
if hasattr(torch, "compile"):
model = torch.compile(model)
# Example Data
EXAMPLES = [
{
"audio_name": "KAN_F (Happy)",
"audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
"ref_text": "ನಮ್ ಫ್ರಿಜ್ಜಲ್ಲಿ ಕೂಲಿಂಗ್ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
"synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
},
]
# Pydantic models
class SynthesizeRequest(BaseModel):
text: str
ref_audio_name: str
ref_text: Optional[str] = None
class KannadaSynthesizeRequest(BaseModel):
text: str
# Cache for reference audio
audio_cache = {}
def load_audio_from_url(url: str) -> tuple:
start_time = time.time()
if url in audio_cache:
return audio_cache[url]
response = requests.get(url, timeout=10)
if response.status_code == 200:
audio_data, sample_rate = sf.read(io.BytesIO(response.content))
audio_cache[url] = (sample_rate, audio_data)
return sample_rate, audio_data
raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")
def synthesize_speech(text: str, ref_audio_name: str, ref_text: str) -> tuple[io.BytesIO, Dict[str, float]]:
timing = {}
start_total = time.time()
# Find matching example
ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None)
if not ref_audio_url:
raise HTTPException(status_code=400, detail="Invalid reference audio name.")
if not text.strip() or (not ref_text or not ref_text.strip()):
raise HTTPException(status_code=400, detail="Text fields cannot be empty.")
# Load reference audio
start_audio_load = time.time()
sample_rate, audio_data = load_audio_from_url(ref_audio_url)
timing["audio_load"] = time.time() - start_audio_load
# Save reference audio to temp file
start_temp = time.time()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
temp_audio.flush()
# Inference with Flash Attention
start_inference = time.time()
with torch.no_grad():
if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available():
# Assuming model has an attention mechanism we can override
# This is a placeholder; actual implementation depends on model internals
try:
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text, attention_impl="flash")
except AttributeError:
print("Warning: Model does not support custom attention_impl. Using default.")
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
else:
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
timing["inference"] = time.time() - start_inference
timing["temp_file"] = time.time() - start_temp
# Normalize audio
start_normalize = time.time()
if audio.dtype == np.int16:
audio = audio.astype(np.float32) / 32768.0
timing["normalize"] = time.time() - start_normalize
# Save to buffer
start_buffer = time.time()
buffer = io.BytesIO()
sf.write(buffer, audio, 24000, format='WAV')
buffer.seek(0)
timing["buffer"] = time.time() - start_buffer
timing["total"] = time.time() - start_total
return buffer, timing
@app.post("/audio/speech", response_class=StreamingResponse)
async def synthesize_kannada(request: KannadaSynthesizeRequest):
start_time = time.time()
kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
audio_buffer, timing = synthesize_speech(
text=request.text,
ref_audio_name="KAN_F (Happy)",
ref_text=kannada_example["ref_text"]
)
print(f"Synthesis completed in {timing['total']:.2f} seconds: {timing}")
return StreamingResponse(
audio_buffer,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
)
@app.get("/")
async def home():
return RedirectResponse(url="/docs")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |