File size: 3,538 Bytes
2493d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Search a good noise schedule for WaveGrad for a given number of inferece iterations"""
import argparse
from itertools import product as cartesian_product

import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.utils.generic_utils import setup_generator

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='Path to model checkpoint.')
parser.add_argument('--config_path', type=str, help='Path to model config file.')
parser.add_argument('--data_path', type=str, help='Path to data directory.')
parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.')
parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.')
parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.')
parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.')
parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.')

# load config
args = parser.parse_args()
config = load_config(args.config_path)

# setup audio processor
ap = AudioProcessor(**config.audio)

# load dataset
_, train_data = load_wav_data(args.data_path, 0)
train_data = train_data[:args.num_samples]
dataset = WaveGradDataset(ap=ap,
                          items=train_data,
                          seq_len=-1,
                          hop_len=ap.hop_length,
                          pad_short=config.pad_short,
                          conv_pad=config.conv_pad,
                          is_training=True,
                          return_segments=False,
                          use_noise_augment=False,
                          use_cache=False,
                          verbose=True)
loader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=dataset.collate_full_clips,
    drop_last=False,
    num_workers=config.num_loader_workers,
    pin_memory=False)

# setup the model
model = setup_generator(config)
if args.use_cuda:
    model.cuda()

# setup optimization parameters
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
print(base_values)
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
best_error = float('inf')
best_schedule = None
total_search_iter = len(base_values)**args.num_iter
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
    beta = exponents * base
    model.compute_noise_level(beta)
    for data in loader:
        mel, audio = data
        y_hat = model.inference(mel.cuda() if args.use_cuda else mel)

        if args.use_cuda:
            y_hat = y_hat.cpu()
        y_hat = y_hat.numpy()

        mel_hat = []
        for i in range(y_hat.shape[0]):
            m = ap.melspectrogram(y_hat[i, 0])[:, :-1]
            mel_hat.append(torch.from_numpy(m))

        mel_hat = torch.stack(mel_hat)
        mse = torch.sum((mel - mel_hat) ** 2).mean()
        if mse.item() < best_error:
            best_error = mse.item()
            best_schedule = {'beta': beta}
            print(f" > Found a better schedule. - MSE: {mse.item()}")
            np.save(args.output_path, best_schedule)