flavioschneider commited on
Commit
0788506
·
1 Parent(s): c22200e

Upload DMAE1d

Browse files
Files changed (3) hide show
  1. config.json +6 -1
  2. dmae.py +55 -0
  3. pytorch_model.bin +3 -0
config.json CHANGED
@@ -1,4 +1,7 @@
1
  {
 
 
 
2
  "attentions": [
3
  0,
4
  0,
@@ -9,7 +12,8 @@
9
  0
10
  ],
11
  "auto_map": {
12
- "AutoConfig": "dmae_config.DMAE1dConfig"
 
13
  },
14
  "bottleneck": "tanh",
15
  "channels": 512,
@@ -73,5 +77,6 @@
73
  "stft_hop_length": 256,
74
  "stft_num_fft": 1023,
75
  "stft_use_complex": true,
 
76
  "transformers_version": "4.24.0"
77
  }
 
1
  {
2
+ "architectures": [
3
+ "DMAE1d"
4
+ ],
5
  "attentions": [
6
  0,
7
  0,
 
12
  0
13
  ],
14
  "auto_map": {
15
+ "AutoConfig": "dmae_config.DMAE1dConfig",
16
+ "AutoModel": "dmae.DMAE1d"
17
  },
18
  "bottleneck": "tanh",
19
  "channels": 512,
 
77
  "stft_hop_length": 256,
78
  "stft_num_fft": 1023,
79
  "stft_use_complex": true,
80
+ "torch_dtype": "float32",
81
  "transformers_version": "4.24.0"
82
  }
dmae.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from audio_encoders_pytorch import TanhBottleneck
3
+ from audio_diffusion_pytorch import UniformDistribution, LinearSchedule, VSampler, DiffusionMAE1d
4
+ from .dmae_config import DMAE1dConfig
5
+
6
+ bottleneck = { 'tanh': TanhBottleneck }
7
+
8
+ class DMAE1d(PreTrainedModel):
9
+
10
+ config_class = DMAE1dConfig
11
+
12
+ def __init__(self, config: DMAE1dConfig):
13
+ super().__init__(config)
14
+
15
+ self.model = DiffusionMAE1d(
16
+ in_channels = config.in_channels,
17
+ channels = config.channels,
18
+ multipliers = config.multipliers,
19
+ factors = config.factors,
20
+ num_blocks = config.num_blocks,
21
+ attentions = config.attentions,
22
+ encoder_inject_depth = config.encoder_inject_depth,
23
+ encoder_channels = config.encoder_channels,
24
+ encoder_factors = config.encoder_factors,
25
+ encoder_multipliers = config.encoder_multipliers,
26
+ encoder_num_blocks = config.encoder_num_blocks,
27
+ bottleneck = bottleneck[config.bottleneck](),
28
+ stft_use_complex = config.stft_use_complex,
29
+ stft_num_fft = config.stft_num_fft,
30
+ stft_hop_length = config.stft_hop_length,
31
+ diffusion_type = 'v',
32
+ diffusion_sigma_distribution = UniformDistribution(),
33
+ resnet_groups=8,
34
+ kernel_multiplier_downsample=2,
35
+ use_nearest_upsample=False,
36
+ use_skip_scale=True,
37
+ use_context_time=True,
38
+ patch_factor=1,
39
+ patch_blocks=1,
40
+ )
41
+
42
+ def forward(self, *args, **kwargs):
43
+ return self.model(*args, **kwargs)
44
+
45
+ def encode(self, *args, **kwargs):
46
+ return self.model.encode(*args, **kwargs)
47
+
48
+ def decode(self, *args, **kwargs):
49
+ default_kwargs = dict(
50
+ sigma_schedule=LinearSchedule(),
51
+ sampler=VSampler(),
52
+ clamp=True
53
+ )
54
+ return self.model.decode(*args, **{**default_kwargs, **kwargs})
55
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fd060f59068eb7f78b358d929104b3b5d986469364742b6db24b31d72e2c853
3
+ size 937207375