Spaces:
Paused
Paused
sachin
commited on
Commit
·
457fdad
1
Parent(s):
56c2e15
tet
Browse files- requirements.txt +1 -0
- tts_api.py +23 -11
requirements.txt
CHANGED
@@ -9,6 +9,7 @@ anyio==4.9.0
|
|
9 |
async-timeout==5.0.1
|
10 |
attrs==25.3.0
|
11 |
audioread==3.0.1
|
|
|
12 |
bitsandbytes==0.45.5
|
13 |
boto3==1.37.29
|
14 |
botocore==1.37.29
|
|
|
9 |
async-timeout==5.0.1
|
10 |
attrs==25.3.0
|
11 |
audioread==3.0.1
|
12 |
+
flash-attn
|
13 |
bitsandbytes==0.45.5
|
14 |
boto3==1.37.29
|
15 |
botocore==1.37.29
|
tts_api.py
CHANGED
@@ -12,6 +12,14 @@ from typing import Optional, Dict
|
|
12 |
from starlette.responses import StreamingResponse
|
13 |
from fastapi.responses import RedirectResponse
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Initialize FastAPI app
|
16 |
app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output")
|
17 |
|
@@ -20,12 +28,12 @@ repo_id = "ai4bharat/IndicF5"
|
|
20 |
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
|
21 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
model = model.to(device)
|
23 |
-
model.eval() # Set model to evaluation mode
|
24 |
if torch.cuda.is_available():
|
25 |
-
torch.cuda.synchronize()
|
26 |
print("Device:", device)
|
27 |
|
28 |
-
# Precompile model if possible (
|
29 |
if hasattr(torch, "compile"):
|
30 |
model = torch.compile(model)
|
31 |
|
@@ -48,12 +56,7 @@ class SynthesizeRequest(BaseModel):
|
|
48 |
class KannadaSynthesizeRequest(BaseModel):
|
49 |
text: str
|
50 |
|
51 |
-
#
|
52 |
-
class SynthesisResponse(BaseModel):
|
53 |
-
audio: bytes
|
54 |
-
timing: Dict[str, float]
|
55 |
-
|
56 |
-
# Cache for reference audio to avoid repeated downloads
|
57 |
audio_cache = {}
|
58 |
|
59 |
def load_audio_from_url(url: str) -> tuple:
|
@@ -91,10 +94,19 @@ def synthesize_speech(text: str, ref_audio_name: str, ref_text: str) -> tuple[io
|
|
91 |
sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
|
92 |
temp_audio.flush()
|
93 |
|
94 |
-
# Inference with
|
95 |
start_inference = time.time()
|
96 |
with torch.no_grad():
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
timing["inference"] = time.time() - start_inference
|
99 |
|
100 |
timing["temp_file"] = time.time() - start_temp
|
|
|
12 |
from starlette.responses import StreamingResponse
|
13 |
from fastapi.responses import RedirectResponse
|
14 |
|
15 |
+
# Check if flash-attn is available
|
16 |
+
try:
|
17 |
+
from flash_attn import flash_attention
|
18 |
+
FLASH_ATTENTION_AVAILABLE = True
|
19 |
+
except ImportError:
|
20 |
+
FLASH_ATTENTION_AVAILABLE = False
|
21 |
+
print("Flash Attention not available. Install with 'pip install flash-attn' for better performance.")
|
22 |
+
|
23 |
# Initialize FastAPI app
|
24 |
app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output")
|
25 |
|
|
|
28 |
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
model = model.to(device)
|
31 |
+
model.eval() # Set model to evaluation mode
|
32 |
if torch.cuda.is_available():
|
33 |
+
torch.cuda.synchronize()
|
34 |
print("Device:", device)
|
35 |
|
36 |
+
# Precompile model if possible (PyTorch 2.0+)
|
37 |
if hasattr(torch, "compile"):
|
38 |
model = torch.compile(model)
|
39 |
|
|
|
56 |
class KannadaSynthesizeRequest(BaseModel):
|
57 |
text: str
|
58 |
|
59 |
+
# Cache for reference audio
|
|
|
|
|
|
|
|
|
|
|
60 |
audio_cache = {}
|
61 |
|
62 |
def load_audio_from_url(url: str) -> tuple:
|
|
|
94 |
sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
|
95 |
temp_audio.flush()
|
96 |
|
97 |
+
# Inference with Flash Attention
|
98 |
start_inference = time.time()
|
99 |
with torch.no_grad():
|
100 |
+
if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available():
|
101 |
+
# Assuming model has an attention mechanism we can override
|
102 |
+
# This is a placeholder; actual implementation depends on model internals
|
103 |
+
try:
|
104 |
+
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text, attention_impl="flash")
|
105 |
+
except AttributeError:
|
106 |
+
print("Warning: Model does not support custom attention_impl. Using default.")
|
107 |
+
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
|
108 |
+
else:
|
109 |
+
audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
|
110 |
timing["inference"] = time.time() - start_inference
|
111 |
|
112 |
timing["temp_file"] = time.time() - start_temp
|