Spaces:
Runtime error
Runtime error
Commit
·
06d43f8
1
Parent(s):
ad010fe
Create parse_train.py
Browse files- MusicModel/parse/parse_train.py +273 -0
MusicModel/parse/parse_train.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from typing import Any
|
3 |
+
import tensorflow as tf
|
4 |
+
|
5 |
+
|
6 |
+
class EasyDict(dict):
|
7 |
+
def __getattr__(self, name: str) -> Any:
|
8 |
+
try:
|
9 |
+
return self[name]
|
10 |
+
except KeyError:
|
11 |
+
raise AttributeError(name)
|
12 |
+
|
13 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
14 |
+
self[name] = value
|
15 |
+
|
16 |
+
def __delattr__(self, name: str) -> None:
|
17 |
+
del self[name]
|
18 |
+
|
19 |
+
|
20 |
+
def str2bool(v):
|
21 |
+
if isinstance(v, bool):
|
22 |
+
return v
|
23 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
24 |
+
return True
|
25 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
26 |
+
return False
|
27 |
+
else:
|
28 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
29 |
+
|
30 |
+
|
31 |
+
def params_args(args):
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
|
34 |
+
parser.add_argument(
|
35 |
+
"--hop",
|
36 |
+
type=int,
|
37 |
+
default=256,
|
38 |
+
help="Hop size (window size = 4*hop)",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--mel_bins",
|
42 |
+
type=int,
|
43 |
+
default=256,
|
44 |
+
help="Mel bins in mel-spectrograms",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--sr",
|
48 |
+
type=int,
|
49 |
+
default=44100,
|
50 |
+
help="Sampling Rate",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--small",
|
54 |
+
type=str2bool,
|
55 |
+
default=False,
|
56 |
+
help="If True, use model with shorter available context, useful for small datasets",
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--latdepth",
|
60 |
+
type=int,
|
61 |
+
default=64,
|
62 |
+
help="Depth of generated latent vectors",
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--coorddepth",
|
66 |
+
type=int,
|
67 |
+
default=64,
|
68 |
+
help="Dimension of latent coordinate and style random vectors",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--max_lat_len",
|
72 |
+
type=int,
|
73 |
+
default=512,
|
74 |
+
help="Length of .npy arrays used for training",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--base_channels",
|
78 |
+
type=int,
|
79 |
+
default=128,
|
80 |
+
help="Base channels for generator and discriminator architectures",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--shape",
|
84 |
+
type=int,
|
85 |
+
default=128,
|
86 |
+
help="Length of spectrograms time axis",
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"--window",
|
90 |
+
type=int,
|
91 |
+
default=64,
|
92 |
+
help="Generator spectrogram window (must divide shape)",
|
93 |
+
)
|
94 |
+
parser.add_argument(
|
95 |
+
"--bs",
|
96 |
+
type=int,
|
97 |
+
default=32,
|
98 |
+
help="Batch size",
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--lr",
|
102 |
+
type=float,
|
103 |
+
default=0.0001,
|
104 |
+
help="Learning Rate",
|
105 |
+
)
|
106 |
+
parser.add_argument(
|
107 |
+
"--gp_max_weight",
|
108 |
+
type=float,
|
109 |
+
default=10.0,
|
110 |
+
help="Maximum allowed R1 gradient penalty loss weight. The weight will self-adapt if high values are not needed for stable training.",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--totsamples",
|
114 |
+
type=int,
|
115 |
+
default=300000,
|
116 |
+
help="Max samples chosen per epoch",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--epochs",
|
120 |
+
type=int,
|
121 |
+
default=250,
|
122 |
+
help="Number of epochs",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--save_every",
|
126 |
+
type=int,
|
127 |
+
default=1,
|
128 |
+
help="Save after x epochs",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--mu_rescale",
|
132 |
+
type=float,
|
133 |
+
default=-25.0,
|
134 |
+
help="Spectrogram mu used to normalize",
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--sigma_rescale",
|
138 |
+
type=float,
|
139 |
+
default=75.0,
|
140 |
+
help="Spectrogram sigma used to normalize",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--save_path",
|
144 |
+
type=str,
|
145 |
+
default="checkpoints",
|
146 |
+
help="Path where to save checkpoints",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--train_path",
|
150 |
+
type=str,
|
151 |
+
default="training_samples",
|
152 |
+
help="Path of training samples",
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--dec_path",
|
156 |
+
type=str,
|
157 |
+
default="checkpoints/ae",
|
158 |
+
help="Path of pretrained decoders weights",
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--load_path",
|
162 |
+
type=str,
|
163 |
+
default="None",
|
164 |
+
help="If not None, load models weights from this path",
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--base_path",
|
168 |
+
type=str,
|
169 |
+
default="checkpoints",
|
170 |
+
help="Path where pretrained models are downloaded",
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--log_path",
|
174 |
+
type=str,
|
175 |
+
default="logs",
|
176 |
+
help="Path where to save tensorboard logs",
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--testing",
|
180 |
+
type=str2bool,
|
181 |
+
default=False,
|
182 |
+
help="True if optimizers weight do not need to be loaded",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--cpu",
|
186 |
+
type=str2bool,
|
187 |
+
default=False,
|
188 |
+
help="True if you wish to use cpu",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--mixed_precision",
|
192 |
+
type=str2bool,
|
193 |
+
default=True,
|
194 |
+
help="True if your GPU supports mixed precision",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--xla",
|
198 |
+
type=str2bool,
|
199 |
+
default=True,
|
200 |
+
help="True if you wish to improve training speed with XLA",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--share_gradio",
|
204 |
+
type=str2bool,
|
205 |
+
default=False,
|
206 |
+
help="True if you wish to create a public URL for the Gradio interface",
|
207 |
+
)
|
208 |
+
|
209 |
+
tmp_args = parser.parse_args()
|
210 |
+
|
211 |
+
args.hop = tmp_args.hop
|
212 |
+
args.mel_bins = tmp_args.mel_bins
|
213 |
+
args.sr = tmp_args.sr
|
214 |
+
args.small = tmp_args.small
|
215 |
+
args.latdepth = tmp_args.latdepth
|
216 |
+
args.coorddepth = tmp_args.coorddepth
|
217 |
+
args.max_lat_len = tmp_args.max_lat_len
|
218 |
+
args.base_channels = tmp_args.base_channels
|
219 |
+
args.shape = tmp_args.shape
|
220 |
+
args.window = tmp_args.window
|
221 |
+
args.bs = tmp_args.bs
|
222 |
+
args.lr = tmp_args.lr
|
223 |
+
args.gp_max_weight = tmp_args.gp_max_weight
|
224 |
+
args.totsamples = tmp_args.totsamples
|
225 |
+
args.epochs = tmp_args.epochs
|
226 |
+
args.save_every = tmp_args.save_every
|
227 |
+
args.mu_rescale = tmp_args.mu_rescale
|
228 |
+
args.sigma_rescale = tmp_args.sigma_rescale
|
229 |
+
args.save_path = tmp_args.save_path
|
230 |
+
args.train_path = tmp_args.train_path
|
231 |
+
args.dec_path = tmp_args.dec_path
|
232 |
+
args.load_path = tmp_args.load_path
|
233 |
+
args.base_path = tmp_args.base_path
|
234 |
+
args.log_path = tmp_args.log_path
|
235 |
+
args.testing = tmp_args.testing
|
236 |
+
args.cpu = tmp_args.cpu
|
237 |
+
args.mixed_precision = tmp_args.mixed_precision
|
238 |
+
args.xla = tmp_args.xla
|
239 |
+
args.share_gradio = tmp_args.share_gradio
|
240 |
+
|
241 |
+
if args.small:
|
242 |
+
args.latlen = 128
|
243 |
+
else:
|
244 |
+
args.latlen = 256
|
245 |
+
args.coordlen = (args.latlen // 2) * 3
|
246 |
+
|
247 |
+
print()
|
248 |
+
|
249 |
+
args.datatype = tf.float32
|
250 |
+
gpuls = tf.config.list_physical_devices("GPU")
|
251 |
+
if len(gpuls) == 0 or args.cpu:
|
252 |
+
args.cpu = True
|
253 |
+
args.mixed_precision = False
|
254 |
+
tf.config.set_visible_devices([], "GPU")
|
255 |
+
print()
|
256 |
+
print("Using CPU...")
|
257 |
+
print()
|
258 |
+
if args.mixed_precision:
|
259 |
+
args.datatype = tf.float16
|
260 |
+
print()
|
261 |
+
print("Using GPU with mixed precision enabled...")
|
262 |
+
print()
|
263 |
+
if not args.mixed_precision and not args.cpu:
|
264 |
+
print()
|
265 |
+
print("Using GPU without mixed precision...")
|
266 |
+
print()
|
267 |
+
|
268 |
+
return args
|
269 |
+
|
270 |
+
|
271 |
+
def parse_args():
|
272 |
+
args = EasyDict()
|
273 |
+
return params_args(args)
|