Spaces:
Runtime error
Runtime error
File size: 4,243 Bytes
4300fed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import re
import json
import torch
import librosa
import soundfile
import torchaudio
import numpy as np
import torch.nn as nn
from . import utils
from . import commons
from .models import SynthesizerTrn
from .split_utils import split_sentence
from .mel_processing import spectrogram_torch, spectrogram_torch_conv
from .download_utils import load_or_download_config, load_or_download_model
class TTS(nn.Module):
def __init__(self,
language,
device='cuda:0'):
super().__init__()
if 'cuda' in device:
assert torch.cuda.is_available()
# config_path =
hps = load_or_download_config(language)
num_languages = hps.num_languages
num_tones = hps.num_tones
symbols = hps.symbols
model = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
num_tones=num_tones,
num_languages=num_languages,
**hps.model,
).to(device)
model.eval()
self.model = model
self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
self.hps = hps
self.device = device
# load state_dict
checkpoint_dict = load_or_download_model(language, device)
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
language = language.split('_')[0]
self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
@staticmethod
def audio_numpy_concat(segment_data_list, sr, speed=1.):
audio_segments = []
for segment_data in segment_data_list:
audio_segments += segment_data.reshape(-1).tolist()
audio_segments += [0] * int((sr * 0.05) / speed)
audio_segments = np.array(audio_segments).astype(np.float32)
return audio_segments
@staticmethod
def split_sentences_into_pieces(text, language):
texts = split_sentence(text, language_str=language)
print(" > Text splitted to sentences.")
print('\n'.join(texts))
print(" > ===========================")
return texts
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0):
language = self.language
texts = self.split_sentences_into_pieces(text, language)
audio_list = []
for t in texts:
if language in ['EN', 'ZH_MIX_EN']:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
device = self.device
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([speaker_id]).to(device)
audio = self.model.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=1. / speed,
)[0][0, 0].data.cpu().float().numpy()
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
#
audio_list.append(audio)
torch.cuda.empty_cache()
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
if output_path is None:
return audio
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate) |