from transformers import PreTrainedModel from audio_encoders_pytorch import TanhBottleneck from audio_diffusion_pytorch import UniformDistribution, LinearSchedule, VSampler, DiffusionMAE1d from .dmae_config import DMAE1dConfig bottleneck = { 'tanh': TanhBottleneck } class DMAE1d(PreTrainedModel): config_class = DMAE1dConfig def __init__(self, config: DMAE1dConfig): super().__init__(config) self.model = DiffusionMAE1d( in_channels = config.in_channels, channels = config.channels, multipliers = config.multipliers, factors = config.factors, num_blocks = config.num_blocks, attentions = config.attentions, encoder_inject_depth = config.encoder_inject_depth, encoder_channels = config.encoder_channels, encoder_factors = config.encoder_factors, encoder_multipliers = config.encoder_multipliers, encoder_num_blocks = config.encoder_num_blocks, bottleneck = bottleneck[config.bottleneck](), stft_use_complex = config.stft_use_complex, stft_num_fft = config.stft_num_fft, stft_hop_length = config.stft_hop_length, diffusion_type = 'v', diffusion_sigma_distribution = UniformDistribution(), resnet_groups=8, kernel_multiplier_downsample=2, use_nearest_upsample=False, use_skip_scale=True, use_context_time=True, patch_factor=1, patch_blocks=1, ) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def encode(self, *args, **kwargs): return self.model.encode(*args, **kwargs) def decode(self, *args, **kwargs): default_kwargs = dict( sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True ) return self.model.decode(*args, **{**default_kwargs, **kwargs})