|
import importlib |
|
import re |
|
|
|
from coqpit import Coqpit |
|
|
|
|
|
def to_camel(text): |
|
text = text.capitalize() |
|
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) |
|
|
|
|
|
def setup_model(config: Coqpit): |
|
"""Load models directly from configuration.""" |
|
if "discriminator_model" in config and "generator_model" in config: |
|
MyModel = importlib.import_module("TTS.vocoder.models.gan") |
|
MyModel = getattr(MyModel, "GAN") |
|
else: |
|
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower()) |
|
if config.model.lower() == "wavernn": |
|
MyModel = getattr(MyModel, "Wavernn") |
|
elif config.model.lower() == "gan": |
|
MyModel = getattr(MyModel, "GAN") |
|
elif config.model.lower() == "wavegrad": |
|
MyModel = getattr(MyModel, "Wavegrad") |
|
else: |
|
try: |
|
MyModel = getattr(MyModel, to_camel(config.model)) |
|
except ModuleNotFoundError as e: |
|
raise ValueError(f"Model {config.model} not exist!") from e |
|
print(" > Vocoder Model: {}".format(config.model)) |
|
return MyModel.init_from_config(config) |
|
|
|
|
|
def setup_generator(c): |
|
"""TODO: use config object as arguments""" |
|
print(" > Generator Model: {}".format(c.generator_model)) |
|
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) |
|
MyModel = getattr(MyModel, to_camel(c.generator_model)) |
|
|
|
if c.generator_model.lower() in "hifigan_generator": |
|
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params) |
|
elif c.generator_model.lower() in "melgan_generator": |
|
model = MyModel( |
|
in_channels=c.audio["num_mels"], |
|
out_channels=1, |
|
proj_kernel=7, |
|
base_channels=512, |
|
upsample_factors=c.generator_model_params["upsample_factors"], |
|
res_kernel=3, |
|
num_res_blocks=c.generator_model_params["num_res_blocks"], |
|
) |
|
elif c.generator_model in "melgan_fb_generator": |
|
raise ValueError("melgan_fb_generator is now fullband_melgan_generator") |
|
elif c.generator_model.lower() in "multiband_melgan_generator": |
|
model = MyModel( |
|
in_channels=c.audio["num_mels"], |
|
out_channels=4, |
|
proj_kernel=7, |
|
base_channels=384, |
|
upsample_factors=c.generator_model_params["upsample_factors"], |
|
res_kernel=3, |
|
num_res_blocks=c.generator_model_params["num_res_blocks"], |
|
) |
|
elif c.generator_model.lower() in "fullband_melgan_generator": |
|
model = MyModel( |
|
in_channels=c.audio["num_mels"], |
|
out_channels=1, |
|
proj_kernel=7, |
|
base_channels=512, |
|
upsample_factors=c.generator_model_params["upsample_factors"], |
|
res_kernel=3, |
|
num_res_blocks=c.generator_model_params["num_res_blocks"], |
|
) |
|
elif c.generator_model.lower() in "parallel_wavegan_generator": |
|
model = MyModel( |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_size=3, |
|
num_res_blocks=c.generator_model_params["num_res_blocks"], |
|
stacks=c.generator_model_params["stacks"], |
|
res_channels=64, |
|
gate_channels=128, |
|
skip_channels=64, |
|
aux_channels=c.audio["num_mels"], |
|
dropout=0.0, |
|
bias=True, |
|
use_weight_norm=True, |
|
upsample_factors=c.generator_model_params["upsample_factors"], |
|
) |
|
elif c.generator_model.lower() in "univnet_generator": |
|
model = MyModel(**c.generator_model_params) |
|
else: |
|
raise NotImplementedError(f"Model {c.generator_model} not implemented!") |
|
return model |
|
|
|
|
|
def setup_discriminator(c): |
|
"""TODO: use config objekt as arguments""" |
|
print(" > Discriminator Model: {}".format(c.discriminator_model)) |
|
if "parallel_wavegan" in c.discriminator_model: |
|
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") |
|
else: |
|
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) |
|
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) |
|
if c.discriminator_model in "hifigan_discriminator": |
|
model = MyModel() |
|
if c.discriminator_model in "random_window_discriminator": |
|
model = MyModel( |
|
cond_channels=c.audio["num_mels"], |
|
hop_length=c.audio["hop_length"], |
|
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], |
|
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], |
|
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], |
|
window_sizes=c.discriminator_model_params["window_sizes"], |
|
) |
|
if c.discriminator_model in "melgan_multiscale_discriminator": |
|
model = MyModel( |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_sizes=(5, 3), |
|
base_channels=c.discriminator_model_params["base_channels"], |
|
max_channels=c.discriminator_model_params["max_channels"], |
|
downsample_factors=c.discriminator_model_params["downsample_factors"], |
|
) |
|
if c.discriminator_model == "residual_parallel_wavegan_discriminator": |
|
model = MyModel( |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_size=3, |
|
num_layers=c.discriminator_model_params["num_layers"], |
|
stacks=c.discriminator_model_params["stacks"], |
|
res_channels=64, |
|
gate_channels=128, |
|
skip_channels=64, |
|
dropout=0.0, |
|
bias=True, |
|
nonlinear_activation="LeakyReLU", |
|
nonlinear_activation_params={"negative_slope": 0.2}, |
|
) |
|
if c.discriminator_model == "parallel_wavegan_discriminator": |
|
model = MyModel( |
|
in_channels=1, |
|
out_channels=1, |
|
kernel_size=3, |
|
num_layers=c.discriminator_model_params["num_layers"], |
|
conv_channels=64, |
|
dilation_factor=1, |
|
nonlinear_activation="LeakyReLU", |
|
nonlinear_activation_params={"negative_slope": 0.2}, |
|
bias=True, |
|
) |
|
if c.discriminator_model == "univnet_discriminator": |
|
model = MyModel() |
|
return model |
|
|