|
from flask import Flask, render_template, request, send_file, jsonify |
|
import requests |
|
import json |
|
import ssl |
|
import logging |
|
import sys |
|
import os |
|
import base64 |
|
import io |
|
|
|
sys.path.append("hifigan") |
|
|
|
import torch |
|
from espnet2.bin.tts_inference import Text2Speech |
|
from models import Generator |
|
from scipy.io.wavfile import write |
|
from meldataset import MAX_WAV_VALUE |
|
from env import AttrDict |
|
import json |
|
import yaml |
|
from text_preprocess_for_inference import TTSDurAlignPreprocessor |
|
|
|
|
|
logging.basicConfig(filename='access.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
SAMPLING_RATE = 22050 |
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
preprocessor = TTSDurAlignPreprocessor() |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_hifigan_vocoder(language, gender, device): |
|
|
|
vocoder_config = f"vocoder/{gender}/aryan/hifigan/config.json" |
|
vocoder_generator = f"vocoder/{gender}/aryan/hifigan/generator" |
|
|
|
with open(vocoder_config, 'r') as f: |
|
data = f.read() |
|
json_config = json.loads(data) |
|
h = AttrDict(json_config) |
|
torch.manual_seed(h.seed) |
|
|
|
device = torch.device(device) |
|
generator = Generator(h).to(device) |
|
state_dict_g = torch.load(vocoder_generator, device) |
|
generator.load_state_dict(state_dict_g['generator']) |
|
generator.eval() |
|
generator.remove_weight_norm() |
|
|
|
|
|
return generator |
|
|
|
def load_fastspeech2_model(language, gender, device): |
|
|
|
|
|
with open(f"{language}/{gender}/model/config.yaml", "r") as file: |
|
config = yaml.safe_load(file) |
|
|
|
current_working_directory = os.getcwd() |
|
feat="model/feats_stats.npz" |
|
pitch="model/pitch_stats.npz" |
|
energy="model/energy_stats.npz" |
|
|
|
feat_path=os.path.join(current_working_directory,language,gender,feat) |
|
pitch_path=os.path.join(current_working_directory,language,gender,pitch) |
|
energy_path=os.path.join(current_working_directory,language,gender,energy) |
|
|
|
|
|
config["normalize_conf"]["stats_file"] = feat_path |
|
config["pitch_normalize_conf"]["stats_file"] = pitch_path |
|
config["energy_normalize_conf"]["stats_file"] = energy_path |
|
|
|
with open(f"{language}/{gender}/model/config.yaml", "w") as file: |
|
yaml.dump(config, file) |
|
|
|
tts_model = f"{language}/{gender}/model/model.pth" |
|
tts_config = f"{language}/{gender}/model/config.yaml" |
|
|
|
|
|
return Text2Speech(train_config=tts_config, model_file=tts_model, device=device) |
|
|
|
def text_synthesis(language, gender, sample_text, vocoder, MAX_WAV_VALUE, device, alpha=1): |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
model = load_fastspeech2_model(language, gender, device) |
|
|
|
|
|
out = model(sample_text, decode_conf={"alpha": alpha}) |
|
print("TTS Done") |
|
x = out["feat_gen_denorm"].T.unsqueeze(0) * 2.3262 |
|
x = x.to(device) |
|
|
|
|
|
y_g_hat = vocoder(x) |
|
audio = y_g_hat.squeeze() |
|
audio = audio * MAX_WAV_VALUE |
|
audio = audio.cpu().numpy().astype('int16') |
|
|
|
|
|
return audio |
|
|
|
def setup_app(): |
|
genders = ['male','female'] |
|
|
|
languages = {'hindi': "नमस्ते",'malayalam': "ഹലോ",'manipuri': "হ্যালো",'marathi': "हॅलो",'kannada': "ಹಲೋ",'bodo': "हॅलो",'english': "Hello",'assamese': "হ্যালো",'tamil': "ஹலோ",'odia': "ହେଲୋ",'rajasthani': "हॅलो",'telugu': "హలో",'bengali': "হ্যালো",'gujarati': "હલો"} |
|
|
|
vocoders = {} |
|
for gender in genders: |
|
vocoders[gender]={} |
|
for language,text in languages.items(): |
|
|
|
vocoder = load_hifigan_vocoder(language, gender, device) |
|
vocoders[gender][language] = vocoder |
|
|
|
print(f"making dummy calls for {language} - {gender}") |
|
try: |
|
out = text_synthesis(language, gender, text, vocoder, MAX_WAV_VALUE, device) |
|
except: |
|
message = f"cannot make dummy call for {gender} - {language} <===================" |
|
print(message.upper()) |
|
|
|
print("Server Started...") |
|
return vocoders |
|
vocoders = setup_app() |
|
|
|
@app.route('/', methods=['GET']) |
|
def main(): |
|
return "IITM_TTS_V2" |
|
|
|
@app.route('/tts', methods=['GET', 'POST'], strict_slashes=False) |
|
def tts(): |
|
try: |
|
json_data = request.get_json() |
|
text = json_data["input"] |
|
if not isinstance(text,str): |
|
input_type = type(text) |
|
ret = jsonify(status='failure', reason=f"Unsupported input type {input_type}. Input text should be in string format.") |
|
gender = json_data["gender"] |
|
language = json_data["lang"].lower() |
|
alpha = json_data["alpha"] |
|
|
|
preprocessed_text, phrases = preprocessor.preprocess(text, language, gender) |
|
preprocessed_text = " ".join(preprocessed_text) |
|
vocoder = vocoders[gender][language] |
|
out = text_synthesis(language, gender, preprocessed_text, vocoder, MAX_WAV_VALUE, device, alpha=alpha) |
|
|
|
|
|
|
|
|
|
|
|
|
|
output_stream = io.BytesIO() |
|
write(output_stream, SAMPLING_RATE, out) |
|
audio_wav_bytes = base64.b64encode(output_stream.getvalue()) |
|
|
|
ret = jsonify(status="success",audio=audio_wav_bytes.decode('utf-8')) |
|
|
|
except Exception as err: |
|
ret = jsonify(status="failure", reason=str(err)) |
|
return ret |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
app.run(host='0.0.0.0', port=4005, debug=True) |