File size: 6,299 Bytes
4b03bb8
56c2e15
4b03bb8
 
 
 
 
 
 
 
56c2e15
4b03bb8
56c2e15
4b03bb8
457fdad
 
 
 
 
 
 
 
4b03bb8
 
 
56c2e15
4b03bb8
 
 
 
457fdad
56c2e15
457fdad
56c2e15
4b03bb8
457fdad
56c2e15
 
 
 
4b03bb8
 
 
 
 
 
 
 
 
56c2e15
4b03bb8
56c2e15
 
 
4b03bb8
d638e5c
56c2e15
 
457fdad
56c2e15
d638e5c
56c2e15
 
 
 
 
 
4b03bb8
 
56c2e15
4b03bb8
 
 
56c2e15
 
 
 
 
 
4b03bb8
 
 
56c2e15
 
4b03bb8
56c2e15
 
4b03bb8
56c2e15
4b03bb8
56c2e15
 
4b03bb8
 
 
 
457fdad
56c2e15
 
457fdad
 
 
 
 
 
 
 
 
 
56c2e15
 
 
4b03bb8
56c2e15
 
4b03bb8
 
56c2e15
4b03bb8
56c2e15
 
4b03bb8
 
 
56c2e15
 
 
 
4b03bb8
d638e5c
 
56c2e15
d638e5c
 
 
 
 
56c2e15
d638e5c
 
 
 
 
56c2e15
 
d638e5c
 
 
 
 
4b03bb8
 
56c2e15
 
 
4b03bb8
 
2994754
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import io
import time
import torch
import requests
import tempfile
import numpy as np
import soundfile as sf
from fastapi import FastAPI, HTTPException
from transformers import AutoModel
from pydantic import BaseModel
from typing import Optional, Dict
from starlette.responses import StreamingResponse
from fastapi.responses import RedirectResponse

# Check if flash-attn is available
try:
    from flash_attn import flash_attention
    FLASH_ATTENTION_AVAILABLE = True
except ImportError:
    FLASH_ATTENTION_AVAILABLE = False
    print("Flash Attention not available. Install with 'pip install flash-attn' for better performance.")

# Initialize FastAPI app
app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output")

# Load TTS model globally with optimizations
repo_id = "ai4bharat/IndicF5"
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()  # Set model to evaluation mode
if torch.cuda.is_available():
    torch.cuda.synchronize()
print("Device:", device)

# Precompile model if possible (PyTorch 2.0+)
if hasattr(torch, "compile"):
    model = torch.compile(model)

# Example Data
EXAMPLES = [
    {
        "audio_name": "KAN_F (Happy)",
        "audio_url": "https://github.com/AI4Bharat/IndicF5/raw/refs/heads/main/prompts/KAN_F_HAPPY_00001.wav",
        "ref_text": "ನಮ್‌ ಫ್ರಿಜ್ಜಲ್ಲಿ  ಕೂಲಿಂಗ್‌ ಸಮಸ್ಯೆ ಆಗಿ ನಾನ್‌ ಭಾಳ ದಿನದಿಂದ ಒದ್ದಾಡ್ತಿದ್ದೆ, ಆದ್ರೆ ಅದ್ನೀಗ ಮೆಕಾನಿಕ್ ಆಗಿರೋ ನಿಮ್‌ ಸಹಾಯ್ದಿಂದ ಬಗೆಹರಿಸ್ಕೋಬೋದು ಅಂತಾಗಿ ನಿರಾಳ ಆಯ್ತು ನಂಗೆ.",
        "synth_text": "ಚೆನ್ನೈನ ಶೇರ್ ಆಟೋ ಪ್ರಯಾಣಿಕರ ನಡುವೆ ಆಹಾರವನ್ನು ಹಂಚಿಕೊಂಡು ತಿನ್ನುವುದು ನನಗೆ ಮನಸ್ಸಿಗೆ ತುಂಬಾ ಒಳ್ಳೆಯದೆನಿಸುವ ವಿಷಯ."
    },
]

# Pydantic models
class SynthesizeRequest(BaseModel):
    text: str
    ref_audio_name: str
    ref_text: Optional[str] = None

class KannadaSynthesizeRequest(BaseModel):
    text: str

# Cache for reference audio
audio_cache = {}

def load_audio_from_url(url: str) -> tuple:
    start_time = time.time()
    if url in audio_cache:
        return audio_cache[url]
    
    response = requests.get(url, timeout=10)
    if response.status_code == 200:
        audio_data, sample_rate = sf.read(io.BytesIO(response.content))
        audio_cache[url] = (sample_rate, audio_data)
        return sample_rate, audio_data
    raise HTTPException(status_code=500, detail="Failed to load reference audio from URL.")

def synthesize_speech(text: str, ref_audio_name: str, ref_text: str) -> tuple[io.BytesIO, Dict[str, float]]:
    timing = {}
    start_total = time.time()

    # Find matching example
    ref_audio_url = next((ex["audio_url"] for ex in EXAMPLES if ex["audio_name"] == ref_audio_name), None)
    if not ref_audio_url:
        raise HTTPException(status_code=400, detail="Invalid reference audio name.")
    
    if not text.strip() or (not ref_text or not ref_text.strip()):
        raise HTTPException(status_code=400, detail="Text fields cannot be empty.")

    # Load reference audio
    start_audio_load = time.time()
    sample_rate, audio_data = load_audio_from_url(ref_audio_url)
    timing["audio_load"] = time.time() - start_audio_load

    # Save reference audio to temp file
    start_temp = time.time()
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
        sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
        temp_audio.flush()

        # Inference with Flash Attention
        start_inference = time.time()
        with torch.no_grad():
            if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available():
                # Assuming model has an attention mechanism we can override
                # This is a placeholder; actual implementation depends on model internals
                try:
                    audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text, attention_impl="flash")
                except AttributeError:
                    print("Warning: Model does not support custom attention_impl. Using default.")
                    audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
            else:
                audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
        timing["inference"] = time.time() - start_inference

    timing["temp_file"] = time.time() - start_temp

    # Normalize audio
    start_normalize = time.time()
    if audio.dtype == np.int16:
        audio = audio.astype(np.float32) / 32768.0
    timing["normalize"] = time.time() - start_normalize

    # Save to buffer
    start_buffer = time.time()
    buffer = io.BytesIO()
    sf.write(buffer, audio, 24000, format='WAV')
    buffer.seek(0)
    timing["buffer"] = time.time() - start_buffer

    timing["total"] = time.time() - start_total
    return buffer, timing

@app.post("/audio/speech", response_class=StreamingResponse)
async def synthesize_kannada(request: KannadaSynthesizeRequest):
    start_time = time.time()
    kannada_example = next(ex for ex in EXAMPLES if ex["audio_name"] == "KAN_F (Happy)")
    
    if not request.text.strip():
        raise HTTPException(status_code=400, detail="Text to synthesize cannot be empty.")
    
    audio_buffer, timing = synthesize_speech(
        text=request.text,
        ref_audio_name="KAN_F (Happy)",
        ref_text=kannada_example["ref_text"]
    )
    
    print(f"Synthesis completed in {timing['total']:.2f} seconds: {timing}")
    
    return StreamingResponse(
        audio_buffer,
        media_type="audio/wav",
        headers={"Content-Disposition": "attachment; filename=synthesized_kannada_speech.wav"}
    )

@app.get("/")
async def home():
    return RedirectResponse(url="/docs")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)