import torch from torch import Tensor from transformers import PreTrainedModel from audio_encoders_pytorch import MelE1d, TanhBottleneck from audio_diffusion_pytorch import DiffusionAE, UNetV0, LTPlugin, VDiffusion, VSampler from .config import DMAE1dConfig class DMAE1d(PreTrainedModel): config_class = DMAE1dConfig def __init__(self, config: DMAE1dConfig): super().__init__(config) UNet = LTPlugin( UNetV0, num_filters=128, window_length=64, stride=32, ) self.model = DiffusionAE( net_t=UNet, dim=1, in_channels=2, channels=[256, 512, 512, 512, 512], factors=[1, 2, 2, 2, 2], items=[1, 2, 2, 2, 4], inject_depth=4, encoder=MelE1d( in_channels=2, channels=512, multipliers=[1, 1], factors=[2], num_blocks=[12], out_channels=32, mel_channels=80, mel_sample_rate=48000, mel_normalize_log=True, bottleneck=TanhBottleneck() ), diffusion_t=VDiffusion, sampler_t=VSampler ) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def encode(self, *args, **kwargs): return self.model.encode(*args, **kwargs) @torch.no_grad() def decode(self, *args, **kwargs): return self.model.decode(*args, **kwargs)