Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,456 Bytes
b9d6819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import json
def create_model_from_config(model_config):
model_type = model_config.get('model_type', None)
assert model_type is not None, 'model_type must be specified in model config'
if model_type == 'autoencoder':
from .autoencoders import create_autoencoder_from_config
return create_autoencoder_from_config(model_config)
elif model_type == 'diffusion_uncond':
from .diffusion import create_diffusion_uncond_from_config
return create_diffusion_uncond_from_config(model_config)
elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior":
from .diffusion import create_diffusion_cond_from_config
return create_diffusion_cond_from_config(model_config)
elif model_type == 'diffusion_autoencoder':
from .autoencoders import create_diffAE_from_config
return create_diffAE_from_config(model_config)
elif model_type == 'lm':
from .lm import create_audio_lm_from_config
return create_audio_lm_from_config(model_config)
else:
raise NotImplementedError(f'Unknown model type: {model_type}')
def create_model_from_config_path(model_config_path):
with open(model_config_path) as f:
model_config = json.load(f)
return create_model_from_config(model_config)
def create_pretransform_from_config(pretransform_config, sample_rate):
pretransform_type = pretransform_config.get('type', None)
assert pretransform_type is not None, 'type must be specified in pretransform config'
if pretransform_type == 'autoencoder':
from .autoencoders import create_autoencoder_from_config
from .pretransforms import AutoencoderPretransform
# Create fake top-level config to pass sample rate to autoencoder constructor
# This is a bit of a hack but it keeps us from re-defining the sample rate in the config
autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
autoencoder = create_autoencoder_from_config(autoencoder_config)
scale = pretransform_config.get("scale", 1.0)
model_half = pretransform_config.get("model_half", False)
iterate_batch = pretransform_config.get("iterate_batch", False)
chunked = pretransform_config.get("chunked", False)
pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
elif pretransform_type == 'wavelet':
from .pretransforms import WaveletPretransform
wavelet_config = pretransform_config["config"]
channels = wavelet_config["channels"]
levels = wavelet_config["levels"]
wavelet = wavelet_config["wavelet"]
pretransform = WaveletPretransform(channels, levels, wavelet)
elif pretransform_type == 'pqmf':
from .pretransforms import PQMFPretransform
pqmf_config = pretransform_config["config"]
pretransform = PQMFPretransform(**pqmf_config)
elif pretransform_type == 'dac_pretrained':
from .pretransforms import PretrainedDACPretransform
pretrained_dac_config = pretransform_config["config"]
pretransform = PretrainedDACPretransform(**pretrained_dac_config)
elif pretransform_type == "audiocraft_pretrained":
from .pretransforms import AudiocraftCompressionPretransform
audiocraft_config = pretransform_config["config"]
pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
else:
raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
enable_grad = pretransform_config.get('enable_grad', False)
pretransform.enable_grad = enable_grad
pretransform.eval().requires_grad_(pretransform.enable_grad)
return pretransform
def create_bottleneck_from_config(bottleneck_config):
bottleneck_type = bottleneck_config.get('type', None)
assert bottleneck_type is not None, 'type must be specified in bottleneck config'
if bottleneck_type == 'tanh':
from .bottleneck import TanhBottleneck
bottleneck = TanhBottleneck()
elif bottleneck_type == 'vae':
from .bottleneck import VAEBottleneck
bottleneck = VAEBottleneck()
elif bottleneck_type == 'rvq':
from .bottleneck import RVQBottleneck
quantizer_params = {
"dim": 128,
"codebook_size": 1024,
"num_quantizers": 8,
"decay": 0.99,
"kmeans_init": True,
"kmeans_iters": 50,
"threshold_ema_dead_code": 2,
}
quantizer_params.update(bottleneck_config["config"])
bottleneck = RVQBottleneck(**quantizer_params)
elif bottleneck_type == "dac_rvq":
from .bottleneck import DACRVQBottleneck
bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
elif bottleneck_type == 'rvq_vae':
from .bottleneck import RVQVAEBottleneck
quantizer_params = {
"dim": 128,
"codebook_size": 1024,
"num_quantizers": 8,
"decay": 0.99,
"kmeans_init": True,
"kmeans_iters": 50,
"threshold_ema_dead_code": 2,
}
quantizer_params.update(bottleneck_config["config"])
bottleneck = RVQVAEBottleneck(**quantizer_params)
elif bottleneck_type == 'dac_rvq_vae':
from .bottleneck import DACRVQVAEBottleneck
bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
elif bottleneck_type == 'l2_norm':
from .bottleneck import L2Bottleneck
bottleneck = L2Bottleneck()
elif bottleneck_type == "wasserstein":
from .bottleneck import WassersteinBottleneck
bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
elif bottleneck_type == "fsq":
from .bottleneck import FSQBottleneck
bottleneck = FSQBottleneck(**bottleneck_config["config"])
else:
raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
requires_grad = bottleneck_config.get('requires_grad', True)
if not requires_grad:
for param in bottleneck.parameters():
param.requires_grad = False
return bottleneck
|