update
Browse files- bigvgan.py +34 -20
bigvgan.py
CHANGED
@@ -257,14 +257,18 @@ class BigVGAN(
|
|
257 |
return x
|
258 |
|
259 |
def remove_weight_norm(self):
|
260 |
-
|
261 |
-
|
262 |
-
for
|
263 |
-
|
264 |
-
|
265 |
-
l.
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
|
|
268 |
|
269 |
##################################################################
|
270 |
# additional methods for huggingface_hub support
|
@@ -304,17 +308,21 @@ class BigVGAN(
|
|
304 |
##################################################################
|
305 |
# download and load hyperparameters (h) used by BigVGAN
|
306 |
##################################################################
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
|
|
|
|
|
|
|
|
318 |
h = load_hparams_from_json(config_file)
|
319 |
|
320 |
##################################################################
|
@@ -347,6 +355,12 @@ class BigVGAN(
|
|
347 |
)
|
348 |
|
349 |
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
|
352 |
return model
|
|
|
257 |
return x
|
258 |
|
259 |
def remove_weight_norm(self):
|
260 |
+
try:
|
261 |
+
print('Removing weight norm...')
|
262 |
+
for l in self.ups:
|
263 |
+
for l_i in l:
|
264 |
+
remove_weight_norm(l_i)
|
265 |
+
for l in self.resblocks:
|
266 |
+
l.remove_weight_norm()
|
267 |
+
remove_weight_norm(self.conv_pre)
|
268 |
+
remove_weight_norm(self.conv_post)
|
269 |
+
except ValueError:
|
270 |
+
print('[INFO] Model already removed weight norm. Skipping!')
|
271 |
+
pass
|
272 |
|
273 |
##################################################################
|
274 |
# additional methods for huggingface_hub support
|
|
|
308 |
##################################################################
|
309 |
# download and load hyperparameters (h) used by BigVGAN
|
310 |
##################################################################
|
311 |
+
if os.path.isdir(model_id):
|
312 |
+
print("Loading config.json from local directory")
|
313 |
+
config_file = os.path.join(model_id, 'config.json')
|
314 |
+
else:
|
315 |
+
config_file = hf_hub_download(
|
316 |
+
repo_id=model_id,
|
317 |
+
filename='config.json',
|
318 |
+
revision=revision,
|
319 |
+
cache_dir=cache_dir,
|
320 |
+
force_download=force_download,
|
321 |
+
proxies=proxies,
|
322 |
+
resume_download=resume_download,
|
323 |
+
token=token,
|
324 |
+
local_files_only=local_files_only,
|
325 |
+
)
|
326 |
h = load_hparams_from_json(config_file)
|
327 |
|
328 |
##################################################################
|
|
|
355 |
)
|
356 |
|
357 |
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
358 |
+
|
359 |
+
try:
|
360 |
+
model.load_state_dict(checkpoint_dict['generator'])
|
361 |
+
except RuntimeError:
|
362 |
+
print(f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!")
|
363 |
+
model.remove_weight_norm()
|
364 |
+
model.load_state_dict(checkpoint_dict['generator'])
|
365 |
|
366 |
return model
|