|
import argparse |
|
import importlib |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
from argparse import RawTextHelpFormatter |
|
from TTS.tts.datasets.TTSDataset import MyDataset |
|
from TTS.tts.utils.generic_utils import setup_model |
|
from TTS.tts.utils.io import load_checkpoint |
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols |
|
from TTS.utils.audio import AudioProcessor |
|
from TTS.utils.io import load_config |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser( |
|
description='''Extract attention masks from trained Tacotron/Tacotron2 models. |
|
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n''' |
|
|
|
'''Each attention mask is written to the same path as the input wav file with ".npy" file extension. |
|
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n''' |
|
|
|
''' |
|
Example run: |
|
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py |
|
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar |
|
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json |
|
--dataset_metafile /root/LJSpeech-1.1/metadata.csv |
|
--data_path /root/LJSpeech-1.1/ |
|
--batch_size 32 |
|
--dataset ljspeech |
|
--use_cuda True |
|
''', |
|
formatter_class=RawTextHelpFormatter |
|
) |
|
parser.add_argument('--model_path', |
|
type=str, |
|
required=True, |
|
help='Path to Tacotron/Tacotron2 model file ') |
|
parser.add_argument( |
|
'--config_path', |
|
type=str, |
|
required=True, |
|
help='Path to Tacotron/Tacotron2 config file.', |
|
) |
|
parser.add_argument('--dataset', |
|
type=str, |
|
default='', |
|
required=True, |
|
help='Target dataset processor name from TTS.tts.dataset.preprocess.') |
|
|
|
parser.add_argument( |
|
'--dataset_metafile', |
|
type=str, |
|
default='', |
|
required=True, |
|
help='Dataset metafile inclusing file paths with transcripts.') |
|
parser.add_argument( |
|
'--data_path', |
|
type=str, |
|
default='', |
|
help='Defines the data path. It overwrites config.json.') |
|
parser.add_argument('--use_cuda', |
|
type=bool, |
|
default=False, |
|
help="enable/disable cuda.") |
|
|
|
parser.add_argument( |
|
'--batch_size', |
|
default=16, |
|
type=int, |
|
help='Batch size for the model. Use batch_size=1 if you have no CUDA.') |
|
args = parser.parse_args() |
|
|
|
C = load_config(args.config_path) |
|
ap = AudioProcessor(**C.audio) |
|
|
|
|
|
if 'characters' in C.keys(): |
|
symbols, phonemes = make_symbols(**C.characters) |
|
|
|
|
|
num_chars = len(phonemes) if C.use_phonemes else len(symbols) |
|
|
|
model = setup_model(num_chars, num_speakers=0, c=C) |
|
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) |
|
model.eval() |
|
|
|
|
|
preprocessor = importlib.import_module('TTS.tts.datasets.preprocess') |
|
preprocessor = getattr(preprocessor, args.dataset) |
|
meta_data = preprocessor(args.data_path, args.dataset_metafile) |
|
dataset = MyDataset(model.decoder.r, |
|
C.text_cleaner, |
|
compute_linear_spec=False, |
|
ap=ap, |
|
meta_data=meta_data, |
|
tp=C.characters if 'characters' in C.keys() else None, |
|
add_blank=C['add_blank'] if 'add_blank' in C.keys() else False, |
|
use_phonemes=C.use_phonemes, |
|
phoneme_cache_path=C.phoneme_cache_path, |
|
phoneme_language=C.phoneme_language, |
|
enable_eos_bos=C.enable_eos_bos_chars) |
|
|
|
dataset.sort_items() |
|
loader = DataLoader(dataset, |
|
batch_size=args.batch_size, |
|
num_workers=4, |
|
collate_fn=dataset.collate_fn, |
|
shuffle=False, |
|
drop_last=False) |
|
|
|
|
|
file_paths = [] |
|
with torch.no_grad(): |
|
for data in tqdm(loader): |
|
|
|
text_input = data[0] |
|
text_lengths = data[1] |
|
linear_input = data[3] |
|
mel_input = data[4] |
|
mel_lengths = data[5] |
|
stop_targets = data[6] |
|
item_idxs = data[7] |
|
|
|
|
|
if args.use_cuda: |
|
text_input = text_input.cuda() |
|
text_lengths = text_lengths.cuda() |
|
mel_input = mel_input.cuda() |
|
mel_lengths = mel_lengths.cuda() |
|
|
|
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward( |
|
text_input, text_lengths, mel_input) |
|
|
|
alignments = alignments.detach() |
|
for idx, alignment in enumerate(alignments): |
|
item_idx = item_idxs[idx] |
|
|
|
alignment = torch.nn.functional.interpolate( |
|
alignment.transpose(0, 1).unsqueeze(0), |
|
size=None, |
|
scale_factor=model.decoder.r, |
|
mode='nearest', |
|
align_corners=None, |
|
recompute_scale_factor=None).squeeze(0).transpose(0, 1) |
|
|
|
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy() |
|
|
|
wav_file_name = os.path.basename(item_idx) |
|
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy' |
|
file_path = item_idx.replace(wav_file_name, align_file_name) |
|
|
|
file_paths.append([item_idx, file_path]) |
|
np.save(file_path, alignment) |
|
|
|
|
|
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") |
|
|
|
with open(metafile, "w") as f: |
|
for p in file_paths: |
|
f.write(f"{p[0]}|{p[1]}\n") |
|
print(f" >> Metafile created: {metafile}") |
|
|