Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
All the functions to build the relevant models and modules | |
from the Hydra config. | |
""" | |
import typing as tp | |
import omegaconf | |
import torch | |
import audiocraft | |
from .. import quantization as qt | |
from ..modules.codebooks_patterns import (CoarseFirstPattern, | |
CodebooksPatternProvider, | |
DelayedPatternProvider, | |
MusicLMPattern, | |
ParallelPatternProvider, | |
UnrolledPatternProvider) | |
from ..modules.conditioners import (BaseConditioner, ChromaStemConditioner, | |
CLAPEmbeddingConditioner, ConditionFuser, | |
ConditioningProvider, LUTConditioner, | |
T5Conditioner) | |
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor | |
from ..utils.utils import dict_from_config | |
from .encodec import (CompressionModel, EncodecModel, | |
InterleaveStereoCompressionModel) | |
from .flow import FlowModel | |
from .lm import LMModel | |
from .lm_magnet import MagnetLMModel | |
from .unet import DiffusionUnet | |
from .watermark import WMModel | |
def get_quantizer( | |
quantizer: str, cfg: omegaconf.DictConfig, dimension: int | |
) -> qt.BaseQuantizer: | |
klass = {"no_quant": qt.DummyQuantizer, "rvq": qt.ResidualVectorQuantizer}[ | |
quantizer | |
] | |
kwargs = dict_from_config(getattr(cfg, quantizer)) | |
if quantizer != "no_quant": | |
kwargs["dimension"] = dimension | |
return klass(**kwargs) | |
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): | |
if encoder_name == "seanet": | |
kwargs = dict_from_config(getattr(cfg, "seanet")) | |
encoder_override_kwargs = kwargs.pop("encoder") | |
decoder_override_kwargs = kwargs.pop("decoder") | |
encoder_kwargs = {**kwargs, **encoder_override_kwargs} | |
decoder_kwargs = {**kwargs, **decoder_override_kwargs} | |
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs) | |
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs) | |
return encoder, decoder | |
else: | |
raise KeyError(f"Unexpected compression model {cfg.compression_model}") | |
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: | |
"""Instantiate a compression model.""" | |
if cfg.compression_model == "encodec": | |
kwargs = dict_from_config(getattr(cfg, "encodec")) | |
encoder_name = kwargs.pop("autoencoder") | |
quantizer_name = kwargs.pop("quantizer") | |
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) | |
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) | |
frame_rate = kwargs["sample_rate"] // encoder.hop_length | |
renormalize = kwargs.pop("renormalize", False) | |
# deprecated params | |
kwargs.pop("renorm", None) | |
return EncodecModel( | |
encoder, | |
decoder, | |
quantizer, | |
frame_rate=frame_rate, | |
renormalize=renormalize, | |
**kwargs, | |
).to(cfg.device) | |
else: | |
raise KeyError(f"Unexpected compression model {cfg.compression_model}") | |
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: | |
"""Instantiate a transformer LM.""" | |
if cfg.lm_model in ["transformer_lm", "transformer_lm_magnet"]: | |
kwargs = dict_from_config(getattr(cfg, "transformer_lm")) | |
n_q = kwargs["n_q"] | |
q_modeling = kwargs.pop("q_modeling", None) | |
codebooks_pattern_cfg = getattr(cfg, "codebooks_pattern") | |
attribute_dropout = dict_from_config(getattr(cfg, "attribute_dropout")) | |
cls_free_guidance = dict_from_config(getattr(cfg, "classifier_free_guidance")) | |
cfg_prob, cfg_coef = ( | |
cls_free_guidance["training_dropout"], | |
cls_free_guidance["inference_coef"], | |
) | |
fuser = get_condition_fuser(cfg) | |
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) | |
if len(fuser.fuse2cond["cross"]) > 0: # enforce cross-att programmatically | |
kwargs["cross_attention"] = True | |
if codebooks_pattern_cfg.modeling is None: | |
assert ( | |
q_modeling is not None | |
), "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" | |
codebooks_pattern_cfg = omegaconf.OmegaConf.create( | |
{"modeling": q_modeling, "delay": {"delays": list(range(n_q))}} | |
) | |
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) | |
lm_class = MagnetLMModel if cfg.lm_model == "transformer_lm_magnet" else LMModel | |
return lm_class( | |
pattern_provider=pattern_provider, | |
condition_provider=condition_provider, | |
fuser=fuser, | |
cfg_dropout=cfg_prob, | |
cfg_coef=cfg_coef, | |
attribute_dropout=attribute_dropout, | |
dtype=getattr(torch, cfg.dtype), | |
device=cfg.device, | |
**kwargs, | |
).to(cfg.device) | |
else: | |
raise KeyError(f"Unexpected LM model {cfg.lm_model}") | |
def get_dit_model(cfg: omegaconf.DictConfig) -> FlowModel: | |
"""Instantiate a DiT""" | |
kwargs = dict_from_config(cfg.transformer_lm) | |
mask_cross_attention = kwargs.get("mask_cross_attention", False) | |
fuser = get_condition_fuser( | |
cfg, | |
).to(cfg.device) | |
condition_provider = get_conditioner_provider( | |
kwargs["dim"], | |
cfg, | |
).to(cfg.device) | |
kwargs["cross_attention"] = ( | |
True if len(fuser.fuse2cond["cross"]) > 0 else False | |
) # cross-att is dependent on fuser type | |
if not kwargs["cross_attention"] and mask_cross_attention: | |
kwargs["mask_cross_attention"] = False | |
fuser.mask_cross_attention = False | |
flow_model = FlowModel( | |
condition_provider, | |
fuser, | |
device=cfg.device, | |
**kwargs, | |
) | |
return flow_model | |
def get_conditioner_provider( | |
output_dim: int, cfg: omegaconf.DictConfig | |
) -> ConditioningProvider: | |
"""Instantiate a conditioning model.""" | |
device = cfg.device | |
duration = cfg.dataset.segment_duration | |
cfg = getattr(cfg, "conditioners") | |
dict_cfg = {} if cfg is None else dict_from_config(cfg) | |
conditioners: tp.Dict[str, BaseConditioner] = {} | |
condition_provider_args = dict_cfg.pop("args", {}) | |
condition_provider_args.pop("merge_text_conditions_p", None) | |
condition_provider_args.pop("drop_desc_p", None) | |
for cond, cond_cfg in dict_cfg.items(): | |
model_type = cond_cfg["model"] | |
model_args = cond_cfg[model_type] | |
if model_type == "t5": | |
conditioners[str(cond)] = T5Conditioner( | |
output_dim=output_dim, device=device, **model_args | |
) | |
elif model_type == "lut": | |
conditioners[str(cond)] = LUTConditioner( | |
output_dim=output_dim, **model_args | |
) | |
elif model_type == "chroma_stem": | |
conditioners[str(cond)] = ChromaStemConditioner( | |
output_dim=output_dim, duration=duration, device=device, **model_args | |
) | |
elif model_type == "clap": | |
conditioners[str(cond)] = CLAPEmbeddingConditioner( | |
output_dim=output_dim, device=device, **model_args | |
) | |
else: | |
raise ValueError(f"Unrecognized conditioning model: {model_type}") | |
conditioner = ConditioningProvider( | |
conditioners, device=device, **condition_provider_args | |
) | |
return conditioner | |
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: | |
"""Instantiate a condition fuser object.""" | |
fuser_cfg = getattr(cfg, "fuser") | |
fuser_methods = ["sum", "cross", "prepend", "input_interpolate"] | |
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} | |
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} | |
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) | |
return fuser | |
def get_codebooks_pattern_provider( | |
n_q: int, cfg: omegaconf.DictConfig | |
) -> CodebooksPatternProvider: | |
"""Instantiate a codebooks pattern provider object.""" | |
pattern_providers = { | |
"parallel": ParallelPatternProvider, | |
"delay": DelayedPatternProvider, | |
"unroll": UnrolledPatternProvider, | |
"coarse_first": CoarseFirstPattern, | |
"musiclm": MusicLMPattern, | |
} | |
name = cfg.modeling | |
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} | |
klass = pattern_providers[name] | |
return klass(n_q, **kwargs) | |
def get_debug_compression_model(device="cpu", sample_rate: int = 32000): | |
"""Instantiate a debug compression model to be used for unit tests.""" | |
assert sample_rate in [ | |
16000, | |
32000, | |
], "unsupported sample rate for debug compression model" | |
model_ratios = { | |
16000: [10, 8, 8], # 25 Hz at 16kHz | |
32000: [10, 8, 16], # 25 Hz at 32kHz | |
} | |
ratios: tp.List[int] = model_ratios[sample_rate] | |
frame_rate = 25 | |
seanet_kwargs: dict = { | |
"n_filters": 4, | |
"n_residual_layers": 1, | |
"dimension": 32, | |
"ratios": ratios, | |
} | |
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) | |
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) | |
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4) | |
init_x = torch.randn(8, 32, 128) | |
quantizer(init_x, 1) # initialize kmeans etc. | |
compression_model = EncodecModel( | |
encoder, | |
decoder, | |
quantizer, | |
frame_rate=frame_rate, | |
sample_rate=sample_rate, | |
channels=1, | |
).to(device) | |
return compression_model.eval() | |
def get_diffusion_model(cfg: omegaconf.DictConfig): | |
# TODO Find a way to infer the channels from dset | |
channels = cfg.channels | |
num_steps = cfg.schedule.num_steps | |
return DiffusionUnet(chin=channels, num_steps=num_steps, **cfg.diffusion_unet) | |
def get_processor(cfg, sample_rate: int = 24000): | |
sample_processor = SampleProcessor() | |
if cfg.use: | |
kw = dict(cfg) | |
kw.pop("use") | |
kw.pop("name") | |
if cfg.name == "multi_band_processor": | |
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw) | |
return sample_processor | |
def get_debug_lm_model(device="cpu"): | |
"""Instantiate a debug LM to be used for unit tests.""" | |
pattern = DelayedPatternProvider(n_q=4) | |
dim = 16 | |
providers = { | |
"description": LUTConditioner( | |
n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace" | |
), | |
} | |
condition_provider = ConditioningProvider(providers) | |
fuser = ConditionFuser( | |
{"cross": ["description"], "prepend": [], "sum": [], "input_interpolate": []} | |
) | |
lm = LMModel( | |
pattern, | |
condition_provider, | |
fuser, | |
n_q=4, | |
card=400, | |
dim=dim, | |
num_heads=4, | |
custom=True, | |
num_layers=2, | |
cross_attention=True, | |
causal=True, | |
) | |
return lm.to(device).eval() | |
def get_wrapped_compression_model( | |
compression_model: CompressionModel, cfg: omegaconf.DictConfig | |
) -> CompressionModel: | |
if hasattr(cfg, "interleave_stereo_codebooks"): | |
if cfg.interleave_stereo_codebooks.use: | |
kwargs = dict_from_config(cfg.interleave_stereo_codebooks) | |
kwargs.pop("use") | |
compression_model = InterleaveStereoCompressionModel( | |
compression_model, **kwargs | |
) | |
if hasattr(cfg, "compression_model_n_q"): | |
if cfg.compression_model_n_q is not None: | |
compression_model.set_num_codebooks(cfg.compression_model_n_q) | |
return compression_model | |
def get_watermark_model(cfg: omegaconf.DictConfig) -> WMModel: | |
"""Build a WMModel based by audioseal. This requires audioseal to be installed""" | |
import audioseal | |
from .watermark import AudioSeal | |
# Builder encoder and decoder directly using audiocraft API to avoid cyclic import | |
assert hasattr( | |
cfg, "seanet" | |
), "Missing required `seanet` parameters in AudioSeal config" | |
encoder, decoder = get_encodec_autoencoder("seanet", cfg) | |
# Build message processor | |
kwargs = ( | |
dict_from_config(getattr(cfg, "audioseal")) if hasattr(cfg, "audioseal") else {} | |
) | |
nbits = kwargs.get("nbits", 0) | |
hidden_size = getattr(cfg.seanet, "dimension", 128) | |
msg_processor = audioseal.MsgProcessor(nbits, hidden_size=hidden_size) | |
# Build detector using audioseal API | |
def _get_audioseal_detector(): | |
# We don't need encoder and decoder params from seanet, remove them | |
seanet_cfg = dict_from_config(cfg.seanet) | |
seanet_cfg.pop("encoder") | |
seanet_cfg.pop("decoder") | |
detector_cfg = dict_from_config(cfg.detector) | |
typed_seanet_cfg = audioseal.builder.SEANetConfig(**seanet_cfg) | |
typed_detector_cfg = audioseal.builder.DetectorConfig(**detector_cfg) | |
_cfg = audioseal.builder.AudioSealDetectorConfig( | |
nbits=nbits, seanet=typed_seanet_cfg, detector=typed_detector_cfg | |
) | |
return audioseal.builder.create_detector(_cfg) | |
detector = _get_audioseal_detector() | |
generator = audioseal.AudioSealWM( | |
encoder=encoder, decoder=decoder, msg_processor=msg_processor | |
) | |
model = AudioSeal(generator=generator, detector=detector, nbits=nbits) | |
device = torch.device(getattr(cfg, "device", "cpu")) | |
dtype = getattr(torch, getattr(cfg, "dtype", "float32")) | |
return model.to(device=device, dtype=dtype) | |