dmae1d-ATC32-v3 / model.py
flavioschneider's picture
Upload DMAE1d
3d43b81
import torch
from torch import Tensor
from transformers import PreTrainedModel
from audio_encoders_pytorch import MelE1d, TanhBottleneck
from audio_diffusion_pytorch import DiffusionAE, UNetV0, LTPlugin, VDiffusion, VSampler
from .config import DMAE1dConfig
class DMAE1d(PreTrainedModel):
config_class = DMAE1dConfig
def __init__(self, config: DMAE1dConfig):
super().__init__(config)
UNet = LTPlugin(
UNetV0,
num_filters=128,
window_length=64,
stride=32,
)
self.model = DiffusionAE(
net_t=UNet,
dim=1,
in_channels=2,
channels=[256, 512, 512, 512, 512],
factors=[1, 2, 2, 2, 2],
items=[1, 2, 2, 2, 4],
inject_depth=4,
encoder=MelE1d(
in_channels=2,
channels=512,
multipliers=[1, 1],
factors=[2],
num_blocks=[12],
out_channels=32,
mel_channels=80,
mel_sample_rate=48000,
mel_normalize_log=True,
bottleneck=TanhBottleneck()
),
diffusion_t=VDiffusion,
sampler_t=VSampler
)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def encode(self, *args, **kwargs):
return self.model.encode(*args, **kwargs)
@torch.no_grad()
def decode(self, *args, **kwargs):
return self.model.decode(*args, **kwargs)