Spaces:
Runtime error
Runtime error
log after x epochs, not after 1
Browse files
scripts/train_unconditional.py
CHANGED
@@ -244,7 +244,9 @@ def main(args):
|
|
244 |
|
245 |
# Generate sample images for visual inspection
|
246 |
if accelerator.is_main_process:
|
247 |
-
if
|
|
|
|
|
248 |
if vqvae is not None:
|
249 |
pipeline = LatentAudioDiffusionPipeline(
|
250 |
unet=accelerator.unwrap_model(
|
@@ -275,7 +277,9 @@ def main(args):
|
|
275 |
else:
|
276 |
pipeline.save_pretrained(output_dir)
|
277 |
|
278 |
-
if
|
|
|
|
|
279 |
generator = torch.manual_seed(42)
|
280 |
# run pipeline in inference (sample random noise and denoise)
|
281 |
images, (sample_rate, audios) = pipeline(
|
|
|
244 |
|
245 |
# Generate sample images for visual inspection
|
246 |
if accelerator.is_main_process:
|
247 |
+
if (
|
248 |
+
epoch + 1
|
249 |
+
) % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
250 |
if vqvae is not None:
|
251 |
pipeline = LatentAudioDiffusionPipeline(
|
252 |
unet=accelerator.unwrap_model(
|
|
|
277 |
else:
|
278 |
pipeline.save_pretrained(output_dir)
|
279 |
|
280 |
+
if (
|
281 |
+
epoch + 1
|
282 |
+
) % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
283 |
generator = torch.manual_seed(42)
|
284 |
# run pipeline in inference (sample random noise and denoise)
|
285 |
images, (sample_rate, audios) = pipeline(
|