|
import torch |
|
from torch import Tensor |
|
from transformers import PreTrainedModel |
|
from audio_encoders_pytorch import MelE1d, TanhBottleneck |
|
from audio_diffusion_pytorch import DiffusionAE, UNetV0, LTPlugin, VDiffusion, VSampler |
|
from .config import DMAE1dConfig |
|
|
|
|
|
class DMAE1d(PreTrainedModel): |
|
|
|
config_class = DMAE1dConfig |
|
|
|
def __init__(self, config: DMAE1dConfig): |
|
super().__init__(config) |
|
|
|
UNet = LTPlugin( |
|
UNetV0, |
|
num_filters=128, |
|
window_length=64, |
|
stride=32, |
|
) |
|
|
|
self.model = DiffusionAE( |
|
net_t=UNet, |
|
dim=1, |
|
in_channels=2, |
|
channels=[256, 512, 512, 512, 512], |
|
factors=[1, 2, 2, 2, 2], |
|
items=[1, 2, 2, 2, 4], |
|
inject_depth=4, |
|
encoder=MelE1d( |
|
in_channels=2, |
|
channels=512, |
|
multipliers=[1, 1], |
|
factors=[2], |
|
num_blocks=[12], |
|
out_channels=32, |
|
mel_channels=80, |
|
mel_sample_rate=48000, |
|
mel_normalize_log=True, |
|
bottleneck=TanhBottleneck() |
|
), |
|
diffusion_t=VDiffusion, |
|
sampler_t=VSampler |
|
) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
def encode(self, *args, **kwargs): |
|
return self.model.encode(*args, **kwargs) |
|
|
|
@torch.no_grad() |
|
def decode(self, *args, **kwargs): |
|
return self.model.decode(*args, **kwargs) |
|
|
|
|