E2-F5-TTSi / test_infer_batch.py
mrfakename's picture
Super-squash branch 'main' using huggingface_hub
1646c30 verified
import os
import time
import random
from tqdm import tqdm
import argparse
import torch
import torchaudio
from accelerate import Accelerator
from einops import rearrange
from ema_pytorch import EMA
from vocos import Vocos
from model import CFM, UNetT, DiT
from model.utils import (
get_tokenizer,
get_seedtts_testset_metainfo,
get_librispeech_test_clean_metainfo,
get_inference_prompt,
)
accelerator = Accelerator()
device = f"cuda:{accelerator.process_index}"
# --------------------- Dataset Settings -------------------- #
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
tokenizer = "pinyin"
# ---------------------- infer setting ---------------------- #
parser = argparse.ArgumentParser(description="batch inference")
parser.add_argument('-s', '--seed', default=None, type=int)
parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
parser.add_argument('-n', '--expname', required=True)
parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
parser.add_argument('-nfe', '--nfestep', default=32, type=int)
parser.add_argument('-o', '--odemethod', default="euler")
parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
parser.add_argument('-t', '--testset', required=True)
args = parser.parse_args()
seed = args.seed
dataset_name = args.dataset
exp_name = args.expname
ckpt_step = args.ckptstep
checkpoint = torch.load(f"ckpts/{exp_name}/model_{ckpt_step}.pt", map_location=device)
nfe_step = args.nfestep
ode_method = args.odemethod
sway_sampling_coef = args.swaysampling
testset = args.testset
infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
cfg_strength = 2.
speed = 1.
use_truth_duration = False
no_ref_audio = False
if exp_name == "F5TTS_Base":
model_cls = DiT
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
elif exp_name == "E2TTS_Base":
model_cls = UNetT
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
if testset == "ls_pc_test_clean":
metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
elif testset == "seedtts_test_zh":
metalst = "data/seedtts_testset/zh/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
elif testset == "seedtts_test_en":
metalst = "data/seedtts_testset/en/meta.lst"
metainfo = get_seedtts_testset_metainfo(metalst)
# path to save genereted wavs
if seed is None: seed = random.randint(-10000, 10000)
output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
f"seed{seed}_{ode_method}_nfe{nfe_step}" \
f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
f"_cfg{cfg_strength}_speed{speed}" \
f"{'_gt-dur' if use_truth_duration else ''}" \
f"{'_no-ref-audio' if no_ref_audio else ''}"
# -------------------------------------------------#
use_ema = True
prompts_all = get_inference_prompt(
metainfo,
speed = speed,
tokenizer = tokenizer,
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
target_rms = target_rms,
use_truth_duration = use_truth_duration,
infer_batch_size = infer_batch_size,
)
# Vocoder model
local = False
if local:
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()
else:
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
# Tokenizer
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
# Model
model = CFM(
transformer = model_cls(
**model_cfg,
text_num_embeds = vocab_size,
mel_dim = n_mel_channels
),
mel_spec_kwargs = dict(
target_sample_rate = target_sample_rate,
n_mel_channels = n_mel_channels,
hop_length = hop_length,
),
odeint_kwargs = dict(
method = ode_method,
),
vocab_char_map = vocab_char_map,
).to(device)
if use_ema == True:
ema_model = EMA(model, include_online_model = False).to(device)
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
ema_model.copy_params_from_ema_to_model()
else:
model.load_state_dict(checkpoint['model_state_dict'])
if not os.path.exists(output_dir) and accelerator.is_main_process:
os.makedirs(output_dir)
# start batch inference
accelerator.wait_for_everyone()
start = time.time()
with accelerator.split_between_processes(prompts_all) as prompts:
for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
ref_mels = ref_mels.to(device)
ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
# Inference
with torch.inference_mode():
generated, _ = model.sample(
cond = ref_mels,
text = final_text_list,
duration = total_mel_lens,
lens = ref_mel_lens,
steps = nfe_step,
cfg_strength = cfg_strength,
sway_sampling_coef = sway_sampling_coef,
no_ref_audio = no_ref_audio,
seed = seed,
)
# Final result
for i, gen in enumerate(generated):
gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
generated_wave = vocos.decode(gen_mel_spec.cpu())
if ref_rms_list[i] < target_rms:
generated_wave = generated_wave * ref_rms_list[i] / target_rms
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
timediff = time.time() - start
print(f"Done batch inference in {timediff / 60 :.2f} minutes.")