yuancwang
init
b725c5a
import argparse
import os
import torch
import soundfile as sf
import numpy as np
from models.tts.naturalspeech2.ns2 import NaturalSpeech2
from encodec import EncodecModel
from encodec.utils import convert_audio
from utils.util import load_config
from text import text_to_sequence
from text.cmudict import valid_symbols
from text.g2p import preprocess_english, read_lexicon
import torchaudio
class NS2Inference:
def __init__(self, args, cfg):
self.cfg = cfg
self.args = args
self.model = self.build_model()
self.codec = self.build_codec()
self.symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
self.phone2id = {s: i for i, s in enumerate(self.symbols)}
self.id2phone = {i: s for s, i in self.phone2id.items()}
def build_model(self):
model = NaturalSpeech2(self.cfg.model)
model.load_state_dict(
torch.load(
os.path.join(self.args.checkpoint_path, "pytorch_model.bin"),
map_location="cpu",
)
)
model = model.to(self.args.device)
return model
def build_codec(self):
encodec_model = EncodecModel.encodec_model_24khz()
encodec_model = encodec_model.to(device=self.args.device)
encodec_model.set_target_bandwidth(12.0)
return encodec_model
def get_ref_code(self):
ref_wav_path = self.args.ref_audio
ref_wav, sr = torchaudio.load(ref_wav_path)
ref_wav = convert_audio(
ref_wav, sr, self.codec.sample_rate, self.codec.channels
)
ref_wav = ref_wav.unsqueeze(0).to(device=self.args.device)
with torch.no_grad():
encoded_frames = self.codec.encode(ref_wav)
ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
# print(ref_code.shape)
ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device)
# print(ref_mask.shape)
return ref_code, ref_mask
def inference(self):
ref_code, ref_mask = self.get_ref_code()
lexicon = read_lexicon(self.cfg.preprocess.lexicon_path)
phone_seq = preprocess_english(self.args.text, lexicon)
print(phone_seq)
phone_id = np.array(
[
*map(
self.phone2id.get,
phone_seq.replace("{", "").replace("}", "").split(),
)
]
)
phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=self.args.device)
print(phone_id)
x0, prior_out = self.model.inference(
ref_code, phone_id, ref_mask, self.args.inference_step
)
print(prior_out["dur_pred"])
print(prior_out["dur_pred_round"])
print(torch.sum(prior_out["dur_pred_round"]))
latent_ref = self.codec.quantizer.vq.decode(ref_code.transpose(0, 1))
rec_wav = self.codec.decoder(x0)
# ref_wav = self.codec.decoder(latent_ref)
os.makedirs(self.args.output_dir, exist_ok=True)
sf.write(
"{}/{}.wav".format(
self.args.output_dir, self.args.text.replace(" ", "_", 100)
),
rec_wav[0, 0].detach().cpu().numpy(),
samplerate=24000,
)
def add_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--ref_audio",
type=str,
default="",
help="Reference audio path",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
)
parser.add_argument(
"--inference_step",
type=int,
default=200,
help="Total inference steps for the diffusion model",
)