import os import re import json import torch import librosa import soundfile import torchaudio import numpy as np import torch.nn as nn from . import utils from . import commons from .models import SynthesizerTrn from .split_utils import split_sentence from .mel_processing import spectrogram_torch, spectrogram_torch_conv from .download_utils import load_or_download_config, load_or_download_model class TTS(nn.Module): def __init__(self, language, device='cuda:0'): super().__init__() if 'cuda' in device: assert torch.cuda.is_available() # config_path = hps = load_or_download_config(language) num_languages = hps.num_languages num_tones = hps.num_tones symbols = hps.symbols model = SynthesizerTrn( len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, num_tones=num_tones, num_languages=num_languages, **hps.model, ).to(device) model.eval() self.model = model self.symbol_to_id = {s: i for i, s in enumerate(symbols)} self.hps = hps self.device = device # load state_dict checkpoint_dict = load_or_download_model(language, device) self.model.load_state_dict(checkpoint_dict['model'], strict=True) language = language.split('_')[0] self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model @staticmethod def audio_numpy_concat(segment_data_list, sr, speed=1.): audio_segments = [] for segment_data in segment_data_list: audio_segments += segment_data.reshape(-1).tolist() audio_segments += [0] * int((sr * 0.05) / speed) audio_segments = np.array(audio_segments).astype(np.float32) return audio_segments @staticmethod def split_sentences_into_pieces(text, language): texts = split_sentence(text, language_str=language) # print(" > Text splitted to sentences.") # print('\n'.join(texts)) # print(" > ===========================") return texts def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None): language = self.language texts = self.split_sentences_into_pieces(text, language) audio_list = [] tx = texts if pbar: tx = pbar(texts) for t in tx: if language in ['EN', 'ZH_MIX_EN']: t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) device = self.device bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id) with torch.no_grad(): x_tst = phones.to(device).unsqueeze(0) tones = tones.to(device).unsqueeze(0) lang_ids = lang_ids.to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) ja_bert = ja_bert.to(device).unsqueeze(0) x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) del phones speakers = torch.LongTensor([speaker_id]).to(device) audio = self.model.infer( x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, ja_bert, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=1. / speed, )[0][0, 0].data.cpu().float().numpy() del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers # audio_list.append(audio) torch.cuda.empty_cache() audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed) if output_path is None: return audio else: soundfile.write(output_path, audio, self.hps.data.sampling_rate, format)