SLAYEROFALL3050 commited on
Commit
06d43f8
·
1 Parent(s): ad010fe

Create parse_train.py

Browse files
Files changed (1) hide show
  1. MusicModel/parse/parse_train.py +273 -0
MusicModel/parse/parse_train.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Any
3
+ import tensorflow as tf
4
+
5
+
6
+ class EasyDict(dict):
7
+ def __getattr__(self, name: str) -> Any:
8
+ try:
9
+ return self[name]
10
+ except KeyError:
11
+ raise AttributeError(name)
12
+
13
+ def __setattr__(self, name: str, value: Any) -> None:
14
+ self[name] = value
15
+
16
+ def __delattr__(self, name: str) -> None:
17
+ del self[name]
18
+
19
+
20
+ def str2bool(v):
21
+ if isinstance(v, bool):
22
+ return v
23
+ if v.lower() in ("yes", "true", "t", "y", "1"):
24
+ return True
25
+ elif v.lower() in ("no", "false", "f", "n", "0"):
26
+ return False
27
+ else:
28
+ raise argparse.ArgumentTypeError("Boolean value expected.")
29
+
30
+
31
+ def params_args(args):
32
+ parser = argparse.ArgumentParser()
33
+
34
+ parser.add_argument(
35
+ "--hop",
36
+ type=int,
37
+ default=256,
38
+ help="Hop size (window size = 4*hop)",
39
+ )
40
+ parser.add_argument(
41
+ "--mel_bins",
42
+ type=int,
43
+ default=256,
44
+ help="Mel bins in mel-spectrograms",
45
+ )
46
+ parser.add_argument(
47
+ "--sr",
48
+ type=int,
49
+ default=44100,
50
+ help="Sampling Rate",
51
+ )
52
+ parser.add_argument(
53
+ "--small",
54
+ type=str2bool,
55
+ default=False,
56
+ help="If True, use model with shorter available context, useful for small datasets",
57
+ )
58
+ parser.add_argument(
59
+ "--latdepth",
60
+ type=int,
61
+ default=64,
62
+ help="Depth of generated latent vectors",
63
+ )
64
+ parser.add_argument(
65
+ "--coorddepth",
66
+ type=int,
67
+ default=64,
68
+ help="Dimension of latent coordinate and style random vectors",
69
+ )
70
+ parser.add_argument(
71
+ "--max_lat_len",
72
+ type=int,
73
+ default=512,
74
+ help="Length of .npy arrays used for training",
75
+ )
76
+ parser.add_argument(
77
+ "--base_channels",
78
+ type=int,
79
+ default=128,
80
+ help="Base channels for generator and discriminator architectures",
81
+ )
82
+ parser.add_argument(
83
+ "--shape",
84
+ type=int,
85
+ default=128,
86
+ help="Length of spectrograms time axis",
87
+ )
88
+ parser.add_argument(
89
+ "--window",
90
+ type=int,
91
+ default=64,
92
+ help="Generator spectrogram window (must divide shape)",
93
+ )
94
+ parser.add_argument(
95
+ "--bs",
96
+ type=int,
97
+ default=32,
98
+ help="Batch size",
99
+ )
100
+ parser.add_argument(
101
+ "--lr",
102
+ type=float,
103
+ default=0.0001,
104
+ help="Learning Rate",
105
+ )
106
+ parser.add_argument(
107
+ "--gp_max_weight",
108
+ type=float,
109
+ default=10.0,
110
+ help="Maximum allowed R1 gradient penalty loss weight. The weight will self-adapt if high values are not needed for stable training.",
111
+ )
112
+ parser.add_argument(
113
+ "--totsamples",
114
+ type=int,
115
+ default=300000,
116
+ help="Max samples chosen per epoch",
117
+ )
118
+ parser.add_argument(
119
+ "--epochs",
120
+ type=int,
121
+ default=250,
122
+ help="Number of epochs",
123
+ )
124
+ parser.add_argument(
125
+ "--save_every",
126
+ type=int,
127
+ default=1,
128
+ help="Save after x epochs",
129
+ )
130
+ parser.add_argument(
131
+ "--mu_rescale",
132
+ type=float,
133
+ default=-25.0,
134
+ help="Spectrogram mu used to normalize",
135
+ )
136
+ parser.add_argument(
137
+ "--sigma_rescale",
138
+ type=float,
139
+ default=75.0,
140
+ help="Spectrogram sigma used to normalize",
141
+ )
142
+ parser.add_argument(
143
+ "--save_path",
144
+ type=str,
145
+ default="checkpoints",
146
+ help="Path where to save checkpoints",
147
+ )
148
+ parser.add_argument(
149
+ "--train_path",
150
+ type=str,
151
+ default="training_samples",
152
+ help="Path of training samples",
153
+ )
154
+ parser.add_argument(
155
+ "--dec_path",
156
+ type=str,
157
+ default="checkpoints/ae",
158
+ help="Path of pretrained decoders weights",
159
+ )
160
+ parser.add_argument(
161
+ "--load_path",
162
+ type=str,
163
+ default="None",
164
+ help="If not None, load models weights from this path",
165
+ )
166
+ parser.add_argument(
167
+ "--base_path",
168
+ type=str,
169
+ default="checkpoints",
170
+ help="Path where pretrained models are downloaded",
171
+ )
172
+ parser.add_argument(
173
+ "--log_path",
174
+ type=str,
175
+ default="logs",
176
+ help="Path where to save tensorboard logs",
177
+ )
178
+ parser.add_argument(
179
+ "--testing",
180
+ type=str2bool,
181
+ default=False,
182
+ help="True if optimizers weight do not need to be loaded",
183
+ )
184
+ parser.add_argument(
185
+ "--cpu",
186
+ type=str2bool,
187
+ default=False,
188
+ help="True if you wish to use cpu",
189
+ )
190
+ parser.add_argument(
191
+ "--mixed_precision",
192
+ type=str2bool,
193
+ default=True,
194
+ help="True if your GPU supports mixed precision",
195
+ )
196
+ parser.add_argument(
197
+ "--xla",
198
+ type=str2bool,
199
+ default=True,
200
+ help="True if you wish to improve training speed with XLA",
201
+ )
202
+ parser.add_argument(
203
+ "--share_gradio",
204
+ type=str2bool,
205
+ default=False,
206
+ help="True if you wish to create a public URL for the Gradio interface",
207
+ )
208
+
209
+ tmp_args = parser.parse_args()
210
+
211
+ args.hop = tmp_args.hop
212
+ args.mel_bins = tmp_args.mel_bins
213
+ args.sr = tmp_args.sr
214
+ args.small = tmp_args.small
215
+ args.latdepth = tmp_args.latdepth
216
+ args.coorddepth = tmp_args.coorddepth
217
+ args.max_lat_len = tmp_args.max_lat_len
218
+ args.base_channels = tmp_args.base_channels
219
+ args.shape = tmp_args.shape
220
+ args.window = tmp_args.window
221
+ args.bs = tmp_args.bs
222
+ args.lr = tmp_args.lr
223
+ args.gp_max_weight = tmp_args.gp_max_weight
224
+ args.totsamples = tmp_args.totsamples
225
+ args.epochs = tmp_args.epochs
226
+ args.save_every = tmp_args.save_every
227
+ args.mu_rescale = tmp_args.mu_rescale
228
+ args.sigma_rescale = tmp_args.sigma_rescale
229
+ args.save_path = tmp_args.save_path
230
+ args.train_path = tmp_args.train_path
231
+ args.dec_path = tmp_args.dec_path
232
+ args.load_path = tmp_args.load_path
233
+ args.base_path = tmp_args.base_path
234
+ args.log_path = tmp_args.log_path
235
+ args.testing = tmp_args.testing
236
+ args.cpu = tmp_args.cpu
237
+ args.mixed_precision = tmp_args.mixed_precision
238
+ args.xla = tmp_args.xla
239
+ args.share_gradio = tmp_args.share_gradio
240
+
241
+ if args.small:
242
+ args.latlen = 128
243
+ else:
244
+ args.latlen = 256
245
+ args.coordlen = (args.latlen // 2) * 3
246
+
247
+ print()
248
+
249
+ args.datatype = tf.float32
250
+ gpuls = tf.config.list_physical_devices("GPU")
251
+ if len(gpuls) == 0 or args.cpu:
252
+ args.cpu = True
253
+ args.mixed_precision = False
254
+ tf.config.set_visible_devices([], "GPU")
255
+ print()
256
+ print("Using CPU...")
257
+ print()
258
+ if args.mixed_precision:
259
+ args.datatype = tf.float16
260
+ print()
261
+ print("Using GPU with mixed precision enabled...")
262
+ print()
263
+ if not args.mixed_precision and not args.cpu:
264
+ print()
265
+ print("Using GPU without mixed precision...")
266
+ print()
267
+
268
+ return args
269
+
270
+
271
+ def parse_args():
272
+ args = EasyDict()
273
+ return params_args(args)