Vijish commited on
Commit
4fbbd3b
1 Parent(s): 40b21b2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +63 -67
handler.py CHANGED
@@ -1,79 +1,75 @@
1
- from typing import Dict, List, Any, Optional
2
- from voice_processing import tts, get_model_names, voice_mapping, get_unique_filename
 
3
  import os
4
  import base64
5
  import numpy as np
 
6
  from scipy.io import wavfile
7
  import asyncio
8
- import time
9
 
10
  class EndpointHandler:
11
- def __init__(self):
12
- self.models = get_model_names()
13
- self.voices = list(voice_mapping.keys())
14
-
15
- async def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
- model_name = data.get("model_name")
17
- tts_text = data.get("tts_text", "Текстыг оруулна уу.")
18
- selected_voice = data.get("selected_voice")
19
- slang_rate = float(data.get("slang_rate", 0))
20
- use_uploaded_voice = data.get("use_uploaded_voice", False)
21
- voice_upload_file = data.get("voice_upload_file")
22
-
23
- edge_tts_voice = voice_mapping.get(selected_voice)
24
- if not edge_tts_voice:
25
- raise ValueError(f"Invalid voice '{selected_voice}'.")
26
-
27
- info, edge_tts_output_path, tts_output_data, edge_output_file = await tts(
28
- model_name, tts_text, edge_tts_voice, slang_rate, use_uploaded_voice, voice_upload_file
29
- )
30
-
31
- if edge_output_file and os.path.exists(edge_output_file):
32
- os.remove(edge_output_file)
33
-
34
- _, audio_output = tts_output_data
35
-
36
- audio_file_path = self.save_audio_data_to_file(audio_output) if isinstance(audio_output, np.ndarray) else audio_output
37
-
38
- audio_data_uri = self.get_audio_data_uri(audio_file_path)
39
-
40
- if os.path.exists(audio_file_path):
41
- os.remove(audio_file_path)
42
 
43
- return [{"info": info, "audio_data_uri": audio_data_uri}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def save_audio_data_to_file(self, audio_data, sample_rate=40000):
46
  file_path = get_unique_filename('wav')
47
  wavfile.write(file_path, sample_rate, audio_data)
48
- return file_path
49
-
50
- def get_audio_data_uri(self, audio_file_path):
51
- try:
52
- with open(audio_file_path, 'rb') as file:
53
- audio_bytes = file.read()
54
- return f"data:audio/wav;base64,{base64.b64encode(audio_bytes).decode('utf-8')}"
55
- except Exception as e:
56
- raise ValueError(f"Failed to read audio file: {e}")
57
-
58
- async def periodic_cleanup(folder_path, interval_seconds, max_age_seconds):
59
- while True:
60
- now = time.time()
61
- for filename in os.listdir(folder_path):
62
- if filename.endswith(".mp3"):
63
- file_path = os.path.join(folder_path, filename)
64
- file_age = now - os.path.getmtime(file_path)
65
- if file_age > max_age_seconds:
66
- os.remove(file_path)
67
- print(f"Deleted: {file_path}")
68
- await asyncio.sleep(interval_seconds)
69
-
70
- async def start_periodic_cleanup():
71
- folder_path = os.getcwd() # Use current working directory
72
- interval_seconds = 60 * 60 * 24 # Daily interval (24 hours)
73
- max_age_seconds = 3600 * 3 # Maximum age of files (3 hours)
74
- asyncio.create_task(periodic_cleanup(folder_path, interval_seconds, max_age_seconds))
75
-
76
- handler = EndpointHandler()
77
-
78
- if __name__ == "__main__":
79
- asyncio.run(start_periodic_cleanup())
 
1
+ from pydantic import BaseModel
2
+ from environs import Env
3
+ from typing import List, Dict, Any
4
  import os
5
  import base64
6
  import numpy as np
7
+ import librosa
8
  from scipy.io import wavfile
9
  import asyncio
10
+ from voice_processing import tts, get_model_names, voice_mapping, get_unique_filename
11
 
12
  class EndpointHandler:
13
+ def __init__(self, model_dir=None):
14
+ self.model_dir = model_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
17
+ try:
18
+ if "inputs" in data: # Check if data is in Hugging Face JSON format
19
+ return self.process_hf_input(data)
20
+ else:
21
+ return self.process_json_input(data)
22
+ except ValueError as e:
23
+ return {"error": str(e)}
24
+ except Exception as e:
25
+ return {"error": str(e)}
26
+
27
+ def process_json_input(self, json_data):
28
+ if all(key in json_data for key in ["model_name", "tts_text", "selected_voice", "slang_rate", "use_uploaded_voice"]):
29
+ model_name = json_data["model_name"]
30
+ tts_text = json_data["tts_text"]
31
+ selected_voice = json_data["selected_voice"]
32
+ slang_rate = json_data["slang_rate"]
33
+ use_uploaded_voice = json_data["use_uploaded_voice"]
34
+ voice_upload_file = json_data.get("voice_upload_file", None)
35
+
36
+ edge_tts_voice = voice_mapping.get(selected_voice)
37
+ if not edge_tts_voice:
38
+ raise ValueError(f"Invalid voice '{selected_voice}'.")
39
+
40
+ info, edge_tts_output_path, tts_output_data, edge_output_file = asyncio.run(tts(
41
+ model_name, tts_text, edge_tts_voice, slang_rate, use_uploaded_voice, voice_upload_file
42
+ ))
43
+
44
+ if edge_output_file and os.path.exists(edge_output_file):
45
+ os.remove(edge_output_file)
46
+
47
+ _, audio_output = tts_output_data
48
+
49
+ audio_file_path = self.save_audio_data_to_file(audio_output) if isinstance(audio_output, np.ndarray) else audio_output
50
+
51
+ try:
52
+ with open(audio_file_path, 'rb') as file:
53
+ audio_bytes = file.read()
54
+ audio_data_uri = f"data:audio/wav;base64,{base64.b64encode(audio_bytes).decode('utf-8')}"
55
+ except Exception as e:
56
+ raise Exception(f"Failed to read audio file: {e}")
57
+ finally:
58
+ if os.path.exists(audio_file_path):
59
+ os.remove(audio_file_path)
60
+
61
+ return {"info": info, "audio_data_uri": audio_data_uri}
62
+ else:
63
+ raise ValueError("Invalid JSON structure.")
64
+
65
+ def process_hf_input(self, hf_data):
66
+ if "inputs" in hf_data:
67
+ actual_data = hf_data["inputs"]
68
+ return self.process_json_input(actual_data)
69
+ else:
70
+ return {"error": "Invalid Hugging Face JSON structure."}
71
 
72
  def save_audio_data_to_file(self, audio_data, sample_rate=40000):
73
  file_path = get_unique_filename('wav')
74
  wavfile.write(file_path, sample_rate, audio_data)
75
+ return file_path