Spaces:
Runtime error
Runtime error
update train_unconditional for latent diffusion
Browse files- README.md +2 -2
- scripts/train_unconditional.py +16 -19
- scripts/train_vae.py +0 -2
README.md
CHANGED
@@ -89,7 +89,7 @@ accelerate launch --config_file config/accelerate_local.yaml \
|
|
89 |
scripts/train_unconditional.py \
|
90 |
--dataset_name teticio/audio-diffusion-256 \
|
91 |
--resolution 256 \
|
92 |
-
--output_dir
|
93 |
--num_epochs 100 \
|
94 |
--train_batch_size 2 \
|
95 |
--eval_batch_size 2 \
|
@@ -98,7 +98,7 @@ accelerate launch --config_file config/accelerate_local.yaml \
|
|
98 |
--lr_warmup_steps 500 \
|
99 |
--mixed_precision no \
|
100 |
--push_to_hub True \
|
101 |
-
--hub_model_id
|
102 |
--hub_token $(cat $HOME/.huggingface/token)
|
103 |
```
|
104 |
#### Run training on SageMaker.
|
|
|
89 |
scripts/train_unconditional.py \
|
90 |
--dataset_name teticio/audio-diffusion-256 \
|
91 |
--resolution 256 \
|
92 |
+
--output_dir audio-diffusion-256 \
|
93 |
--num_epochs 100 \
|
94 |
--train_batch_size 2 \
|
95 |
--eval_batch_size 2 \
|
|
|
98 |
--lr_warmup_steps 500 \
|
99 |
--mixed_precision no \
|
100 |
--push_to_hub True \
|
101 |
+
--hub_model_id audio-diffusion-256 \
|
102 |
--hub_token $(cat $HOME/.huggingface/token)
|
103 |
```
|
104 |
#### Run training on SageMaker.
|
scripts/train_unconditional.py
CHANGED
@@ -48,8 +48,9 @@ def main(args):
|
|
48 |
model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
|
49 |
else:
|
50 |
model = UNet2DModel(
|
51 |
-
|
52 |
-
|
|
|
53 |
layers_per_block=2,
|
54 |
block_out_channels=(128, 128, 256, 256, 512, 512),
|
55 |
down_block_types=(
|
@@ -114,7 +115,7 @@ def main(args):
|
|
114 |
def transforms(examples):
|
115 |
if args.vae is not None:
|
116 |
images = [
|
117 |
-
augmentations(image
|
118 |
for image in examples["image"]
|
119 |
]
|
120 |
else:
|
@@ -173,6 +174,13 @@ def main(args):
|
|
173 |
model.train()
|
174 |
for step, batch in enumerate(train_dataloader):
|
175 |
clean_images = batch["input"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
# Sample noise that we'll add to the images
|
177 |
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
178 |
bsz = clean_images.shape[0]
|
@@ -184,11 +192,6 @@ def main(args):
|
|
184 |
device=clean_images.device,
|
185 |
).long()
|
186 |
|
187 |
-
if args.vae is not None:
|
188 |
-
with torch.no_grad():
|
189 |
-
clean_images = vqvae.encode(
|
190 |
-
clean_images).latent_dist.sample()
|
191 |
-
|
192 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
193 |
# (this is the forward diffusion process)
|
194 |
noisy_images = noise_scheduler.add_noise(clean_images, noise,
|
@@ -196,8 +199,7 @@ def main(args):
|
|
196 |
|
197 |
with accelerator.accumulate(model):
|
198 |
# Predict the noise residual
|
199 |
-
|
200 |
-
noise_pred = vqvae.decode(images)["sample"]
|
201 |
loss = F.mse_loss(noise_pred, noise)
|
202 |
accelerator.backward(loss)
|
203 |
|
@@ -209,13 +211,6 @@ def main(args):
|
|
209 |
ema_model.step(model)
|
210 |
optimizer.zero_grad()
|
211 |
|
212 |
-
if args.vae is not None:
|
213 |
-
with torch.no_grad():
|
214 |
-
images = [
|
215 |
-
image.convert('L')
|
216 |
-
for image in vqvae.decode(images)["sample"]
|
217 |
-
]
|
218 |
-
|
219 |
if accelerator.sync_gradients:
|
220 |
progress_bar.update(1)
|
221 |
global_step += 1
|
@@ -239,14 +234,16 @@ def main(args):
|
|
239 |
if args.vae is not None:
|
240 |
pipeline = LDMPipeline(
|
241 |
unet=accelerator.unwrap_model(
|
242 |
-
ema_model.averaged_model if args.use_ema else model
|
|
|
243 |
vqvae=vqvae,
|
244 |
scheduler=noise_scheduler,
|
245 |
)
|
246 |
else:
|
247 |
pipeline = DDPMPipeline(
|
248 |
unet=accelerator.unwrap_model(
|
249 |
-
ema_model.averaged_model if args.use_ema else model
|
|
|
250 |
scheduler=noise_scheduler,
|
251 |
)
|
252 |
|
|
|
48 |
model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
|
49 |
else:
|
50 |
model = UNet2DModel(
|
51 |
+
sample_size=args.resolution if args.vae is None else 64,
|
52 |
+
in_channels=1 if args.vae is None else 3,
|
53 |
+
out_channels=1 if args.vae is None else 3,
|
54 |
layers_per_block=2,
|
55 |
block_out_channels=(128, 128, 256, 256, 512, 512),
|
56 |
down_block_types=(
|
|
|
115 |
def transforms(examples):
|
116 |
if args.vae is not None:
|
117 |
images = [
|
118 |
+
augmentations(image.convert("RGB"))
|
119 |
for image in examples["image"]
|
120 |
]
|
121 |
else:
|
|
|
174 |
model.train()
|
175 |
for step, batch in enumerate(train_dataloader):
|
176 |
clean_images = batch["input"]
|
177 |
+
|
178 |
+
if args.vae is not None:
|
179 |
+
vqvae.to(clean_images.device)
|
180 |
+
with torch.no_grad():
|
181 |
+
clean_images = vqvae.encode(
|
182 |
+
clean_images).latent_dist.sample()
|
183 |
+
|
184 |
# Sample noise that we'll add to the images
|
185 |
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
186 |
bsz = clean_images.shape[0]
|
|
|
192 |
device=clean_images.device,
|
193 |
).long()
|
194 |
|
|
|
|
|
|
|
|
|
|
|
195 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
196 |
# (this is the forward diffusion process)
|
197 |
noisy_images = noise_scheduler.add_noise(clean_images, noise,
|
|
|
199 |
|
200 |
with accelerator.accumulate(model):
|
201 |
# Predict the noise residual
|
202 |
+
noise_pred = model(noisy_images, timesteps)["sample"]
|
|
|
203 |
loss = F.mse_loss(noise_pred, noise)
|
204 |
accelerator.backward(loss)
|
205 |
|
|
|
211 |
ema_model.step(model)
|
212 |
optimizer.zero_grad()
|
213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
if accelerator.sync_gradients:
|
215 |
progress_bar.update(1)
|
216 |
global_step += 1
|
|
|
234 |
if args.vae is not None:
|
235 |
pipeline = LDMPipeline(
|
236 |
unet=accelerator.unwrap_model(
|
237 |
+
ema_model.averaged_model if args.use_ema else model
|
238 |
+
),
|
239 |
vqvae=vqvae,
|
240 |
scheduler=noise_scheduler,
|
241 |
)
|
242 |
else:
|
243 |
pipeline = DDPMPipeline(
|
244 |
unet=accelerator.unwrap_model(
|
245 |
+
ema_model.averaged_model if args.use_ema else model
|
246 |
+
),
|
247 |
scheduler=noise_scheduler,
|
248 |
)
|
249 |
|
scripts/train_vae.py
CHANGED
@@ -4,9 +4,7 @@
|
|
4 |
|
5 |
# TODO
|
6 |
# grayscale
|
7 |
-
# add vae to train_uncond (no_grad)
|
8 |
# update README
|
9 |
-
# merge in changes to train_unconditional
|
10 |
|
11 |
import os
|
12 |
import argparse
|
|
|
4 |
|
5 |
# TODO
|
6 |
# grayscale
|
|
|
7 |
# update README
|
|
|
8 |
|
9 |
import os
|
10 |
import argparse
|