TTS_Mongolian / handler.py
Vijish's picture
Create handler.py
9c20444 verified
raw
history blame
3.33 kB
from pydantic import BaseModel
from environs import Env
from typing import List, Dict, Any
import os
import base64
import numpy as np
import librosa
from scipy.io import wavfile
import asyncio
class EndpointHandler:
def __init__(self, model_dir=None):
self.model_dir = model_dir
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
try:
# Extract the actual input data from the "inputs" field
if "inputs" in data:
json_data = data["inputs"]
else:
json_data = data
# Clone the repository if it's not already cloned
repo_dir = "TTS_Mongolian"
if not os.path.exists(repo_dir):
repo_url = "https://huggingface.co/mazalaai/TTS_Mongolian.git"
os.system(f"git clone {repo_url}")
# Change directory to the repository
os.chdir(repo_dir)
# Import the voice_processing module and functions
from voice_processing import tts, get_model_names, voice_mapping, get_unique_filename
return self.process_json_input(json_data)
except ValueError as e:
return {"error": str(e)}
except Exception as e:
return {"error": str(e)}
def process_json_input(self, json_data):
if all(key in json_data for key in ["model_name", "tts_text", "selected_voice", "slang_rate", "use_uploaded_voice"]):
model_name = json_data["model_name"]
tts_text = json_data["tts_text"]
selected_voice = json_data["selected_voice"]
slang_rate = json_data["slang_rate"]
use_uploaded_voice = json_data["use_uploaded_voice"]
voice_upload_file = json_data.get("voice_upload_file", None)
edge_tts_voice = voice_mapping.get(selected_voice)
if not edge_tts_voice:
raise ValueError(f"Invalid voice '{selected_voice}'.")
info, edge_tts_output_path, tts_output_data, edge_output_file = asyncio.run(tts(
model_name,
tts_text,
edge_tts_voice,
slang_rate,
use_uploaded_voice,
voice_upload_file
))
if edge_output_file and os.path.exists(edge_output_file):
os.remove(edge_output_file)
_, audio_output = tts_output_data
audio_file_path = self.save_audio_data_to_file(audio_output) if isinstance(audio_output, np.ndarray) else audio_output
try:
with open(audio_file_path, 'rb') as file:
audio_bytes = file.read()
audio_data_uri = f"data:audio/wav;base64,{base64.b64encode(audio_bytes).decode('utf-8')}"
except Exception as e:
raise Exception(f"Failed to read audio file: {e}")
finally:
if os.path.exists(audio_file_path):
os.remove(audio_file_path)
return {"info": info, "audio_data_uri": audio_data_uri}
else:
raise ValueError("Invalid JSON structure.")
def save_audio_data_to_file(self, audio_data, sample_rate=40000):
file_path = get_unique_filename('wav')
wavfile.write(file_path, sample_rate, audio_data)
return file_path