File size: 882 Bytes
2493d72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import numpy as np
import torch
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator
def test_melgan_discriminator():
model = MelganDiscriminator()
print(model)
dummy_input = torch.rand((4, 1, 256 * 10))
output, _ = model(dummy_input)
assert np.all(output.shape == (4, 1, 10))
def test_melgan_multi_scale_discriminator():
model = MelganMultiscaleDiscriminator()
print(model)
dummy_input = torch.rand((4, 1, 256 * 16))
scores, feats = model(dummy_input)
assert len(scores) == 3
assert len(scores) == len(feats)
assert np.all(scores[0].shape == (4, 1, 64))
assert np.all(feats[0][0].shape == (4, 16, 4096))
assert np.all(feats[0][1].shape == (4, 64, 1024))
assert np.all(feats[0][2].shape == (4, 256, 256))
|