Update handler.py
Browse files- handler.py +20 -2
handler.py
CHANGED
@@ -6,6 +6,7 @@ import base64
|
|
6 |
import numpy as np
|
7 |
import librosa
|
8 |
from scipy.io import wavfile
|
|
|
9 |
|
10 |
class EndpointHandler:
|
11 |
def __init__(self, model_dir=None):
|
@@ -17,9 +18,15 @@ class EndpointHandler:
|
|
17 |
repo_url = "https://huggingface.co/mazalaai/TTS_Mongolian.git"
|
18 |
os.system(f"git clone {repo_url}")
|
19 |
|
20 |
-
#
|
21 |
repo_dir = "TTS_Mongolian"
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# Import the voice_processing module and functions
|
25 |
from voice_processing import tts, get_model_names, voice_mapping, get_unique_filename
|
@@ -34,6 +41,17 @@ class EndpointHandler:
|
|
34 |
return {"error": str(e)}
|
35 |
except Exception as e:
|
36 |
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def process_json_input(self, json_data):
|
39 |
if all(key in json_data for key in ["model_name", "tts_text", "selected_voice", "slang_rate", "use_uploaded_voice"]):
|
|
|
6 |
import numpy as np
|
7 |
import librosa
|
8 |
from scipy.io import wavfile
|
9 |
+
import shutil
|
10 |
|
11 |
class EndpointHandler:
|
12 |
def __init__(self, model_dir=None):
|
|
|
18 |
repo_url = "https://huggingface.co/mazalaai/TTS_Mongolian.git"
|
19 |
os.system(f"git clone {repo_url}")
|
20 |
|
21 |
+
# Copy all files from the cloned repository to the /repository directory
|
22 |
repo_dir = "TTS_Mongolian"
|
23 |
+
dest_dir = "/repository"
|
24 |
+
for item in os.listdir(repo_dir):
|
25 |
+
item_path = os.path.join(repo_dir, item)
|
26 |
+
if os.path.isfile(item_path):
|
27 |
+
shutil.copy(item_path, dest_dir)
|
28 |
+
elif os.path.isdir(item_path):
|
29 |
+
shutil.copytree(item_path, os.path.join(dest_dir, item))
|
30 |
|
31 |
# Import the voice_processing module and functions
|
32 |
from voice_processing import tts, get_model_names, voice_mapping, get_unique_filename
|
|
|
41 |
return {"error": str(e)}
|
42 |
except Exception as e:
|
43 |
return {"error": str(e)}
|
44 |
+
finally:
|
45 |
+
# Clean up the cloned repository and copied files/directories
|
46 |
+
if os.path.exists(repo_dir):
|
47 |
+
shutil.rmtree(repo_dir)
|
48 |
+
for item in os.listdir(dest_dir):
|
49 |
+
if item.startswith("TTS_Mongolian"):
|
50 |
+
item_path = os.path.join(dest_dir, item)
|
51 |
+
if os.path.isfile(item_path):
|
52 |
+
os.remove(item_path)
|
53 |
+
elif os.path.isdir(item_path):
|
54 |
+
shutil.rmtree(item_path)
|
55 |
|
56 |
def process_json_input(self, json_data):
|
57 |
if all(key in json_data for key in ["model_name", "tts_text", "selected_voice", "slang_rate", "use_uploaded_voice"]):
|