xVASynth-TTS / app.py
Pendrokar's picture
show errors
d0fa6f6 verified
raw
history blame
6.65 kB
import os
import sys
import requests
import json
from huggingface_hub import HfApi
# start xVASynth service (no HTTP)
import resources.app.no_server as xvaserver
from gr_client import BlocksDemo
# TODO: move snapshots to common folder & use an models array
# NVIDIA NeMo models
hf_model_name = "Pendrokar/xvapitch_nvidia"
model_repo = HfApi()
commits = model_repo.list_repo_commits(repo_id=hf_model_name)
latest_commit_sha = commits[0].commit_id
hf_cache_models_path = f'/home/user/.cache/huggingface/hub/models--Pendrokar--xvapitch_nvidia/snapshots/{latest_commit_sha}/'
models_path = hf_cache_models_path
# Expresso models
commits = model_repo.list_repo_commits(repo_id='Pendrokar/xvapitch_expresso')
latest_commit_sha = commits[0].commit_id
hf_cache_expresso_models_path = f'/home/user/.cache/huggingface/hub/models--Pendrokar--xvapitch_expresso/snapshots/{latest_commit_sha}/'
# Lojban model
commits = model_repo.list_repo_commits(repo_id='Pendrokar/xvasynth_lojban')
latest_commit_sha = commits[0].commit_id
hf_cache_lojban_models_path = f'/home/user/.cache/huggingface/hub/models--Pendrokar--xvasynth_lojban/snapshots/{latest_commit_sha}/'
# Robotic model
hf_cache_robotic_models_path = ''
try:
commits = model_repo.list_repo_commits(repo_id='Pendrokar/xvasynth_cabal', token=os.getenv('HF_TOKEN'))
latest_commit_sha = commits[0].commit_id
hf_cache_robotic_models_path = f'/home/user/.cache/huggingface/hub/models--Pendrokar--xvasynth_cabal/snapshots/{latest_commit_sha}/'
except:
print('Robotic voice not loaded!')
pass
current_voice_model = None
current_voice_type = None
base_speaker_emb = ''
def load_model(voice_model_name):
global current_voice_model, current_voice_type, base_speaker_emb
if voice_model_name == 'x_selpahi':
# Lojban
model_path = hf_cache_lojban_models_path + voice_model_name
model_type = 'FastPitch1.1'
else:
model_path = models_path + voice_model_name
if voice_model_name == 'cnc_cabal':
model_path = hf_cache_robotic_models_path + voice_model_name
if voice_model_name[:5] == 'x_ex0':
model_path = hf_cache_expresso_models_path + voice_model_name
model_type = 'xVAPitch'
language = 'en' # seems to have no effect if generated text is for a different language
data = {
'outputs': None,
'version': '3.0',
'model': model_path,
'modelType': model_type,
'base_lang': language,
'pluginsContext': '{}',
}
print('Loading voice model...')
try:
json_data = xvaserver.loadModel(data)
current_voice_model = voice_model_name
current_voice_type = model_type
with open(model_path + '.json', 'r', encoding='utf-8') as f:
voice_model_json = json.load(f)
if model_type == 'xVAPitch':
base_speaker_emb = voice_model_json['games'][0]['base_speaker_emb']
elif model_type == 'FastPitch1.1':
base_speaker_emb = voice_model_json['games'][0]['resemblyzer']
except requests.exceptions.RequestException as err:
print(f'FAILED to load voice model: {err}')
return base_speaker_emb
class LocalBlocksDemo(BlocksDemo):
def predict(
self,
input_text,
voice,
lang,
pacing,
pitch,
energy,
anger,
happy,
sad,
surprise,
use_deepmoji
):
global current_voice_model, current_voice_type, base_speaker_emb
# grab only the first 1000 characters
input_text = input_text[:1000]
# load voice model if not the current model
if (current_voice_model != voice):
load_model(voice)
model_type = current_voice_type
pace = pacing if pacing else 1.0
save_path = '/tmp/xvapitch_audio_sample.wav'
language = lang
use_sr = 0
use_cleanup = 0
pluginsContext = {}
pluginsContext["mantella_settings"] = {
"emAngry": (anger if anger > 0 else 0),
"emHappy": (happy if happy > 0 else 0),
"emSad": (sad if sad > 0 else 0),
"emSurprise": (surprise if surprise > 0 else 0),
"run_model": use_deepmoji
}
data = {
'pluginsContext': json.dumps(pluginsContext),
'modelType': model_type,
# pad with whitespaces as a workaround to avoid cutoffs
'sequence': input_text.center(len(input_text) + 2, ' '),
'pace': pace,
'outfile': save_path,
'vocoder': 'n/a',
'base_lang': language,
'base_emb': base_speaker_emb,
'useSR': use_sr,
'useCleanup': use_cleanup,
}
print('Synthesizing...')
try:
json_data = xvaserver.synthesize(data)
# response = requests.post('http://0.0.0.0:8008/synthesize', json=data, timeout=60)
# response.raise_for_status() # If the response contains an HTTP error status code, raise an exception
# json_data = json.loads(response.text)
except requests.exceptions.RequestException as err:
print('FAILED to synthesize: {err}')
save_path = ''
response = {'text': '{"message": "Failed"}'}
json_data = {
'arpabet': ['Failed'],
'durations': [0],
'em_anger': anger,
'em_happy': happy,
'em_sad': sad,
'em_surprise': surprise,
}
# print('server.log contents:')
# with open('resources/app/server.log', 'r') as f:
# print(f.read())
arpabet_html = ''
if voice == 'x_selpahi':
em_angry = 0
em_happy = 0
em_sad = 0
em_surprise = 0
else:
arpabet_html = '<h6>ARPAbet & Durations</h6>'
arpabet_html += '<table style="margin: 0 var(--size-2)"><tbody><tr>'
arpabet_nopad = json_data['arpabet'].split('|PAD|')
arpabet_symbols = json_data['arpabet'].split('|')
wpad_len = len(arpabet_symbols)
nopad_len = len(arpabet_nopad)
total_dur_length = 0
for symb_i in range(wpad_len):
if (arpabet_symbols[symb_i] == '<PAD>'):
continue
total_dur_length += float(json_data['durations'][symb_i])
for symb_i in range(wpad_len):
if (arpabet_symbols[symb_i] == '<PAD>'):
continue
arpabet_length = float(json_data['durations'][symb_i])
cell_width = round(arpabet_length / total_dur_length * 100, 2)
arpabet_html += '<td class="arpabet" style="width: '\
+ str(cell_width)\
+'%">'\
+ arpabet_symbols[symb_i]\
+ '</td> '
arpabet_html += '<tr></tbody></table>'
if use_deepmoji:
em_angry = round(json_data['em_angry'][0], 2)
em_happy = round(json_data['em_happy'][0], 2)
em_sad = round(json_data['em_sad'][0], 2)
em_surprise = round(json_data['em_surprise'][0], 2)
else:
em_angry = anger
em_happy = happy
em_sad = sad
em_surprise = surprise
return [
save_path,
arpabet_html,
em_angry,
em_happy,
em_sad,
em_surprise,
json_data
]
if __name__ == "__main__":
print('running custom Gradio interface')
demo = LocalBlocksDemo(models_path, hf_cache_lojban_models_path, hf_cache_robotic_models_path, hf_cache_expresso_models_path)
demo.block.launch(show_api=True, show_error=True)