sachin commited on
Commit
457fdad
·
1 Parent(s): 56c2e15
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. tts_api.py +23 -11
requirements.txt CHANGED
@@ -9,6 +9,7 @@ anyio==4.9.0
9
  async-timeout==5.0.1
10
  attrs==25.3.0
11
  audioread==3.0.1
 
12
  bitsandbytes==0.45.5
13
  boto3==1.37.29
14
  botocore==1.37.29
 
9
  async-timeout==5.0.1
10
  attrs==25.3.0
11
  audioread==3.0.1
12
+ flash-attn
13
  bitsandbytes==0.45.5
14
  boto3==1.37.29
15
  botocore==1.37.29
tts_api.py CHANGED
@@ -12,6 +12,14 @@ from typing import Optional, Dict
12
  from starlette.responses import StreamingResponse
13
  from fastapi.responses import RedirectResponse
14
 
 
 
 
 
 
 
 
 
15
  # Initialize FastAPI app
16
  app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output")
17
 
@@ -20,12 +28,12 @@ repo_id = "ai4bharat/IndicF5"
20
  model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  model = model.to(device)
23
- model.eval() # Set model to evaluation mode for inference
24
  if torch.cuda.is_available():
25
- torch.cuda.synchronize() # Ensure CUDA is ready
26
  print("Device:", device)
27
 
28
- # Precompile model if possible (for PyTorch 2.0+)
29
  if hasattr(torch, "compile"):
30
  model = torch.compile(model)
31
 
@@ -48,12 +56,7 @@ class SynthesizeRequest(BaseModel):
48
  class KannadaSynthesizeRequest(BaseModel):
49
  text: str
50
 
51
- # Response model with timing
52
- class SynthesisResponse(BaseModel):
53
- audio: bytes
54
- timing: Dict[str, float]
55
-
56
- # Cache for reference audio to avoid repeated downloads
57
  audio_cache = {}
58
 
59
  def load_audio_from_url(url: str) -> tuple:
@@ -91,10 +94,19 @@ def synthesize_speech(text: str, ref_audio_name: str, ref_text: str) -> tuple[io
91
  sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
92
  temp_audio.flush()
93
 
94
- # Inference with no_grad for optimization
95
  start_inference = time.time()
96
  with torch.no_grad():
97
- audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
 
 
 
 
 
 
 
 
 
98
  timing["inference"] = time.time() - start_inference
99
 
100
  timing["temp_file"] = time.time() - start_temp
 
12
  from starlette.responses import StreamingResponse
13
  from fastapi.responses import RedirectResponse
14
 
15
+ # Check if flash-attn is available
16
+ try:
17
+ from flash_attn import flash_attention
18
+ FLASH_ATTENTION_AVAILABLE = True
19
+ except ImportError:
20
+ FLASH_ATTENTION_AVAILABLE = False
21
+ print("Flash Attention not available. Install with 'pip install flash-attn' for better performance.")
22
+
23
  # Initialize FastAPI app
24
  app = FastAPI(title="IndicF5 Text-to-Speech API", description="High-quality TTS for Indian languages with Kannada output")
25
 
 
28
  model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  model = model.to(device)
31
+ model.eval() # Set model to evaluation mode
32
  if torch.cuda.is_available():
33
+ torch.cuda.synchronize()
34
  print("Device:", device)
35
 
36
+ # Precompile model if possible (PyTorch 2.0+)
37
  if hasattr(torch, "compile"):
38
  model = torch.compile(model)
39
 
 
56
  class KannadaSynthesizeRequest(BaseModel):
57
  text: str
58
 
59
+ # Cache for reference audio
 
 
 
 
 
60
  audio_cache = {}
61
 
62
  def load_audio_from_url(url: str) -> tuple:
 
94
  sf.write(temp_audio.name, audio_data, samplerate=sample_rate, format='WAV')
95
  temp_audio.flush()
96
 
97
+ # Inference with Flash Attention
98
  start_inference = time.time()
99
  with torch.no_grad():
100
+ if FLASH_ATTENTION_AVAILABLE and torch.cuda.is_available():
101
+ # Assuming model has an attention mechanism we can override
102
+ # This is a placeholder; actual implementation depends on model internals
103
+ try:
104
+ audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text, attention_impl="flash")
105
+ except AttributeError:
106
+ print("Warning: Model does not support custom attention_impl. Using default.")
107
+ audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
108
+ else:
109
+ audio = model(text, ref_audio_path=temp_audio.name, ref_text=ref_text)
110
  timing["inference"] = time.time() - start_inference
111
 
112
  timing["temp_file"] = time.time() - start_temp