diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000000000000000000000000000000000000..c90cef80cdee682bb9b4dfd890bc590e845173de --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,9 @@ +cff-version: 1.2.0 +message: "If you use this code, please cite it as below." +authors: +- family-names: "Wu" + given-names: "Hecong" +title: "Pixel Guide Diffusion For Anime Colorization" +version: 1.0.0 +date-released: 2021-10-26 +url: "https://github.com/HighCWu/pixel-guide-diffusion-for-anime-colorization" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..046b2c98cff6af9c1e6103ac9740debbc8214f12 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Wu Hecong + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 8365cdd81fd7dcafe28b7da133f408f7828ac23e..b416200f6309f6d3d578f0697668de8fe18c1f99 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,176 @@ ---- -title: Anime Colorization -emoji: 😻 -colorFrom: indigo -colorTo: pink -sdk: gradio -sdk_version: 3.0.5 -app_file: app.py -pinned: false -license: mit ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference +# Pixel Guide Diffusion For Anime Colorization + +![avatar](docs/imgs/sample.png) + +Use denoising diffusion probabilistic model to do the anime colorization task. + +v1 test result is in branch [v1_result](https://github.com/HighCWu/pixel-guide-diffusion-for-anime-colorization/tree/v1_result). + +The dataset is not clean enough and the sketch as the guide is generated using sketch2keras, so the generalization is not good. + +In the future, I may try to use only anime portraits as the target images, and look for some more diverse sketch models. + +# Introduction and Usage + +Pixel Guide Denoising Diffusion Probabilistic Models ( One Channel Guide Version ) + +This repo is modified from [improved-diffusion](https://github.com/openai/improved-diffusion). + +Use [danbooru-sketch-pair-128x](https://www.kaggle.com/wuhecong/danbooru-sketch-pair-128x) as the dataset. Maybe you should move folders in the dataset first to make guide-target pair dataset. + +Modify `train_danbooru*.sh`, `test_danbooru*.sh` to meet your needs. + +The model is divided into a 32px part and a super-divided part, which can be cascaded during testing to get the final result. But there is no cascade during training. + +QQ Group: 1044867291 + +Discord: https://discord.gg/YwWcAS47qb + +# Original README + +# improved-diffusion + +This is the codebase for [Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672). + +# Usage + +This section of the README walks through how to train and sample from a model. + +## Installation + +Clone this repository and navigate to it in your terminal. Then run: + +``` +pip install -e . +``` + +This should install the ~~`improved_diffusion`~~ `pixel_guide_diffusion` python package that the scripts depend on. + +## Preparing Data + +The training code reads images from a directory of image files. In the [datasets](datasets) folder, we have provided instructions/scripts for preparing these directories for ImageNet, LSUN bedrooms, and CIFAR-10. + +For creating your own dataset, simply dump all of your images into a directory with ".jpg", ".jpeg", or ".png" extensions. If you wish to train a class-conditional model, name the files like "mylabel1_XXX.jpg", "mylabel2_YYY.jpg", etc., so that the data loader knows that "mylabel1" and "mylabel2" are the labels. Subdirectories will automatically be enumerated as well, so the images can be organized into a recursive structure (although the directory names will be ignored, and the underscore prefixes are used as names). + +The images will automatically be scaled and center-cropped by the data-loading pipeline. Simply pass `--data_dir path/to/images` to the training script, and it will take care of the rest. + +## Training + +To train your model, you should first decide some hyperparameters. We will split up our hyperparameters into three groups: model architecture, diffusion process, and training flags. Here are some reasonable defaults for a baseline: + +``` +MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear" +TRAIN_FLAGS="--lr 1e-4 --batch_size 128" +``` + +Here are some changes we experiment with, and how to set them in the flags: + + * **Learned sigmas:** add `--learn_sigma True` to `MODEL_FLAGS` + * **Cosine schedule:** change `--noise_schedule linear` to `--noise_schedule cosine` + * **Reweighted VLB:** add `--use_kl True` to `DIFFUSION_FLAGS` and add `--schedule_sampler loss-second-moment` to `TRAIN_FLAGS`. + * **Class-conditional:** add `--class_cond True` to `MODEL_FLAGS`. + +Once you have setup your hyper-parameters, you can run an experiment like so: + +``` +python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS +``` + +You may also want to train in a distributed manner. In this case, run the same command with `mpiexec`: + +``` +mpiexec -n $NUM_GPUS python scripts/image_train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS +``` + +When training in a distributed manner, you must manually divide the `--batch_size` argument by the number of ranks. In lieu of distributed training, you may use `--microbatch 16` (or `--microbatch 1` in extreme memory-limited cases) to reduce memory usage. + +The logs and saved models will be written to a logging directory determined by the `OPENAI_LOGDIR` environment variable. If it is not set, then a temporary directory will be created in `/tmp`. + +## Sampling + +The above training script saves checkpoints to `.pt` files in the logging directory. These checkpoints will have names like `ema_0.9999_200000.pt` and `model200000.pt`. You will likely want to sample from the EMA models, since those produce much better samples. + +Once you have a path to your model, you can generate a large batch of samples like so: + +``` +python scripts/image_sample.py --model_path /path/to/model.pt $MODEL_FLAGS $DIFFUSION_FLAGS +``` + +Again, this will save results to a logging directory. Samples are saved as a large `npz` file, where `arr_0` in the file is a large batch of samples. + +Just like for training, you can run `image_sample.py` through MPI to use multiple GPUs and machines. + +You can change the number of sampling steps using the `--timestep_respacing` argument. For example, `--timestep_respacing 250` uses 250 steps to sample. Passing `--timestep_respacing ddim250` is similar, but uses the uniform stride from the [DDIM paper](https://arxiv.org/abs/2010.02502) rather than our stride. + +To sample using [DDIM](https://arxiv.org/abs/2010.02502), pass `--use_ddim True`. + +## Models and Hyperparameters + +This section includes model checkpoints and run flags for the main models in the paper. + +Note that the batch sizes are specified for single-GPU training, even though most of these runs will not naturally fit on a single GPU. To address this, either set `--microbatch` to a small value (e.g. 4) to train on one GPU, or run with MPI and divide `--batch_size` by the number of GPUs. + +Unconditional ImageNet-64 with our `L_hybrid` objective and cosine noise schedule [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_uncond_100M_1500K.pt)]: + +```bash +MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TRAIN_FLAGS="--lr 1e-4 --batch_size 128" +``` + +Unconditional CIFAR-10 with our `L_hybrid` objective and cosine noise schedule [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/cifar10_uncond_50M_500K.pt)]: + +```bash +MODEL_FLAGS="--image_size 32 --num_channels 128 --num_res_blocks 3 --learn_sigma True --dropout 0.3" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TRAIN_FLAGS="--lr 1e-4 --batch_size 128" +``` + +Class-conditional ImageNet-64 model (270M parameters, trained for 250K iterations) [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_cond_270M_250K.pt)]: + +```bash +MODEL_FLAGS="--image_size 64 --num_channels 192 --num_res_blocks 3 --learn_sigma True --class_cond True" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine --rescale_learned_sigmas False --rescale_timesteps False" +TRAIN_FLAGS="--lr 3e-4 --batch_size 2048" +``` + +Upsampling 256x256 model (280M parameters, trained for 500K iterations) [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/upsample_cond_500K.pt)]: + +```bash +MODEL_FLAGS="--num_channels 192 --num_res_blocks 2 --learn_sigma True --class_cond True" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False" +TRAIN_FLAGS="--lr 3e-4 --batch_size 256" +``` + +LSUN bedroom model (lr=1e-4) [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/lsun_uncond_100M_1200K_bs128.pt)]: + +```bash +MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16" +DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False" +TRAIN_FLAGS="--lr 1e-4 --batch_size 128" +``` + +LSUN bedroom model (lr=2e-5) [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/lsun_uncond_100M_2400K_bs64.pt)]: + +```bash +MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --learn_sigma True --use_scale_shift_norm False --attention_resolutions 16" +DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --rescale_learned_sigmas False --rescale_timesteps False --use_scale_shift_norm False" +TRAIN_FLAGS="--lr 2e-5 --batch_size 128" +``` + +Unconditional ImageNet-64 with the `L_vlb` objective and cosine noise schedule [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/imagenet64_uncond_vlb_100M_1500K.pt)]: + +```bash +MODEL_FLAGS="--image_size 64 --num_channels 128 --num_res_blocks 3 --learn_sigma True" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TRAIN_FLAGS="--lr 1e-4 --batch_size 128 --schedule_sampler loss-second-moment" +``` + +Unconditional CIFAR-10 with the `L_vlb` objective and cosine noise schedule [[checkpoint](https://openaipublic.blob.core.windows.net/diffusion/march-2021/cifar10_uncond_vlb_50M_500K.pt)]: + +```bash +MODEL_FLAGS="--image_size 32 --num_channels 128 --num_res_blocks 3 --learn_sigma True --dropout 0.3" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TRAIN_FLAGS="--lr 1e-4 --batch_size 128 --schedule_sampler loss-second-moment" +``` diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..560fe908e4b0ea261d50fb36c72a53fa24f01953 --- /dev/null +++ b/app.py @@ -0,0 +1,246 @@ +""" +A Gradio Blocks Demo App. +Generate a large batch of samples from a super resolution model, given a batch +of samples from a regular model from image_sample.py. +""" + +import gradio as gr +import argparse +import os +import glob + +import blobfile as bf +import numpy as np +import torch as th +import torch.distributed as dist + +from PIL import Image, ImageDraw +from torchvision import utils +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.script_util import ( + pg_model_and_diffusion_defaults, + pg_create_model_and_diffusion, + pgsr_model_and_diffusion_defaults, + pgsr_create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) + +MODEL_FLAGS="--image_size=32 --small_size=32 --large_size=128 --guide_size=128 --num_channels=128 --num_channels2=64 --num_res_blocks=3 --learn_sigma=True --dropout=0.0 --use_attention2=False" +DIFFUSION_FLAGS="--diffusion_steps=4000 --noise_schedule=cosine" +TEST_FLAGS="--batch_size=1 --seed=233 --num_samples=4" +OTHER_FLAGS = '''\ +--timestep_respacing=16 \ +--use_ddim=False \ +--model_path=./danbooru2017_guided_log/ema_0.9999_360000.pt \ +--model_path2=./danbooru2017_guided_sr_log/ema_0.9999_360000.pt''' +OTHER_FLAGS = OTHER_FLAGS.replace('\r\n', ' ').replace('\n', ' ') +flags = OTHER_FLAGS.split(' ') + MODEL_FLAGS.split(' ') + DIFFUSION_FLAGS.split(' ') + TEST_FLAGS.split(' ') + + +def norm_size(img, size=128, add_edges=True): + img = img.convert('L') + w, h = img.size + if w != h: + scale = 1024 / max(img.size) + img = img.resize([int(round(s*scale)) for s in img.size]) + w, h = img.size + max_size = max(w, h) + x0 = (max_size - w) // 2 + y0 = (max_size - h) // 2 + x1 = x0 + w + y1 = y0 + h + canvas = Image.new('L', (max_size,max_size), 255) + canvas.paste(img, (x0,y0,x1,y1)) + + if add_edges: + draw = ImageDraw.Draw(canvas) + draw.line((x0-5,0,x0-1,max_size), fill=0) + draw.line((0,y0-5,max_size,y0-1), fill=0) + draw.line((x1+1,0,x1+5,max_size), fill=0) + draw.line((0,y1+1,max_size,y1+5), fill=0) + + img = canvas + img = img.resize((size,size), resample=Image.LANCZOS) + + return img + + +def create_argparser(): + defaults = dict( + data_dir="", + guide_dir="", + clip_denoised=True, + num_samples=100, + batch_size=4, + use_ddim=False, + base_samples="", + model_path="", + seed=-1, + ) + defaults.update(pg_model_and_diffusion_defaults()) + defaults.update(pgsr_model_and_diffusion_defaults()) + defaults.update(dict( + num_channels2=128, + use_attention2=True, + model_path2="", + )) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +@th.inference_mode() +def main(): + args = create_argparser().parse_args(flags) + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model...") + model, diffusion = pg_create_model_and_diffusion( + **args_to_dict(args, pg_model_and_diffusion_defaults().keys()) + ) + model.load_state_dict( + dist_util.load_state_dict(args.model_path, map_location="cpu") + ) + model.to(dist_util.dev()) + model.eval() + + logger.log("creating model2...") + args.num_channels = args.num_channels2 + args.use_attention = args.use_attention2 + model2, diffusion2 = pgsr_create_model_and_diffusion( + **args_to_dict(args, pgsr_model_and_diffusion_defaults().keys()) + ) + model2.load_state_dict( + dist_util.load_state_dict(args.model_path2, map_location="cpu") + ) + model2.to(dist_util.dev()) + model2.eval() + + def inference(img, seed, add_edges): + th.manual_seed(int(seed)) + sketch = sketch_out = norm_size(img, size=128, add_edges=add_edges) + sketch = np.asarray(sketch).astype(np.float32) / 127.5 - 1 + sketch = th.from_numpy(sketch).float()[None,None].to(dist_util.dev()) + model_kwargs = { "guide": sketch } + sample_fn = ( + diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop + ) + sample = sample_fn( + model, + (args.batch_size, 3, args.image_size, args.image_size), + clip_denoised=args.clip_denoised, + model_kwargs=model_kwargs, + ) + + model_kwargs["low_res"] = sample + sample_fn2 = ( + diffusion2.p_sample_loop if not args.use_ddim else diffusion2.ddim_sample_loop + ) + sample2 = sample_fn2( + model2, + (args.batch_size, 3, args.large_size, args.large_size), + clip_denoised=args.clip_denoised, + model_kwargs=model_kwargs, + ) + out = (sample2[0].clamp(-1,1).cpu().numpy() + 1) / 2 * 255 + out = np.uint8(out) + out = out.transpose([1,2,0]) + out = Image.fromarray(out) + + return sketch_out, out + + with gr.Blocks() as demo: + gr.Markdown('''

Anime-Colorization

+

Colorize your anime sketches with this app.

+This is a Gradio Blocks app of + +HighCWu/pixel-guide-diffusion-for-anime-colorization +.
+(PS: Training Datasets are made from +HighCWu/danbooru-sketch-pair-128x + which processed real anime images to sketches by +SketchKeras. +So the model is not very sensitive to some different styles of sketches, +and the colorized results of such sketches are not very good.) +''') + with gr.Row(): + with gr.Box(): + with gr.Column(): + with gr.Row(): + seed_in = gr.Number( + value=233, + label='Seed' + ) + with gr.Row(): + edges_in = gr.Checkbox( + label="Add Edges" + ) + with gr.Row(): + sketch_in = gr.Image( + type="pil", + label="Sketch" + ) + with gr.Row(): + generate_button = gr.Button('Generate') + with gr.Row(): + gr.Markdown('Click to add example as input.👇') + with gr.Row(): + example_sketch_paths = [[p] for p in sorted(glob.glob('docs/imgs/anime_sketch/*.png'))] + example_sketch = gr.Dataset( + components=[sketch_in], + samples=example_sketch_paths + ) + with gr.Row(): + gr.Markdown('These are expect real outputs.👇') + with gr.Row(): + example_real_paths = [[p] for p in sorted(glob.glob('docs/imgs/anime/*.png'))] + example_real = gr.Dataset( + components=[sketch_in], + samples=example_real_paths + ) + + with gr.Box(): + with gr.Column(): + with gr.Row(): + with gr.Column(): + sketch_out = gr.Image( + type="pil", + label="Input" + ) + with gr.Column(): + colorized_out = gr.Image( + type="pil", + label="Colorization Result" + ) + with gr.Row(): + gr.Markdown( + 'Here are some samples 👇 [top: sketch, center: generated, bottom: real]' + ) + with gr.Row(): + gr.Image( + value="docs/imgs/sample.png", + type="filepath", + interactive=False, + label="Samples" + ) + gr.Markdown( + '
visitor badge
' + ) + + generate_button.click( + inference, inputs=[sketch_in, seed_in, edges_in], outputs=[sketch_out, colorized_out] + ) + example_sketch.click( + fn=lambda examples: gr.Image.update(value=examples[0]), + inputs=example_sketch, + outputs=example_sketch.components + ) + + demo.launch() + +if __name__ == '__main__': + main() diff --git a/danbooru2017_guided_log/ema_0.9999_360000.pt b/danbooru2017_guided_log/ema_0.9999_360000.pt new file mode 100644 index 0000000000000000000000000000000000000000..b45f431cb5d709d7b51621a121f0f4fe6f591ac3 --- /dev/null +++ b/danbooru2017_guided_log/ema_0.9999_360000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b614305acf2d30b7c63bcbc56f646a3ee06579a8430f9c55c8de5014c977f397 +size 210354744 diff --git a/danbooru2017_guided_sr_log/ema_0.9999_360000.pt b/danbooru2017_guided_sr_log/ema_0.9999_360000.pt new file mode 100644 index 0000000000000000000000000000000000000000..e0d6788b51d472a81a39e220637a7b343a66059f --- /dev/null +++ b/danbooru2017_guided_sr_log/ema_0.9999_360000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc1c29c293ad2625cf616f0d1a33e2d708c061ac7f439d3df53e9b22cafe36d7 +size 48757368 diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/datasets/README.md b/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..148cfea9a04f0361543b471772f94a9ce3d4c484 --- /dev/null +++ b/datasets/README.md @@ -0,0 +1,37 @@ +# Downloading datasets + +This directory includes instructions and scripts for downloading ImageNet, LSUN bedrooms, and CIFAR-10 for use in this codebase. + +## ImageNet-64 + +To download unconditional ImageNet-64, go to [this page on image-net.org](http://www.image-net.org/small/download.php) and click on "Train (64x64)". Simply download the file and unzip it, and use the resulting directory as the data directory (the `--data_dir` argument for the training script). + +## Class-conditional ImageNet + +For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. + +Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script: + +``` +for file in *.tar; do tar xf "$file"; rm "$file"; done +``` + +This will extract and remove each tar file in turn. + +Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels. + +## CIFAR-10 + +For CIFAR-10, we created a script [cifar10.py](cifar10.py) that creates `cifar_train` and `cifar_test` directories. These directories contain files named like `truck_49997.png`, so that the class name is discernable to the data loader. + +The `cifar_train` and `cifar_test` directories can be passed directly to the training scripts via the `--data_dir` argument. + +## LSUN bedroom + +To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so: + +``` +python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir +``` + +This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument. diff --git a/datasets/cifar10.py b/datasets/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..38c72f04d2eb18442fb1687b1f4daf51693419a3 --- /dev/null +++ b/datasets/cifar10.py @@ -0,0 +1,43 @@ +import os +import tempfile + +import torchvision +from tqdm.auto import tqdm + +CLASSES = ( + "plane", + "car", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", +) + + +def main(): + for split in ["train", "test"]: + out_dir = f"cifar_{split}" + if os.path.exists(out_dir): + print(f"skipping split {split} since {out_dir} already exists.") + continue + + print("downloading...") + with tempfile.TemporaryDirectory() as tmp_dir: + dataset = torchvision.datasets.CIFAR10( + root=tmp_dir, train=split == "train", download=True + ) + + print("dumping images...") + os.mkdir(out_dir) + for i in tqdm(range(len(dataset))): + image, label = dataset[i] + filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png") + image.save(filename) + + +if __name__ == "__main__": + main() diff --git a/datasets/lsun_bedroom.py b/datasets/lsun_bedroom.py new file mode 100644 index 0000000000000000000000000000000000000000..6a5be22eef8c7434331a76ef5ed7332a98a446ef --- /dev/null +++ b/datasets/lsun_bedroom.py @@ -0,0 +1,54 @@ +""" +Convert an LSUN lmdb database into a directory of images. +""" + +import argparse +import io +import os + +from PIL import Image +import lmdb +import numpy as np + + +def read_images(lmdb_path, image_size): + env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True) + with env.begin(write=False) as transaction: + cursor = transaction.cursor() + for _, webp_data in cursor: + img = Image.open(io.BytesIO(webp_data)) + width, height = img.size + scale = image_size / min(width, height) + img = img.resize( + (int(round(scale * width)), int(round(scale * height))), + resample=Image.BOX, + ) + arr = np.array(img) + h, w, _ = arr.shape + h_off = (h - image_size) // 2 + w_off = (w - image_size) // 2 + arr = arr[h_off : h_off + image_size, w_off : w_off + image_size] + yield arr + + +def dump_images(out_dir, images, prefix): + if not os.path.exists(out_dir): + os.mkdir(out_dir) + for i, img in enumerate(images): + Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png")) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--image-size", help="new image size", type=int, default=256) + parser.add_argument("--prefix", help="class name", type=str, default="bedroom") + parser.add_argument("lmdb_path", help="path to an LSUN lmdb database") + parser.add_argument("out_dir", help="path to output directory") + args = parser.parse_args() + + images = read_images(args.lmdb_path, args.image_size) + dump_images(args.out_dir, images, args.prefix) + + +if __name__ == "__main__": + main() diff --git a/docs/imgs/anime/1000000.png b/docs/imgs/anime/1000000.png new file mode 100644 index 0000000000000000000000000000000000000000..13d1bac479cd7a970b93e6a0baa8e03fd6821be0 Binary files /dev/null and b/docs/imgs/anime/1000000.png differ diff --git a/docs/imgs/anime/1002000.png b/docs/imgs/anime/1002000.png new file mode 100644 index 0000000000000000000000000000000000000000..3eafe1504c1b38dd3db8105ece055b3023527a75 Binary files /dev/null and b/docs/imgs/anime/1002000.png differ diff --git a/docs/imgs/anime/1003000.png b/docs/imgs/anime/1003000.png new file mode 100644 index 0000000000000000000000000000000000000000..afa6d2f12b742a1afc262c2ee2111f23a7c45d91 Binary files /dev/null and b/docs/imgs/anime/1003000.png differ diff --git a/docs/imgs/anime/1004000.png b/docs/imgs/anime/1004000.png new file mode 100644 index 0000000000000000000000000000000000000000..24eebed2ec5a257004140a7e361273b4352b329f Binary files /dev/null and b/docs/imgs/anime/1004000.png differ diff --git a/docs/imgs/anime/1006000.png b/docs/imgs/anime/1006000.png new file mode 100644 index 0000000000000000000000000000000000000000..fd61f1f211ba7dc9c6a2bfe30284db8346d9e75a Binary files /dev/null and b/docs/imgs/anime/1006000.png differ diff --git a/docs/imgs/anime/1012000.png b/docs/imgs/anime/1012000.png new file mode 100644 index 0000000000000000000000000000000000000000..31219c49a9c0e244c12e7e488a89251e1e328709 Binary files /dev/null and b/docs/imgs/anime/1012000.png differ diff --git a/docs/imgs/anime_sketch/1000000.png b/docs/imgs/anime_sketch/1000000.png new file mode 100644 index 0000000000000000000000000000000000000000..2a15ae479eff46b951f735517de1dd8be6916d33 Binary files /dev/null and b/docs/imgs/anime_sketch/1000000.png differ diff --git a/docs/imgs/anime_sketch/1002000.png b/docs/imgs/anime_sketch/1002000.png new file mode 100644 index 0000000000000000000000000000000000000000..6545cfd1f7ce0539cdd94bddcdeda4b924fd07ef Binary files /dev/null and b/docs/imgs/anime_sketch/1002000.png differ diff --git a/docs/imgs/anime_sketch/1003000.png b/docs/imgs/anime_sketch/1003000.png new file mode 100644 index 0000000000000000000000000000000000000000..d5f4a0ed7f2698facbd4da7c25bb2dcc4d490b89 Binary files /dev/null and b/docs/imgs/anime_sketch/1003000.png differ diff --git a/docs/imgs/anime_sketch/1004000.png b/docs/imgs/anime_sketch/1004000.png new file mode 100644 index 0000000000000000000000000000000000000000..354b84b4374226f7b3e7bd9184ee7fe401566ca6 Binary files /dev/null and b/docs/imgs/anime_sketch/1004000.png differ diff --git a/docs/imgs/anime_sketch/1006000.png b/docs/imgs/anime_sketch/1006000.png new file mode 100644 index 0000000000000000000000000000000000000000..27cea614fd5143aea10aca4f9da8f749885d8850 Binary files /dev/null and b/docs/imgs/anime_sketch/1006000.png differ diff --git a/docs/imgs/anime_sketch/1012000.png b/docs/imgs/anime_sketch/1012000.png new file mode 100644 index 0000000000000000000000000000000000000000..67b8cc380c7247045bbe0cb70b36cae0a95a2c0a Binary files /dev/null and b/docs/imgs/anime_sketch/1012000.png differ diff --git a/docs/imgs/sample.png b/docs/imgs/sample.png new file mode 100644 index 0000000000000000000000000000000000000000..f7fe032a08c1e960218e0a616acbeabf401d2e10 Binary files /dev/null and b/docs/imgs/sample.png differ diff --git a/openai.LICENSE b/openai.LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9e84fcbc4d81a1f433c90caf9f1cef373c12edae --- /dev/null +++ b/openai.LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 OpenAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..cdf993adc3fe9a401a84dcdd5a3b7bfa1012e85f --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +libopenmpi-dev \ No newline at end of file diff --git a/pixel_guide_diffusion/__init__.py b/pixel_guide_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9665a0d63f695eab303318d824dad14041c7cde9 --- /dev/null +++ b/pixel_guide_diffusion/__init__.py @@ -0,0 +1,3 @@ +""" +Codebase for "Improved Denoising Diffusion Probabilistic Models". +""" diff --git a/pixel_guide_diffusion/dist_util.py b/pixel_guide_diffusion/dist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f665604d6baaf5df6008f131c86cf0779c8b208a --- /dev/null +++ b/pixel_guide_diffusion/dist_util.py @@ -0,0 +1,82 @@ +""" +Helpers for distributed training. +""" + +import io +import os +import socket + +import blobfile as bf +from mpi4py import MPI +import torch as th +import torch.distributed as dist + +# Change this to reflect your cluster layout. +# The GPU for a given rank is (rank % GPUS_PER_NODE). +GPUS_PER_NODE = 8 + +SETUP_RETRY_COUNT = 3 + + +def setup_dist(): + """ + Setup a distributed process group. + """ + if dist.is_initialized(): + return + + comm = MPI.COMM_WORLD + backend = "gloo" if not th.cuda.is_available() else "nccl" + + if backend == "gloo": + hostname = "localhost" + else: + hostname = socket.gethostbyname(socket.getfqdn()) + os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) + os.environ["RANK"] = str(comm.rank) + os.environ["WORLD_SIZE"] = str(comm.size) + + port = comm.bcast(_find_free_port(), root=0) + os.environ["MASTER_PORT"] = str(port) + dist.init_process_group(backend=backend, init_method="env://") + + +def dev(): + """ + Get the device to use for torch.distributed. + """ + if th.cuda.is_available(): + return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}") + return th.device("cpu") + + +def load_state_dict(path, **kwargs): + """ + Load a PyTorch file without redundant fetches across MPI ranks. + """ + if MPI.COMM_WORLD.Get_rank() == 0: + with bf.BlobFile(path, "rb") as f: + data = f.read() + else: + data = None + data = MPI.COMM_WORLD.bcast(data) + return th.load(io.BytesIO(data), **kwargs) + + +def sync_params(params): + """ + Synchronize a sequence of Tensors across ranks from rank 0. + """ + for p in params: + with th.no_grad(): + dist.broadcast(p, 0) + + +def _find_free_port(): + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + finally: + s.close() diff --git a/pixel_guide_diffusion/fp16_util.py b/pixel_guide_diffusion/fp16_util.py new file mode 100644 index 0000000000000000000000000000000000000000..23e0418153143200a718f56077b3360f30f4c663 --- /dev/null +++ b/pixel_guide_diffusion/fp16_util.py @@ -0,0 +1,76 @@ +""" +Helpers to train with 16-bit precision. +""" + +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + l.bias.data = l.bias.data.float() + + +def make_master_params(model_params): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = _flatten_dense_tensors( + [param.detach().float() for param in model_params] + ) + master_params = nn.Parameter(master_params) + master_params.requires_grad = True + return [master_params] + + +def model_grads_to_master_grads(model_params, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + master_params[0].grad = _flatten_dense_tensors( + [param.grad.data.detach().float() for param in model_params] + ) + + +def master_params_to_model_params(model_params, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + model_params = list(model_params) + + for param, master_param in zip( + model_params, unflatten_master_params(model_params, master_params) + ): + param.detach().copy_(master_param) + + +def unflatten_master_params(model_params, master_params): + """ + Unflatten the master parameters to look like model_params. + """ + return _unflatten_dense_tensors(master_params[0].detach(), model_params) + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() diff --git a/pixel_guide_diffusion/gaussian_diffusion.py b/pixel_guide_diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..403d474f3bc3486dff7618d262f6437b2ab43e5c --- /dev/null +++ b/pixel_guide_diffusion/gaussian_diffusion.py @@ -0,0 +1,841 @@ +""" +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +import enum +import math + +import numpy as np +import torch as th + +from .nn import mean_flat +from .losses import normal_kl, discretized_gaussian_log_likelihood + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def p_sample( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) diff --git a/pixel_guide_diffusion/image_datasets.py b/pixel_guide_diffusion/image_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..2eec69426004e2f325960df7d0ccef79be0453c3 --- /dev/null +++ b/pixel_guide_diffusion/image_datasets.py @@ -0,0 +1,173 @@ +from PIL import Image +import blobfile as bf +from mpi4py import MPI +import numpy as np +from torch.utils.data import DataLoader, Dataset + +import PIL.ImageFile +PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def load_data( + *, data_dir, batch_size, image_size, class_cond=False, guide_size=0, guide_dir=None, crop_size=0, deterministic=False +): + """ + For a dataset, create a generator over (images, kwargs) pairs. + + Each images is an NCHW float tensor, and the kwargs dict contains zero or + more keys, each of which map to a batched Tensor of their own. + The kwargs dict can be used for class labels, in which case the key is "y" + and the values are integer tensors of class labels. + + :param data_dir: a dataset directory. + :param batch_size: the batch size of each returned pair. + :param image_size: the size to which images are resized. + :param class_cond: if True, include a "y" key in returned dicts for class + label. If classes are not available and this is true, an + exception will be raised. + :param guide_size: the size to which images are resized for guide tensors. + :param guide_dir: a dataset directory for guide tensors. + :param crop_size: the size to which images are resized and cropped. + :param deterministic: if True, yield results in a deterministic order. + """ + if not data_dir: + raise ValueError("unspecified data directory") + all_files = _list_image_files_recursively(data_dir) + guide_files = None + if guide_dir: + guide_files = _list_image_files_recursively(guide_dir) + guide_files2 = _list_image_files_recursively('data/danbooru2017/anime_sketch_noise') + classes = None + if class_cond: + # Assume classes are the first part of the filename, + # before an underscore. + class_names = [bf.basename(path).split("_")[0] for path in all_files] + sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} + classes = [sorted_classes[x] for x in class_names] + dataset = ImageDataset( + image_size, + all_files, + guide_resolution=guide_size, + guide_paths=guide_files, + guide_paths2=guide_files2, + crop_resolution=crop_size, + classes=classes, + shard=MPI.COMM_WORLD.Get_rank(), + num_shards=MPI.COMM_WORLD.Get_size(), + ) + if deterministic: + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True + ) + else: + loader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True + ) + while True: + yield from loader + + +def _list_image_files_recursively(data_dir): + results = [] + for entry in sorted(bf.listdir(data_dir)): + full_path = bf.join(data_dir, entry) + ext = entry.split(".")[-1] + if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: + results.append(full_path) + elif bf.isdir(full_path): + results.extend(_list_image_files_recursively(full_path)) + return sorted(results) + + +class ImageDataset(Dataset): + def __init__(self, resolution, image_paths, guide_resolution=0, guide_paths=None, guide_paths2=None, crop_resolution=0, classes=None, shard=0, num_shards=1): + super().__init__() + self.resolution = resolution + self.guide_resolution = guide_resolution + self.local_images = image_paths[shard:][::num_shards] + self.local_guides = guide_paths[shard:][::num_shards] if guide_paths else None + self.local_guides2 = guide_paths2[shard:][::num_shards] if guide_paths else None + self.crop_resolution = crop_resolution if crop_resolution > 0 else resolution + self.local_classes = None if classes is None else classes[shard:][::num_shards] + + def __len__(self): + return len(self.local_images) * 1000000 + + def __getitem__(self, idx): + idx = idx % len(self.local_images) + path = self.local_images[idx] + with bf.BlobFile(path, "rb") as f: + pil_image = Image.open(f) + pil_image.load() + + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + while min(*pil_image.size) >= 2 * self.resolution: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = self.resolution / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image.convert("RGB")) + crop_y = (arr.shape[0] - self.crop_resolution) // 2 + crop_x = (arr.shape[1] - self.crop_resolution) // 2 + arr = arr[crop_y : crop_y + self.crop_resolution, crop_x : crop_x + self.crop_resolution] + arr = arr.astype(np.float32) / 127.5 - 1 + + out_dict = {} + + if self.local_guides: + path = self.local_guides[idx] if np.random.rand() < 0.5 else self.local_guides2[idx] + with bf.BlobFile(path, "rb") as f: + pil_image = Image.open(f) + pil_image.load() + + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + while min(*pil_image.size) >= 2 * self.guide_resolution: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = self.guide_resolution / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + crop_resolution = self.guide_resolution // self.resolution * self.crop_resolution + + guide_arr = np.array(pil_image.convert("L"))[...,None] # np.array(pil_image.convert("RGB")) + + # extra noise + if np.random.rand() < 0.5: + w, h = guide_arr.shape[:2][::-1] + a = np.random.randint(2,12) + mean = np.asarray( + Image.fromarray( + np.random.randint(0,255,[a,a],dtype='uint8') + ).resize([w,h], Image.NEAREST) + ).astype('float32') / 255.0 * 2 - 1 + std = np.asarray( + Image.fromarray( + np.random.randint(0,255,[a,a],dtype='uint8') + ).resize([w, h], Image.NEAREST) + ).astype('float32') / 255.0 * 7.5 + 0.125 + guide_arr = (guide_arr - mean[...,None]) * std[...,None] + + crop_y = (guide_arr.shape[0] - crop_resolution) // 2 + crop_x = (guide_arr.shape[1] - crop_resolution) // 2 + guide_arr = guide_arr[crop_y : crop_y + crop_resolution, crop_x : crop_x + crop_resolution] + guide_arr = guide_arr.astype(np.float32) / 127.5 - 1 + + out_dict["guide"] = np.transpose(guide_arr, [2, 0, 1]).astype('float32') + + if self.local_classes is not None: + out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) + + return np.transpose(arr, [2, 0, 1]), out_dict diff --git a/pixel_guide_diffusion/logger.py b/pixel_guide_diffusion/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d856dcfea6b56a2ee8d37b286887430dbfac30 --- /dev/null +++ b/pixel_guide_diffusion/logger.py @@ -0,0 +1,495 @@ +""" +Logger copied from OpenAI baselines to avoid extra RL-based dependencies: +https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py +""" + +import os +import sys +import shutil +import os.path as osp +import json +import time +import datetime +import tempfile +import warnings +from collections import defaultdict +from contextlib import contextmanager + +DEBUG = 10 +INFO = 20 +WARN = 30 +ERROR = 40 + +DISABLED = 50 + + +class KVWriter(object): + def writekvs(self, kvs): + raise NotImplementedError + + +class SeqWriter(object): + def writeseq(self, seq): + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + def __init__(self, filename_or_file): + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, "wt") + self.own_file = True + else: + assert hasattr(filename_or_file, "read"), ( + "expected file or str, got %s" % filename_or_file + ) + self.file = filename_or_file + self.own_file = False + + def writekvs(self, kvs): + # Create strings for printing + key2str = {} + for (key, val) in sorted(kvs.items()): + if hasattr(val, "__float__"): + valstr = "%-8.3g" % val + else: + valstr = str(val) + key2str[self._truncate(key)] = self._truncate(valstr) + + # Find max widths + if len(key2str) == 0: + print("WARNING: tried to write empty key-value dict") + return + else: + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + + # Write out the data + dashes = "-" * (keywidth + valwidth + 7) + lines = [dashes] + for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append( + "| %s%s | %s%s |" + % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) + ) + lines.append(dashes) + self.file.write("\n".join(lines) + "\n") + + # Flush the output to the file + self.file.flush() + + def _truncate(self, s): + maxlen = 30 + return s[: maxlen - 3] + "..." if len(s) > maxlen else s + + def writeseq(self, seq): + seq = list(seq) + for (i, elem) in enumerate(seq): + self.file.write(elem) + if i < len(seq) - 1: # add space unless this is the last one + self.file.write(" ") + self.file.write("\n") + self.file.flush() + + def close(self): + if self.own_file: + self.file.close() + + +class JSONOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "wt") + + def writekvs(self, kvs): + for k, v in sorted(kvs.items()): + if hasattr(v, "dtype"): + kvs[k] = float(v) + self.file.write(json.dumps(kvs) + "\n") + self.file.flush() + + def close(self): + self.file.close() + + +class CSVOutputFormat(KVWriter): + def __init__(self, filename): + self.file = open(filename, "w+t") + self.keys = [] + self.sep = "," + + def writekvs(self, kvs): + # Add our current row to the history + extra_keys = list(kvs.keys() - self.keys) + extra_keys.sort() + if extra_keys: + self.keys.extend(extra_keys) + self.file.seek(0) + lines = self.file.readlines() + self.file.seek(0) + for (i, k) in enumerate(self.keys): + if i > 0: + self.file.write(",") + self.file.write(k) + self.file.write("\n") + for line in lines[1:]: + self.file.write(line[:-1]) + self.file.write(self.sep * len(extra_keys)) + self.file.write("\n") + for (i, k) in enumerate(self.keys): + if i > 0: + self.file.write(",") + v = kvs.get(k) + if v is not None: + self.file.write(str(v)) + self.file.write("\n") + self.file.flush() + + def close(self): + self.file.close() + + +class TensorBoardOutputFormat(KVWriter): + """ + Dumps key/value pairs into TensorBoard's numeric format. + """ + + def __init__(self, dir): + os.makedirs(dir, exist_ok=True) + self.dir = dir + self.step = 1 + prefix = "events" + path = osp.join(osp.abspath(dir), prefix) + import tensorflow as tf + from tensorflow.python import pywrap_tensorflow + from tensorflow.core.util import event_pb2 + from tensorflow.python.util import compat + + self.tf = tf + self.event_pb2 = event_pb2 + self.pywrap_tensorflow = pywrap_tensorflow + self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) + + def writekvs(self, kvs): + def summary_val(k, v): + kwargs = {"tag": k, "simple_value": float(v)} + return self.tf.Summary.Value(**kwargs) + + summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) + event = self.event_pb2.Event(wall_time=time.time(), summary=summary) + event.step = ( + self.step + ) # is there any reason why you'd want to specify the step? + self.writer.WriteEvent(event) + self.writer.Flush() + self.step += 1 + + def close(self): + if self.writer: + self.writer.Close() + self.writer = None + + +def make_output_format(format, ev_dir, log_suffix=""): + os.makedirs(ev_dir, exist_ok=True) + if format == "stdout": + return HumanOutputFormat(sys.stdout) + elif format == "log": + return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) + elif format == "json": + return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) + elif format == "csv": + return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) + elif format == "tensorboard": + return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) + else: + raise ValueError("Unknown format specified: %s" % (format,)) + + +# ================================================================ +# API +# ================================================================ + + +def logkv(key, val): + """ + Log a value of some diagnostic + Call this once for each diagnostic quantity, each iteration + If called many times, last value will be used. + """ + get_current().logkv(key, val) + + +def logkv_mean(key, val): + """ + The same as logkv(), but if called many times, values averaged. + """ + get_current().logkv_mean(key, val) + + +def logkvs(d): + """ + Log a dictionary of key-value pairs + """ + for (k, v) in d.items(): + logkv(k, v) + + +def dumpkvs(): + """ + Write all of the diagnostics from the current iteration + """ + return get_current().dumpkvs() + + +def getkvs(): + return get_current().name2val + + +def log(*args, level=INFO): + """ + Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). + """ + get_current().log(*args, level=level) + + +def debug(*args): + log(*args, level=DEBUG) + + +def info(*args): + log(*args, level=INFO) + + +def warn(*args): + log(*args, level=WARN) + + +def error(*args): + log(*args, level=ERROR) + + +def set_level(level): + """ + Set logging threshold on current logger. + """ + get_current().set_level(level) + + +def set_comm(comm): + get_current().set_comm(comm) + + +def get_dir(): + """ + Get directory that log files are being written to. + will be None if there is no output directory (i.e., if you didn't call start) + """ + return get_current().get_dir() + + +record_tabular = logkv +dump_tabular = dumpkvs + + +@contextmanager +def profile_kv(scopename): + logkey = "wait_" + scopename + tstart = time.time() + try: + yield + finally: + get_current().name2val[logkey] += time.time() - tstart + + +def profile(n): + """ + Usage: + @profile("my_func") + def my_func(): code + """ + + def decorator_with_name(func): + def func_wrapper(*args, **kwargs): + with profile_kv(n): + return func(*args, **kwargs) + + return func_wrapper + + return decorator_with_name + + +# ================================================================ +# Backend +# ================================================================ + + +def get_current(): + if Logger.CURRENT is None: + _configure_default_logger() + + return Logger.CURRENT + + +class Logger(object): + DEFAULT = None # A logger with no output files. (See right below class definition) + # So that you can still log to the terminal without setting up any output files + CURRENT = None # Current logger being used by the free functions above + + def __init__(self, dir, output_formats, comm=None): + self.name2val = defaultdict(float) # values this iteration + self.name2cnt = defaultdict(int) + self.level = INFO + self.dir = dir + self.output_formats = output_formats + self.comm = comm + + # Logging API, forwarded + # ---------------------------------------- + def logkv(self, key, val): + self.name2val[key] = val + + def logkv_mean(self, key, val): + oldval, cnt = self.name2val[key], self.name2cnt[key] + self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) + self.name2cnt[key] = cnt + 1 + + def dumpkvs(self): + if self.comm is None: + d = self.name2val + else: + d = mpi_weighted_mean( + self.comm, + { + name: (val, self.name2cnt.get(name, 1)) + for (name, val) in self.name2val.items() + }, + ) + if self.comm.rank != 0: + d["dummy"] = 1 # so we don't get a warning about empty dict + out = d.copy() # Return the dict for unit testing purposes + for fmt in self.output_formats: + if isinstance(fmt, KVWriter): + fmt.writekvs(d) + self.name2val.clear() + self.name2cnt.clear() + return out + + def log(self, *args, level=INFO): + if self.level <= level: + self._do_log(args) + + # Configuration + # ---------------------------------------- + def set_level(self, level): + self.level = level + + def set_comm(self, comm): + self.comm = comm + + def get_dir(self): + return self.dir + + def close(self): + for fmt in self.output_formats: + fmt.close() + + # Misc + # ---------------------------------------- + def _do_log(self, args): + for fmt in self.output_formats: + if isinstance(fmt, SeqWriter): + fmt.writeseq(map(str, args)) + + +def get_rank_without_mpi_import(): + # check environment variables here instead of importing mpi4py + # to avoid calling MPI_Init() when this module is imported + for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: + if varname in os.environ: + return int(os.environ[varname]) + return 0 + + +def mpi_weighted_mean(comm, local_name2valcount): + """ + Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 + Perform a weighted average over dicts that are each on a different node + Input: local_name2valcount: dict mapping key -> (value, count) + Returns: key -> mean + """ + all_name2valcount = comm.gather(local_name2valcount) + if comm.rank == 0: + name2sum = defaultdict(float) + name2count = defaultdict(float) + for n2vc in all_name2valcount: + for (name, (val, count)) in n2vc.items(): + try: + val = float(val) + except ValueError: + if comm.rank == 0: + warnings.warn( + "WARNING: tried to compute mean on non-float {}={}".format( + name, val + ) + ) + else: + name2sum[name] += val * count + name2count[name] += count + return {name: name2sum[name] / name2count[name] for name in name2sum} + else: + return {} + + +def configure(dir=None, format_strs=None, comm=None, log_suffix=""): + """ + If comm is provided, average all numerical stats across that comm + """ + if dir is None: + dir = os.getenv("OPENAI_LOGDIR") + if dir is None: + dir = osp.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), + ) + assert isinstance(dir, str) + dir = os.path.expanduser(dir) + os.makedirs(os.path.expanduser(dir), exist_ok=True) + + rank = get_rank_without_mpi_import() + if rank > 0: + log_suffix = log_suffix + "-rank%03i" % rank + + if format_strs is None: + if rank == 0: + format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") + else: + format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") + format_strs = filter(None, format_strs) + output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + + Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) + if output_formats: + log("Logging to %s" % dir) + + +def _configure_default_logger(): + configure() + Logger.DEFAULT = Logger.CURRENT + + +def reset(): + if Logger.CURRENT is not Logger.DEFAULT: + Logger.CURRENT.close() + Logger.CURRENT = Logger.DEFAULT + log("Reset logger") + + +@contextmanager +def scoped_configure(dir=None, format_strs=None, comm=None): + prevlogger = Logger.CURRENT + configure(dir=dir, format_strs=format_strs, comm=comm) + try: + yield + finally: + Logger.CURRENT.close() + Logger.CURRENT = prevlogger + diff --git a/pixel_guide_diffusion/losses.py b/pixel_guide_diffusion/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..251e42e4f36a31bb5e1aeda874b3a45d722000a2 --- /dev/null +++ b/pixel_guide_diffusion/losses.py @@ -0,0 +1,77 @@ +""" +Helpers for various likelihood-based losses. These are ported from the original +Ho et al. diffusion models codebase: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py +""" + +import numpy as np + +import torch as th + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/pixel_guide_diffusion/nn.py b/pixel_guide_diffusion/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..e36fb3ef7af8db46737b66baa5a2db7ea3a874ae --- /dev/null +++ b/pixel_guide_diffusion/nn.py @@ -0,0 +1,191 @@ +""" +Various utilities for neural networks. +""" + +import math + +import torch as th +import torch.nn as nn + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +class SpaceToDepth(nn.Module): + def __init__(self, factor): + super().__init__() + + self.factor = factor + + def forward(self, x): + if self.factor == 1: + return x + + batch, channel, height, width = x.shape + h_fold = height // self.factor + w_fold = width // self.factor + + return ( + x.view(batch, channel, h_fold, self.factor, w_fold, self.factor) + .permute(0, 1, 3, 5, 2, 4) + .reshape(batch, -1, h_fold, w_fold) + ) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(th.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with th.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with th.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = th.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/pixel_guide_diffusion/resample.py b/pixel_guide_diffusion/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..c82eccdcd47c468d41e7cbe02de6a731f2c9bf81 --- /dev/null +++ b/pixel_guide_diffusion/resample.py @@ -0,0 +1,154 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + + Sub-classes should override this method to update the reweighting + using losses from the model. + + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/pixel_guide_diffusion/respace.py b/pixel_guide_diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..045d58df956e6ddb04216e972bffff47c59bf488 --- /dev/null +++ b/pixel_guide_diffusion/respace.py @@ -0,0 +1,122 @@ +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.rescale_timesteps, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/pixel_guide_diffusion/script_util.py b/pixel_guide_diffusion/script_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b85a60adb2a8a6cdf10aec3504ba2f30d7500d1b --- /dev/null +++ b/pixel_guide_diffusion/script_util.py @@ -0,0 +1,537 @@ +import argparse +import inspect + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps +from .unet import PixelGuideSuperResModel, PixelGuideModel, SuperResModel, UNetModel + +NUM_CLASSES = 1000 + + +def model_and_diffusion_defaults(): + """ + Defaults for image training. + """ + return dict( + image_size=64, + num_channels=128, + num_res_blocks=2, + num_heads=4, + num_heads_upsample=-1, + attention_resolutions="16,8", + dropout=0.0, + learn_sigma=False, + sigma_small=False, + class_cond=False, + diffusion_steps=1000, + noise_schedule="linear", + timestep_respacing="", + use_kl=False, + predict_xstart=False, + rescale_timesteps=True, + rescale_learned_sigmas=True, + use_checkpoint=False, + use_scale_shift_norm=True, + use_attention=True, + ) + + +def create_model_and_diffusion( + image_size, + class_cond, + learn_sigma, + sigma_small, + num_channels, + num_res_blocks, + num_heads, + num_heads_upsample, + attention_resolutions, + dropout, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, + use_checkpoint, + use_scale_shift_norm, + use_attention, +): + model = create_model( + image_size, + num_channels, + num_res_blocks, + learn_sigma=learn_sigma, + class_cond=class_cond, + use_checkpoint=use_checkpoint, + attention_resolutions=attention_resolutions, + num_heads=num_heads, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + use_attention=use_attention, + dropout=dropout, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + sigma_small=sigma_small, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return model, diffusion + + +def create_model( + image_size, + num_channels, + num_res_blocks, + learn_sigma, + class_cond, + use_checkpoint, + attention_resolutions, + num_heads, + num_heads_upsample, + use_scale_shift_norm, + use_attention, + dropout, +): + if image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 2, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + elif image_size == 32: + channel_mult = (1, 2, 2, 2) + else: + raise ValueError(f"unsupported image size: {image_size}") + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return UNetModel( + in_channels=3, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + use_attention=use_attention + ) + + +def sr_model_and_diffusion_defaults(): + res = model_and_diffusion_defaults() + res["large_size"] = 256 + res["small_size"] = 64 + arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0] + for k in res.copy().keys(): + if k not in arg_names: + del res[k] + return res + + +def sr_create_model_and_diffusion( + large_size, + small_size, + class_cond, + learn_sigma, + num_channels, + num_res_blocks, + num_heads, + num_heads_upsample, + attention_resolutions, + dropout, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, + use_checkpoint, + use_scale_shift_norm, + use_attention +): + model = sr_create_model( + large_size, + small_size, + num_channels, + num_res_blocks, + learn_sigma=learn_sigma, + class_cond=class_cond, + use_checkpoint=use_checkpoint, + attention_resolutions=attention_resolutions, + num_heads=num_heads, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + use_attention=use_attention, + dropout=dropout, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return model, diffusion + + +def sr_create_model( + large_size, + small_size, + num_channels, + num_res_blocks, + learn_sigma, + class_cond, + use_checkpoint, + attention_resolutions, + num_heads, + num_heads_upsample, + use_scale_shift_norm, + use_attention, + dropout, +): + _ = small_size # hack to prevent unused variable + + if large_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif large_size == 128: + channel_mult = (1, 2, 2, 3, 4) + elif large_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported large size: {large_size}") + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(large_size // int(res)) + + return SuperResModel( + in_channels=3, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + use_attention=use_attention + ) + + + +def pg_model_and_diffusion_defaults(): + res = model_and_diffusion_defaults() + res["image_size"] = 32 + res["guide_size"] = 256 + arg_names = inspect.getfullargspec(pg_create_model_and_diffusion)[0] + for k in res.copy().keys(): + if k not in arg_names: + del res[k] + return res + + +def pg_create_model_and_diffusion( + image_size, + guide_size, + class_cond, + learn_sigma, + num_channels, + num_res_blocks, + num_heads, + num_heads_upsample, + attention_resolutions, + dropout, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, + use_checkpoint, + use_scale_shift_norm, + use_attention +): + model = pg_create_model( + image_size, + guide_size, + num_channels, + num_res_blocks, + learn_sigma=learn_sigma, + class_cond=class_cond, + use_checkpoint=use_checkpoint, + attention_resolutions=attention_resolutions, + num_heads=num_heads, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + dropout=dropout, + use_attention=use_attention + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return model, diffusion + + +def pg_create_model( + image_size, + guide_size, + num_channels, + num_res_blocks, + learn_sigma, + class_cond, + use_checkpoint, + attention_resolutions, + num_heads, + num_heads_upsample, + use_scale_shift_norm, + use_attention, + dropout, +): + + if image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 2, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + elif image_size == 32: + channel_mult = (1, 2, 2, 2) + else: + raise ValueError(f"unsupported image size: {image_size}") + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + guide_fold = guide_size // image_size + + return PixelGuideModel( + in_channels=3, + guide_channels=1, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + use_attention=use_attention, + guide_fold=guide_fold + ) + + +def pgsr_model_and_diffusion_defaults(): + res = model_and_diffusion_defaults() + res["large_size"] = 256 + res["small_size"] = 64 + res["guide_size"] = 256 + arg_names = inspect.getfullargspec(pgsr_create_model_and_diffusion)[0] + for k in res.copy().keys(): + if k not in arg_names: + del res[k] + return res + + +def pgsr_create_model_and_diffusion( + large_size, + small_size, + guide_size, + class_cond, + learn_sigma, + num_channels, + num_res_blocks, + num_heads, + num_heads_upsample, + attention_resolutions, + dropout, + diffusion_steps, + noise_schedule, + timestep_respacing, + use_kl, + predict_xstart, + rescale_timesteps, + rescale_learned_sigmas, + use_checkpoint, + use_scale_shift_norm, + use_attention, +): + model = pgsr_create_model( + large_size, + small_size, + guide_size, + num_channels, + num_res_blocks, + learn_sigma=learn_sigma, + class_cond=class_cond, + use_checkpoint=use_checkpoint, + attention_resolutions=attention_resolutions, + num_heads=num_heads, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + use_attention=use_attention, + dropout=dropout, + ) + diffusion = create_gaussian_diffusion( + steps=diffusion_steps, + learn_sigma=learn_sigma, + noise_schedule=noise_schedule, + use_kl=use_kl, + predict_xstart=predict_xstart, + rescale_timesteps=rescale_timesteps, + rescale_learned_sigmas=rescale_learned_sigmas, + timestep_respacing=timestep_respacing, + ) + return model, diffusion + + +def pgsr_create_model( + large_size, + small_size, + guide_size, + num_channels, + num_res_blocks, + learn_sigma, + class_cond, + use_checkpoint, + attention_resolutions, + num_heads, + num_heads_upsample, + use_scale_shift_norm, + use_attention, + dropout, +): + _ = small_size # hack to prevent unused variable + + if large_size == 256: + channel_mult = (1, 2, 2, 3, 4) + elif large_size == 128: + channel_mult = (1, 2, 2, 2) + else: + raise ValueError(f"unsupported image size: {large_size}") + + attention_ds = [] + for res in attention_resolutions.split(","): + attention_ds.append(large_size // int(res)) + + guide_fold = guide_size // large_size + + return PixelGuideSuperResModel( + in_channels=3, + guide_channels=1, + model_channels=num_channels, + out_channels=(3 if not learn_sigma else 6), + num_res_blocks=num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=dropout, + channel_mult=channel_mult, + num_classes=(NUM_CLASSES if class_cond else None), + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_heads_upsample=num_heads_upsample, + use_scale_shift_norm=use_scale_shift_norm, + use_attention=use_attention, + guide_fold=guide_fold + ) + + +def create_gaussian_diffusion( + *, + steps=1000, + learn_sigma=False, + sigma_small=False, + noise_schedule="linear", + use_kl=False, + predict_xstart=False, + rescale_timesteps=False, + rescale_learned_sigmas=False, + timestep_respacing="", +): + betas = gd.get_named_beta_schedule(noise_schedule, steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if not timestep_respacing: + timestep_respacing = [steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + ) + + +def add_dict_to_argparser(parser, default_dict): + for k, v in default_dict.items(): + v_type = type(v) + if v is None: + v_type = str + elif isinstance(v, bool): + v_type = str2bool + parser.add_argument(f"--{k}", default=v, type=v_type) + + +def args_to_dict(args, keys): + return {k: getattr(args, k) for k in keys} + + +def str2bool(v): + """ + https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("boolean value expected") diff --git a/pixel_guide_diffusion/train_util.py b/pixel_guide_diffusion/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..1867604145736352dc51ab05b6caae8b541a6ebb --- /dev/null +++ b/pixel_guide_diffusion/train_util.py @@ -0,0 +1,356 @@ +import copy +import functools +import os + +import blobfile as bf +import numpy as np +import torch as th +import torch.distributed as dist +from torch.nn.parallel.distributed import DistributedDataParallel as DDP +from torch.optim import AdamW + +from . import dist_util, logger +from .fp16_util import ( + make_master_params, + master_params_to_model_params, + model_grads_to_master_grads, + unflatten_master_params, + zero_grad, +) +from .nn import update_ema +from .resample import LossAwareSampler, UniformSampler + +# For ImageNet experiments, this was a good default value. +# We found that the lg_loss_scale quickly climbed to +# 20-21 within the first ~1K steps of training. +INITIAL_LOG_LOSS_SCALE = 20.0 + + +class TrainLoop: + def __init__( + self, + *, + model, + diffusion, + data, + batch_size, + microbatch, + lr, + ema_rate, + log_interval, + save_interval, + resume_checkpoint, + use_fp16=False, + fp16_scale_growth=1e-3, + schedule_sampler=None, + weight_decay=0.0, + lr_anneal_steps=0, + ): + self.model = model + self.diffusion = diffusion + self.data = data + self.batch_size = batch_size + self.microbatch = microbatch if microbatch > 0 else batch_size + self.lr = lr + self.ema_rate = ( + [ema_rate] + if isinstance(ema_rate, float) + else [float(x) for x in ema_rate.split(",")] + ) + self.log_interval = log_interval + self.save_interval = save_interval + self.resume_checkpoint = resume_checkpoint + self.use_fp16 = use_fp16 + self.fp16_scale_growth = fp16_scale_growth + self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) + self.weight_decay = weight_decay + self.lr_anneal_steps = lr_anneal_steps + + self.step = 0 + self.resume_step = 0 + self.global_batch = self.batch_size * dist.get_world_size() + + self.model_params = list(self.model.parameters()) + self.master_params = self.model_params + self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE + self.sync_cuda = th.cuda.is_available() + + self._load_and_sync_parameters() + if self.use_fp16: + self._setup_fp16() + + self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay) + if self.resume_step: + self._load_optimizer_state() + # Model was resumed, either due to a restart or a checkpoint + # being specified at the command line. + self.ema_params = [ + self._load_ema_parameters(rate) for rate in self.ema_rate + ] + else: + self.ema_params = [ + copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate)) + ] + + if th.cuda.is_available(): + self.use_ddp = True + self.ddp_model = DDP( + self.model, + device_ids=[dist_util.dev()], + output_device=dist_util.dev(), + broadcast_buffers=False, + bucket_cap_mb=128, + find_unused_parameters=False, + ) + else: + if dist.get_world_size() > 1: + logger.warn( + "Distributed training requires CUDA. " + "Gradients will not be synchronized properly!" + ) + self.use_ddp = False + self.ddp_model = self.model + + def _load_and_sync_parameters(self): + resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + + if resume_checkpoint: + self.resume_step = parse_resume_step_from_filename(resume_checkpoint) + if dist.get_rank() == 0: + logger.log(f"loading model from checkpoint: {resume_checkpoint}...") + self.model.load_state_dict( + dist_util.load_state_dict( + resume_checkpoint, map_location=dist_util.dev() + ) + ) + + dist_util.sync_params(self.model.parameters()) + + def _load_ema_parameters(self, rate): + ema_params = copy.deepcopy(self.master_params) + + main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) + if ema_checkpoint: + if dist.get_rank() == 0: + logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") + state_dict = dist_util.load_state_dict( + ema_checkpoint, map_location=dist_util.dev() + ) + ema_params = self._state_dict_to_master_params(state_dict) + + dist_util.sync_params(ema_params) + return ema_params + + def _load_optimizer_state(self): + main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint + opt_checkpoint = bf.join( + bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" + ) + if bf.exists(opt_checkpoint): + logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") + state_dict = dist_util.load_state_dict( + opt_checkpoint, map_location=dist_util.dev() + ) + self.opt.load_state_dict(state_dict) + + def _setup_fp16(self): + self.master_params = make_master_params(self.model_params) + self.model.convert_to_fp16() + + def run_loop(self): + while ( + not self.lr_anneal_steps + or self.step + self.resume_step < self.lr_anneal_steps + ): + batch, cond = next(self.data) + self.run_step(batch, cond) + if self.step % self.log_interval == 0: + logger.dumpkvs() + if self.step % self.save_interval == 0: + self.save() + # Run for a finite amount of time in integration tests. + if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: + return + self.step += 1 + # Save the last checkpoint if it wasn't already saved. + if (self.step - 1) % self.save_interval != 0: + self.save() + + def run_step(self, batch, cond): + self.forward_backward(batch, cond) + if self.use_fp16: + self.optimize_fp16() + else: + self.optimize_normal() + self.log_step() + + def forward_backward(self, batch, cond): + zero_grad(self.model_params) + for i in range(0, batch.shape[0], self.microbatch): + micro = batch[i : i + self.microbatch].to(dist_util.dev()) + micro_cond = { + k: v[i : i + self.microbatch].to(dist_util.dev()) + for k, v in cond.items() + } + last_batch = (i + self.microbatch) >= batch.shape[0] + t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) + + compute_losses = functools.partial( + self.diffusion.training_losses, + self.ddp_model, + micro, + t, + model_kwargs=micro_cond, + ) + + if last_batch or not self.use_ddp: + losses = compute_losses() + else: + with self.ddp_model.no_sync(): + losses = compute_losses() + + if isinstance(self.schedule_sampler, LossAwareSampler): + self.schedule_sampler.update_with_local_losses( + t, losses["loss"].detach() + ) + + loss = (losses["loss"] * weights).mean() + log_loss_dict( + self.diffusion, t, {k: v * weights for k, v in losses.items()} + ) + if self.use_fp16: + loss_scale = 2 ** self.lg_loss_scale + (loss * loss_scale).backward() + else: + loss.backward() + + def optimize_fp16(self): + if any(not th.isfinite(p.grad).all() for p in self.model_params): + self.lg_loss_scale -= 1 + logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") + return + + model_grads_to_master_grads(self.model_params, self.master_params) + self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) + self._log_grad_norm() + self._anneal_lr() + self.opt.step() + for rate, params in zip(self.ema_rate, self.ema_params): + update_ema(params, self.master_params, rate=rate) + master_params_to_model_params(self.model_params, self.master_params) + self.lg_loss_scale += self.fp16_scale_growth + + def optimize_normal(self): + self._log_grad_norm() + self._anneal_lr() + self.opt.step() + for rate, params in zip(self.ema_rate, self.ema_params): + update_ema(params, self.master_params, rate=rate) + + def _log_grad_norm(self): + sqsum = 0.0 + for p in self.master_params: + sqsum += (p.grad ** 2).sum().item() + logger.logkv_mean("grad_norm", np.sqrt(sqsum)) + + def _anneal_lr(self): + if not self.lr_anneal_steps: + return + frac_done = (self.step + self.resume_step) / self.lr_anneal_steps + lr = self.lr * (1 - frac_done) + for param_group in self.opt.param_groups: + param_group["lr"] = lr + + def log_step(self): + logger.logkv("step", self.step + self.resume_step) + logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) + if self.use_fp16: + logger.logkv("lg_loss_scale", self.lg_loss_scale) + + def save(self): + def save_checkpoint(rate, params): + state_dict = self._master_params_to_state_dict(params) + if dist.get_rank() == 0: + logger.log(f"saving model {rate}...") + if not rate: + filename = f"model{(self.step+self.resume_step):06d}.pt" + else: + filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" + with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: + th.save(state_dict, f) + + save_checkpoint(0, self.master_params) + for rate, params in zip(self.ema_rate, self.ema_params): + save_checkpoint(rate, params) + + if dist.get_rank() == 0: + with bf.BlobFile( + bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), + "wb", + ) as f: + th.save(self.opt.state_dict(), f) + + dist.barrier() + + def _master_params_to_state_dict(self, master_params): + if self.use_fp16: + master_params = unflatten_master_params( + self.model.parameters(), master_params + ) + state_dict = self.model.state_dict() + for i, (name, _value) in enumerate(self.model.named_parameters()): + assert name in state_dict + state_dict[name] = master_params[i] + return state_dict + + def _state_dict_to_master_params(self, state_dict): + params = [state_dict[name] for name, _ in self.model.named_parameters()] + if self.use_fp16: + return make_master_params(params) + else: + return params + + +def parse_resume_step_from_filename(filename): + """ + Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the + checkpoint's number of steps. + """ + split = filename.split("model") + if len(split) < 2: + return 0 + split1 = split[-1].split(".")[0] + try: + return int(split1) + except ValueError: + return 0 + + +def get_blob_logdir(): + return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir()) + + +def find_resume_checkpoint(): + # On your infrastructure, you may want to override this to automatically + # discover the latest checkpoint on your blob storage, etc. + return None + + +def find_ema_checkpoint(main_checkpoint, step, rate): + if main_checkpoint is None: + return None + filename = f"ema_{rate}_{(step):06d}.pt" + path = bf.join(bf.dirname(main_checkpoint), filename) + if bf.exists(path): + return path + return None + + +def log_loss_dict(diffusion, ts, losses): + for key, values in losses.items(): + logger.logkv_mean(key, values.mean().item()) + # Log the quantiles (four quartiles, in particular). + for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): + quartile = int(4 * sub_t / diffusion.num_timesteps) + logger.logkv_mean(f"{key}_q{quartile}", sub_loss) diff --git a/pixel_guide_diffusion/unet.py b/pixel_guide_diffusion/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..4286fdf522502a3f76aef8dcd8c1ba507fc927c6 --- /dev/null +++ b/pixel_guide_diffusion/unet.py @@ -0,0 +1,594 @@ +from abc import abstractmethod + +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .fp16_util import convert_module_to_f16, convert_module_to_f32 +from .nn import ( + SiLU, + SpaceToDepth, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, + checkpoint, +) + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2): + super().__init__() + self.channels = channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, channels, channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2): + super().__init__() + self.channels = channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) + else: + self.op = avg_pool_nd(stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + self.emb_layers = nn.Sequential( + SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__(self, channels, num_heads=1, use_checkpoint=False): + super().__init__() + self.channels = channels + self.num_heads = num_heads + self.use_checkpoint = use_checkpoint + + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + self.attention = QKVAttention() + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) + h = self.attention(qkv) + h = h.reshape(b, -1, h.shape[-1]) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention. + """ + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. + :return: an [N x C x T] tensor after attention. + """ + ch = qkv.shape[1] // 3 + q, k, v = th.split(qkv, ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + return th.einsum("bts,bcs->bct", weight, v) + + @staticmethod + def count_flops(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + + Meant to be used like: + + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + num_heads=1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + use_attention=True + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, use_checkpoint=use_checkpoint, num_heads=num_heads + ) if use_attention else nn.Sequential() + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + self.input_blocks.append( + TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)) + ) + input_block_chans.append(ch) + ds *= 2 + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads) if use_attention else nn.Sequential(), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + layers = [ + ResBlock( + ch + input_block_chans.pop(), + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + ) if use_attention else nn.Sequential() + ) + if level and i == num_res_blocks: + layers.append(Upsample(ch, conv_resample, dims=dims)) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + + self.out = nn.Sequential( + normalization(ch), + SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + @property + def inner_dtype(self): + """ + Get the dtype used by the torso of the model. + """ + return next(self.input_blocks.parameters()).dtype + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.inner_dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + cat_in = th.cat([h, hs.pop()], dim=1) + h = module(cat_in, emb) + h = h.type(x.dtype) + return self.out(h) + + def get_feature_vectors(self, x, timesteps, y=None): + """ + Apply the model and return all of the intermediate tensors. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: a dict with the following keys: + - 'down': a list of hidden state tensors from downsampling. + - 'middle': the tensor of the output of the lowest-resolution + block in the model. + - 'up': a list of hidden state tensors from upsampling. + """ + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + result = dict(down=[], up=[]) + h = x.type(self.inner_dtype) + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + result["down"].append(h.type(x.dtype)) + h = self.middle_block(h, emb) + result["middle"] = h.type(x.dtype) + for module in self.output_blocks: + cat_in = th.cat([h, hs.pop()], dim=1) + h = module(cat_in, emb) + result["up"].append(h.type(x.dtype)) + return result + + +class SuperResModel(UNetModel): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, in_channels, *args, **kwargs): + super().__init__(in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) + + def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs): + _, new_height, new_width, _ = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().get_feature_vectors(x, timesteps, **kwargs) + + +class PixelGuideModel(UNetModel): + """ + A UNetModel that need a guide tensor which has the same height and width with the input tensor. + + Expects an extra kwarg `guide` as a condition for the model. + """ + + def __init__(self, in_channels, guide_channels, *args, guide_fold=1, **kwargs): + super().__init__(in_channels + guide_channels * guide_fold**2, *args, **kwargs) + + self.guide_folder = SpaceToDepth(guide_fold) + + def forward(self, x, timesteps, guide=None, **kwargs): + guide = self.guide_folder(guide) + x = th.cat([x, guide], dim=1) + return super().forward(x, timesteps, **kwargs) + + def get_feature_vectors(self, x, timesteps, guide=None, **kwargs): + guide = self.guide_folder(guide) + x = th.cat([x, guide], dim=1) + return super().get_feature_vectors(x, timesteps, **kwargs) + + +class PixelGuideSuperResModel(PixelGuideModel): + """ + A PixelGuideModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, in_channels, *args, **kwargs): + super().__init__(in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) + + def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs): + _, new_height, new_width, _ = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().get_feature_vectors(x, timesteps, **kwargs) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4d6e7c08e7a4395ff65d5cf0b35673e952fb3146 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +blobfile +mpi4py +gradio==3.0.5 +urllib3==1.24.3 +torch +torchvision diff --git a/scripts/cascaded_pixel_guide_sample.py b/scripts/cascaded_pixel_guide_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..c765b3332531488b0312326b2d0282cec3c00f57 --- /dev/null +++ b/scripts/cascaded_pixel_guide_sample.py @@ -0,0 +1,148 @@ +""" +Generate a large batch of samples from a super resolution model, given a batch +of samples from a regular model from image_sample.py. +""" + +import argparse +import os + +import blobfile as bf +import numpy as np +import torch as th +import torch.distributed as dist + +from torchvision import utils +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.script_util import ( + pg_model_and_diffusion_defaults, + pg_create_model_and_diffusion, + pgsr_model_and_diffusion_defaults, + pgsr_create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model...") + model, diffusion = pg_create_model_and_diffusion( + **args_to_dict(args, pg_model_and_diffusion_defaults().keys()) + ) + model.load_state_dict( + dist_util.load_state_dict(args.model_path, map_location="cpu") + ) + model.to(dist_util.dev()) + model.eval() + + logger.log("creating model2...") + args.num_channels = args.num_channels2 + args.use_attention = args.use_attention2 + model2, diffusion2 = pgsr_create_model_and_diffusion( + **args_to_dict(args, pgsr_model_and_diffusion_defaults().keys()) + ) + model2.load_state_dict( + dist_util.load_state_dict(args.model_path2, map_location="cpu") + ) + model2.to(dist_util.dev()) + model2.eval() + + logger.log("creating data loader...") + data = load_data( + data_dir=args.data_dir, + batch_size=args.batch_size, + image_size=args.large_size, + class_cond=args.class_cond, + guide_dir=args.guide_dir, + guide_size=args.guide_size, + deterministic=True, + ) + + if args.seed > -1: + th.manual_seed(args.seed) + + logger.log("creating samples...") + os.makedirs('sample', exist_ok=True) + i = 0 + while i * args.batch_size < args.num_samples: + if dist.get_rank() == 0: + target, model_kwargs = next(data) + target = target.to(dist_util.dev()) + model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} + + with th.no_grad(): + sample_fn = ( + diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop + ) + sample = sample_fn( + model, + (args.batch_size, 3, args.image_size, args.image_size), + clip_denoised=args.clip_denoised, + model_kwargs=model_kwargs, + ) + + model_kwargs["low_res"] = sample + sample_fn2 = ( + diffusion2.p_sample_loop if not args.use_ddim else diffusion2.ddim_sample_loop + ) + sample2 = sample_fn2( + model2, + (args.batch_size, 3, args.large_size, args.large_size), + clip_denoised=args.clip_denoised, + model_kwargs=model_kwargs, + ) + + guide = model_kwargs["guide"] + h, w = guide.shape[2:] + guide = guide.clamp(-1,1).repeat(1,3,1,1) + sample = th.nn.functional.interpolate(sample.clamp(-1,1), size=(h, w)) + sample2 = th.nn.functional.interpolate(sample2.clamp(-1,1), size=(h, w)) + target = th.nn.functional.interpolate(target.clamp(-1,1), size=(h, w)) + + # images = th.cat([guide, sample, sample2, target], 0) + images = th.cat([guide, sample2, target], 0) + utils.save_image( + images, + f"sample/{str(i).zfill(6)}.png", + nrow=args.batch_size, + normalize=True, + range=(-1, 1), + ) + + i += 1 + logger.log(f"created {i * args.batch_size} samples") + + logger.log("sampling complete") + + +def create_argparser(): + defaults = dict( + data_dir="", + guide_dir="", + clip_denoised=True, + num_samples=100, + batch_size=4, + use_ddim=False, + base_samples="", + model_path="", + seed=-1, + ) + defaults.update(pg_model_and_diffusion_defaults()) + defaults.update(pgsr_model_and_diffusion_defaults()) + defaults.update(dict( + num_channels2=128, + use_attention2=True, + model_path2="", + )) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/image_nll.py b/scripts/image_nll.py new file mode 100644 index 0000000000000000000000000000000000000000..2b72bfd3810d63270a873f7889dddfd2512387b3 --- /dev/null +++ b/scripts/image_nll.py @@ -0,0 +1,96 @@ +""" +Approximate the bits/dimension for an image model. +""" + +import argparse +import os + +import numpy as np +import torch.distributed as dist + +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.script_util import ( + model_and_diffusion_defaults, + create_model_and_diffusion, + add_dict_to_argparser, + args_to_dict, +) + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model and diffusion...") + model, diffusion = create_model_and_diffusion( + **args_to_dict(args, model_and_diffusion_defaults().keys()) + ) + model.load_state_dict( + dist_util.load_state_dict(args.model_path, map_location="cpu") + ) + model.to(dist_util.dev()) + model.eval() + + logger.log("creating data loader...") + data = load_data( + data_dir=args.data_dir, + batch_size=args.batch_size, + image_size=args.image_size, + class_cond=args.class_cond, + deterministic=True, + ) + + logger.log("evaluating...") + run_bpd_evaluation(model, diffusion, data, args.num_samples, args.clip_denoised) + + +def run_bpd_evaluation(model, diffusion, data, num_samples, clip_denoised): + all_bpd = [] + all_metrics = {"vb": [], "mse": [], "xstart_mse": []} + num_complete = 0 + while num_complete < num_samples: + batch, model_kwargs = next(data) + batch = batch.to(dist_util.dev()) + model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} + minibatch_metrics = diffusion.calc_bpd_loop( + model, batch, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + + for key, term_list in all_metrics.items(): + terms = minibatch_metrics[key].mean(dim=0) / dist.get_world_size() + dist.all_reduce(terms) + term_list.append(terms.detach().cpu().numpy()) + + total_bpd = minibatch_metrics["total_bpd"] + total_bpd = total_bpd.mean() / dist.get_world_size() + dist.all_reduce(total_bpd) + all_bpd.append(total_bpd.item()) + num_complete += dist.get_world_size() * batch.shape[0] + + logger.log(f"done {num_complete} samples: bpd={np.mean(all_bpd)}") + + if dist.get_rank() == 0: + for name, terms in all_metrics.items(): + out_path = os.path.join(logger.get_dir(), f"{name}_terms.npz") + logger.log(f"saving {name} terms to {out_path}") + np.savez(out_path, np.mean(np.stack(terms), axis=0)) + + dist.barrier() + logger.log("evaluation complete") + + +def create_argparser(): + defaults = dict( + data_dir="", clip_denoised=True, num_samples=1000, batch_size=1, model_path="" + ) + defaults.update(model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/image_sample.py b/scripts/image_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..289e06f074436bae8d9daf970315150e20f5a4d6 --- /dev/null +++ b/scripts/image_sample.py @@ -0,0 +1,106 @@ +""" +Generate a large batch of image samples from a model and save them as a large +numpy array. This can be used to produce samples for FID evaluation. +""" + +import argparse +import os + +import numpy as np +import torch as th +import torch.distributed as dist + +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.script_util import ( + NUM_CLASSES, + model_and_diffusion_defaults, + create_model_and_diffusion, + add_dict_to_argparser, + args_to_dict, +) + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model and diffusion...") + model, diffusion = create_model_and_diffusion( + **args_to_dict(args, model_and_diffusion_defaults().keys()) + ) + model.load_state_dict( + dist_util.load_state_dict(args.model_path, map_location="cpu") + ) + model.to(dist_util.dev()) + model.eval() + + logger.log("sampling...") + all_images = [] + all_labels = [] + while len(all_images) * args.batch_size < args.num_samples: + model_kwargs = {} + if args.class_cond: + classes = th.randint( + low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() + ) + model_kwargs["y"] = classes + sample_fn = ( + diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop + ) + sample = sample_fn( + model, + (args.batch_size, 3, args.image_size, args.image_size), + clip_denoised=args.clip_denoised, + model_kwargs=model_kwargs, + ) + sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + + gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] + dist.all_gather(gathered_samples, sample) # gather not supported with NCCL + all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) + if args.class_cond: + gathered_labels = [ + th.zeros_like(classes) for _ in range(dist.get_world_size()) + ] + dist.all_gather(gathered_labels, classes) + all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) + logger.log(f"created {len(all_images) * args.batch_size} samples") + + arr = np.concatenate(all_images, axis=0) + arr = arr[: args.num_samples] + if args.class_cond: + label_arr = np.concatenate(all_labels, axis=0) + label_arr = label_arr[: args.num_samples] + if dist.get_rank() == 0: + shape_str = "x".join([str(x) for x in arr.shape]) + out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") + logger.log(f"saving to {out_path}") + if args.class_cond: + np.savez(out_path, arr, label_arr) + else: + np.savez(out_path, arr) + + dist.barrier() + logger.log("sampling complete") + + +def create_argparser(): + defaults = dict( + clip_denoised=True, + num_samples=10000, + batch_size=16, + use_ddim=False, + model_path="", + ) + defaults.update(model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/image_train.py b/scripts/image_train.py new file mode 100644 index 0000000000000000000000000000000000000000..eccdb3980699fd513bb1d01e89954fc2000d14da --- /dev/null +++ b/scripts/image_train.py @@ -0,0 +1,83 @@ +""" +Train a diffusion model on images. +""" + +import argparse + +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.resample import create_named_schedule_sampler +from pixel_guide_diffusion.script_util import ( + model_and_diffusion_defaults, + create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) +from pixel_guide_diffusion.train_util import TrainLoop + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model and diffusion...") + model, diffusion = create_model_and_diffusion( + **args_to_dict(args, model_and_diffusion_defaults().keys()) + ) + model.to(dist_util.dev()) + schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) + + logger.log("creating data loader...") + data = load_data( + data_dir=args.data_dir, + batch_size=args.batch_size, + image_size=args.image_size, + class_cond=args.class_cond, + ) + + logger.log("training...") + TrainLoop( + model=model, + diffusion=diffusion, + data=data, + batch_size=args.batch_size, + microbatch=args.microbatch, + lr=args.lr, + ema_rate=args.ema_rate, + log_interval=args.log_interval, + save_interval=args.save_interval, + resume_checkpoint=args.resume_checkpoint, + use_fp16=args.use_fp16, + fp16_scale_growth=args.fp16_scale_growth, + schedule_sampler=schedule_sampler, + weight_decay=args.weight_decay, + lr_anneal_steps=args.lr_anneal_steps, + ).run_loop() + + +def create_argparser(): + defaults = dict( + data_dir="", + schedule_sampler="uniform", + lr=1e-4, + weight_decay=0.0, + lr_anneal_steps=0, + batch_size=1, + microbatch=-1, # -1 disables microbatches + ema_rate="0.9999", # comma-separated list of EMA values + log_interval=10, + save_interval=10000, + resume_checkpoint="", + use_fp16=False, + fp16_scale_growth=1e-3, + ) + defaults.update(model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/pixel_guide_sample.py b/scripts/pixel_guide_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..dea18e9479f64014711a01b9d1a2dd58f7e5c985 --- /dev/null +++ b/scripts/pixel_guide_sample.py @@ -0,0 +1,111 @@ +""" +Generate a large batch of samples from a super resolution model, given a batch +of samples from a regular model from image_sample.py. +""" + +import argparse +import os + +import blobfile as bf +import numpy as np +import torch as th +import torch.distributed as dist + +from torchvision import utils +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.script_util import ( + pg_model_and_diffusion_defaults, + pg_create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model...") + model, diffusion = pg_create_model_and_diffusion( + **args_to_dict(args, pg_model_and_diffusion_defaults().keys()) + ) + model.load_state_dict( + dist_util.load_state_dict(args.model_path, map_location="cpu") + ) + model.to(dist_util.dev()) + model.eval() + + logger.log("creating data loader...") + data = load_data( + data_dir=args.data_dir, + batch_size=args.batch_size, + image_size=args.image_size, + class_cond=args.class_cond, + guide_dir=args.guide_dir, + guide_size=args.guide_size, + deterministic=True, + ) + + logger.log("creating samples...") + os.makedirs('sample', exist_ok=True) + i = 0 + while i * args.batch_size < args.num_samples: + if dist.get_rank() == 0: + target, model_kwargs = next(data) + target = target.to(dist_util.dev()) + model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} + + with th.no_grad(): + sample_fn = ( + diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop + ) + sample = sample_fn( + model, + (args.batch_size, 3, args.image_size, args.image_size), + clip_denoised=args.clip_denoised, + model_kwargs=model_kwargs, + ) + + guide = model_kwargs["guide"] + h, w = guide.shape[2:] + guide = guide.clamp(-1,1).repeat(1,3,1,1) + sample = th.nn.functional.interpolate(sample.clamp(-1,1), size=(h, w)) + target = th.nn.functional.interpolate(target.clamp(-1,1), size=(h, w)) + + images = th.cat([guide, sample, target], 0) + utils.save_image( + images, + f"sample/{str(i).zfill(6)}.png", + nrow=args.batch_size, + normalize=True, + range=(-1, 1), + ) + + i += 1 + logger.log(f"created {i * args.batch_size} samples") + + logger.log("sampling complete") + + +def create_argparser(): + defaults = dict( + data_dir="", + guide_dir="", + clip_denoised=True, + num_samples=100, + batch_size=4, + use_ddim=False, + base_samples="", + model_path="", + ) + defaults.update(pg_model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/pixel_guide_super_res_sample.py b/scripts/pixel_guide_super_res_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4d9b3fdd1e2e7f543dd0508ae9709e07c82f02 --- /dev/null +++ b/scripts/pixel_guide_super_res_sample.py @@ -0,0 +1,133 @@ +""" +Generate a large batch of samples from a super resolution model, given a batch +of samples from a regular model from image_sample.py. +""" + +import argparse +import os + +import blobfile as bf +import numpy as np +import torch as th +import torch.distributed as dist + +from torchvision import utils +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.script_util import ( + pgsr_model_and_diffusion_defaults, + pgsr_create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model...") + model, diffusion = pgsr_create_model_and_diffusion( + **args_to_dict(args, pgsr_model_and_diffusion_defaults().keys()) + ) + model.load_state_dict( + dist_util.load_state_dict(args.model_path, map_location="cpu") + ) + model.to(dist_util.dev()) + model.eval() + + logger.log("creating data loader...") + data = load_superres_data( + args.data_dir, + args.batch_size, + large_size=args.large_size, + small_size=args.small_size, + class_cond=args.class_cond, + guide_dir=args.guide_dir, + guide_size=args.guide_size, + crop_size=args.crop_size, + deterministic=True, + ) + + logger.log("creating samples...") + os.makedirs('sample', exist_ok=True) + i = 0 + while i * args.batch_size < args.num_samples: + if dist.get_rank() == 0: + target, model_kwargs = next(data) + target = target.to(dist_util.dev()) + model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} + model_kwargs["low_res"] = th.nn.functional.interpolate(target, args.small_size, mode="area").detach() + + with th.no_grad(): + sample_fn = ( + diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop + ) + sample = sample_fn( + model, + (args.batch_size, 3, args.crop_size, args.crop_size), + clip_denoised=args.clip_denoised, + model_kwargs=model_kwargs, + ) + + guide = model_kwargs["guide"] + low_res = model_kwargs["low_res"] + h, w = guide.shape[2:] + guide = guide.clamp(-1,1).repeat(1,3,1,1) + low_res = th.nn.functional.interpolate(low_res.clamp(-1,1), size=(h, w)) + sample = th.nn.functional.interpolate(sample.clamp(-1,1), size=(h, w)) + target = th.nn.functional.interpolate(target.clamp(-1,1), size=(h, w)) + + images = th.cat([guide, low_res, sample, target], 0) + utils.save_image( + images, + f"sample/{str(i).zfill(6)}.png", + nrow=args.batch_size, + normalize=True, + range=(-1, 1), + ) + + i += 1 + logger.log(f"created {i * args.batch_size} samples") + + logger.log("sampling complete") + + +def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False, guide_dir='', guide_size=0, crop_size=0, deterministic=False): + data = load_data( + data_dir=data_dir, + batch_size=batch_size, + image_size=large_size, + class_cond=class_cond, + guide_dir=guide_dir, + guide_size=guide_size, + crop_size=crop_size, + deterministic=deterministic, + ) + for large_batch, model_kwargs in data: + model_kwargs["low_res"] = th.nn.functional.interpolate(large_batch, scale_factor=small_size/large_size, mode="area").detach() + yield large_batch, model_kwargs + + +def create_argparser(): + defaults = dict( + data_dir="", + guide_dir="", + crop_size=128, + clip_denoised=True, + num_samples=100, + batch_size=4, + use_ddim=False, + base_samples="", + model_path="", + ) + defaults.update(pgsr_model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/pixel_guide_super_res_train.py b/scripts/pixel_guide_super_res_train.py new file mode 100644 index 0000000000000000000000000000000000000000..a095f385e9e1ab159807616769e8339ac329ec58 --- /dev/null +++ b/scripts/pixel_guide_super_res_train.py @@ -0,0 +1,108 @@ +""" +Train a super-resolution model. +""" + +import argparse + +import torch.nn.functional as F + +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.resample import create_named_schedule_sampler +from pixel_guide_diffusion.script_util import ( + pgsr_model_and_diffusion_defaults, + pgsr_create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) +from pixel_guide_diffusion.train_util import TrainLoop + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model...") + model, diffusion = pgsr_create_model_and_diffusion( + **args_to_dict(args, pgsr_model_and_diffusion_defaults().keys()) + ) + model.to(dist_util.dev()) + schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) + + logger.log("creating data loader...") + data = load_superres_data( + args.data_dir, + args.batch_size, + large_size=args.large_size, + small_size=args.small_size, + class_cond=args.class_cond, + guide_dir=args.guide_dir, + guide_size=args.guide_size, + crop_size=args.crop_size, + deterministic=True, + ) + + logger.log("training...") + TrainLoop( + model=model, + diffusion=diffusion, + data=data, + batch_size=args.batch_size, + microbatch=args.microbatch, + lr=args.lr, + ema_rate=args.ema_rate, + log_interval=args.log_interval, + save_interval=args.save_interval, + resume_checkpoint=args.resume_checkpoint, + use_fp16=args.use_fp16, + fp16_scale_growth=args.fp16_scale_growth, + schedule_sampler=schedule_sampler, + weight_decay=args.weight_decay, + lr_anneal_steps=args.lr_anneal_steps, + ).run_loop() + + +def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False, guide_dir='', guide_size=0, crop_size=0, deterministic=False): + data = load_data( + data_dir=data_dir, + batch_size=batch_size, + image_size=large_size, + class_cond=class_cond, + guide_dir=guide_dir, + guide_size=guide_size, + crop_size=crop_size, + deterministic=deterministic, + ) + for large_batch, model_kwargs in data: + model_kwargs["low_res"] = F.interpolate(large_batch, scale_factor=small_size/large_size, mode="area") + yield large_batch, model_kwargs + + +def create_argparser(): + defaults = dict( + data_dir="", + guide_dir="", + crop_size=32, + schedule_sampler="uniform", + lr=1e-4, + weight_decay=0.0, + lr_anneal_steps=0, + batch_size=1, + microbatch=-1, + ema_rate="0.9999", + log_interval=10, + save_interval=10000, + resume_checkpoint="", + use_fp16=False, + fp16_scale_growth=1e-3, + ) + defaults.update(pgsr_model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/pixel_guide_train.py b/scripts/pixel_guide_train.py new file mode 100644 index 0000000000000000000000000000000000000000..179252211c16408144bf06feb438be599bd81801 --- /dev/null +++ b/scripts/pixel_guide_train.py @@ -0,0 +1,89 @@ +""" +Train a super-resolution model. +""" + +import argparse + +import torch.nn.functional as F + +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.resample import create_named_schedule_sampler +from pixel_guide_diffusion.script_util import ( + pg_model_and_diffusion_defaults, + pg_create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) +from pixel_guide_diffusion.train_util import TrainLoop + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model...") + model, diffusion = pg_create_model_and_diffusion( + **args_to_dict(args, pg_model_and_diffusion_defaults().keys()) + ) + model.to(dist_util.dev()) + schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) + + logger.log("creating data loader...") + data = load_data( + data_dir=args.data_dir, + batch_size=args.batch_size, + image_size=args.image_size, + class_cond=args.class_cond, + guide_dir=args.guide_dir, + guide_size=args.guide_size, + deterministic=True, + ) + + logger.log("training...") + TrainLoop( + model=model, + diffusion=diffusion, + data=data, + batch_size=args.batch_size, + microbatch=args.microbatch, + lr=args.lr, + ema_rate=args.ema_rate, + log_interval=args.log_interval, + save_interval=args.save_interval, + resume_checkpoint=args.resume_checkpoint, + use_fp16=args.use_fp16, + fp16_scale_growth=args.fp16_scale_growth, + schedule_sampler=schedule_sampler, + weight_decay=args.weight_decay, + lr_anneal_steps=args.lr_anneal_steps, + ).run_loop() + + +def create_argparser(): + defaults = dict( + data_dir="", + guide_dir="", + schedule_sampler="uniform", + lr=1e-4, + weight_decay=0.0, + lr_anneal_steps=0, + batch_size=1, + microbatch=-1, + ema_rate="0.9999", + log_interval=10, + save_interval=10000, + resume_checkpoint="", + use_fp16=False, + fp16_scale_growth=1e-3, + ) + defaults.update(pg_model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/super_res_sample.py b/scripts/super_res_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e4e3374073945a4f34d92c0caab164c45eac3a --- /dev/null +++ b/scripts/super_res_sample.py @@ -0,0 +1,117 @@ +""" +Generate a large batch of samples from a super resolution model, given a batch +of samples from a regular model from image_sample.py. +""" + +import argparse +import os + +import blobfile as bf +import numpy as np +import torch as th +import torch.distributed as dist + +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.script_util import ( + sr_model_and_diffusion_defaults, + sr_create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model...") + model, diffusion = sr_create_model_and_diffusion( + **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) + ) + model.load_state_dict( + dist_util.load_state_dict(args.model_path, map_location="cpu") + ) + model.to(dist_util.dev()) + model.eval() + + logger.log("loading data...") + data = load_data_for_worker(args.base_samples, args.batch_size, args.class_cond) + + logger.log("creating samples...") + all_images = [] + while len(all_images) * args.batch_size < args.num_samples: + model_kwargs = next(data) + model_kwargs = {k: v.to(dist_util.dev()) for k, v in model_kwargs.items()} + sample = diffusion.p_sample_loop( + model, + (args.batch_size, 3, args.large_size, args.large_size), + clip_denoised=args.clip_denoised, + model_kwargs=model_kwargs, + ) + sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) + sample = sample.permute(0, 2, 3, 1) + sample = sample.contiguous() + + all_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] + dist.all_gather(all_samples, sample) # gather not supported with NCCL + for sample in all_samples: + all_images.append(sample.cpu().numpy()) + logger.log(f"created {len(all_images) * args.batch_size} samples") + + arr = np.concatenate(all_images, axis=0) + arr = arr[: args.num_samples] + if dist.get_rank() == 0: + shape_str = "x".join([str(x) for x in arr.shape]) + out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") + logger.log(f"saving to {out_path}") + np.savez(out_path, arr) + + dist.barrier() + logger.log("sampling complete") + + +def load_data_for_worker(base_samples, batch_size, class_cond): + with bf.BlobFile(base_samples, "rb") as f: + obj = np.load(f) + image_arr = obj["arr_0"] + if class_cond: + label_arr = obj["arr_1"] + rank = dist.get_rank() + num_ranks = dist.get_world_size() + buffer = [] + label_buffer = [] + while True: + for i in range(rank, len(image_arr), num_ranks): + buffer.append(image_arr[i]) + if class_cond: + label_buffer.append(label_arr[i]) + if len(buffer) == batch_size: + batch = th.from_numpy(np.stack(buffer)).float() + batch = batch / 127.5 - 1.0 + batch = batch.permute(0, 3, 1, 2) + res = dict(low_res=batch) + if class_cond: + res["y"] = th.from_numpy(np.stack(label_buffer)) + yield res + buffer, label_buffer = [], [] + + +def create_argparser(): + defaults = dict( + clip_denoised=True, + num_samples=10000, + batch_size=16, + use_ddim=False, + base_samples="", + model_path="", + ) + defaults.update(sr_model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/scripts/super_res_train.py b/scripts/super_res_train.py new file mode 100644 index 0000000000000000000000000000000000000000..251c40e8df2cef848b8ac43386e8bc1c5ac49d54 --- /dev/null +++ b/scripts/super_res_train.py @@ -0,0 +1,98 @@ +""" +Train a super-resolution model. +""" + +import argparse + +import torch.nn.functional as F + +from pixel_guide_diffusion import dist_util, logger +from pixel_guide_diffusion.image_datasets import load_data +from pixel_guide_diffusion.resample import create_named_schedule_sampler +from pixel_guide_diffusion.script_util import ( + sr_model_and_diffusion_defaults, + sr_create_model_and_diffusion, + args_to_dict, + add_dict_to_argparser, +) +from pixel_guide_diffusion.train_util import TrainLoop + + +def main(): + args = create_argparser().parse_args() + + dist_util.setup_dist() + logger.configure() + + logger.log("creating model...") + model, diffusion = sr_create_model_and_diffusion( + **args_to_dict(args, sr_model_and_diffusion_defaults().keys()) + ) + model.to(dist_util.dev()) + schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) + + logger.log("creating data loader...") + data = load_superres_data( + args.data_dir, + args.batch_size, + large_size=args.large_size, + small_size=args.small_size, + class_cond=args.class_cond, + ) + + logger.log("training...") + TrainLoop( + model=model, + diffusion=diffusion, + data=data, + batch_size=args.batch_size, + microbatch=args.microbatch, + lr=args.lr, + ema_rate=args.ema_rate, + log_interval=args.log_interval, + save_interval=args.save_interval, + resume_checkpoint=args.resume_checkpoint, + use_fp16=args.use_fp16, + fp16_scale_growth=args.fp16_scale_growth, + schedule_sampler=schedule_sampler, + weight_decay=args.weight_decay, + lr_anneal_steps=args.lr_anneal_steps, + ).run_loop() + + +def load_superres_data(data_dir, batch_size, large_size, small_size, class_cond=False): + data = load_data( + data_dir=data_dir, + batch_size=batch_size, + image_size=large_size, + class_cond=class_cond, + ) + for large_batch, model_kwargs in data: + model_kwargs["low_res"] = F.interpolate(large_batch, small_size, mode="area") + yield large_batch, model_kwargs + + +def create_argparser(): + defaults = dict( + data_dir="", + schedule_sampler="uniform", + lr=1e-4, + weight_decay=0.0, + lr_anneal_steps=0, + batch_size=1, + microbatch=-1, + ema_rate="0.9999", + log_interval=10, + save_interval=10000, + resume_checkpoint="", + use_fp16=False, + fp16_scale_growth=1e-3, + ) + defaults.update(sr_model_and_diffusion_defaults()) + parser = argparse.ArgumentParser() + add_dict_to_argparser(parser, defaults) + return parser + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..03d79fa50cd2da2c94b45f351b9a966d321dad76 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup + +setup( + name="pixel-guide-diffusion", + py_modules=["pixel_guide_diffusion"], + install_requires=["blobfile>=1.0.5", "torch", "tqdm"], +) diff --git a/test_danbooru.sh b/test_danbooru.sh new file mode 100644 index 0000000000000000000000000000000000000000..0512be74d28ea244b3d57559b29c9d43baf22382 --- /dev/null +++ b/test_danbooru.sh @@ -0,0 +1,6 @@ + +MODEL_FLAGS="--image_size 32 --guide_size 128 --num_channels 128 --num_res_blocks 3 --learn_sigma True --dropout 0.0" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TEST_FLAGS="--batch_size 4" + +OPENAI_LOGDIR="./danbooru2017_guided_test_log" python scripts/pixel_guide_sample.py --data_dir data/danbooru2017/anime --guide_dir data/danbooru2017/anime_sketch --timestep_respacing ddim25 --use_ddim True --model_path danbooru2017_guided_log/ema_0.9999_360000.pt $MODEL_FLAGS $DIFFUSION_FLAGS $TEST_FLAGS diff --git a/test_danbooru_cascade.sh b/test_danbooru_cascade.sh new file mode 100644 index 0000000000000000000000000000000000000000..39ec5efd62103685f5df061cbcc92d47f6d431a8 --- /dev/null +++ b/test_danbooru_cascade.sh @@ -0,0 +1,6 @@ + +MODEL_FLAGS="--image_size 32 --small_size 32 --large_size 128 --guide_size 128 --num_channels 128 --num_channels2 64 --num_res_blocks 3 --learn_sigma True --dropout 0.0 --use_attention2 False" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TEST_FLAGS="--batch_size 4 --seed 233" + +OPENAI_LOGDIR="./danbooru2017_guided_cascaded_test_log" python scripts/cascaded_pixel_guide_sample.py --data_dir data/danbooru2017/anime --guide_dir data/danbooru2017/anime_sketch --timestep_respacing ddim25 --use_ddim True --model_path danbooru2017_guided_log/ema_0.9999_360000.pt --model_path2 danbooru2017_guided_sr_log/ema_0.9999_360000.pt $MODEL_FLAGS $DIFFUSION_FLAGS $TEST_FLAGS diff --git a/test_danbooru_sr.sh b/test_danbooru_sr.sh new file mode 100644 index 0000000000000000000000000000000000000000..145e3c0f2d003e278205d76916a0cdb4473b6221 --- /dev/null +++ b/test_danbooru_sr.sh @@ -0,0 +1,6 @@ + +MODEL_FLAGS="--large_size 128 --small_size 32 --guide_size 128 --num_channels 64 --num_res_blocks 3 --use_attention False --learn_sigma True --dropout 0.0" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TEST_FLAGS="--crop_size 128 --batch_size 4" + +OPENAI_LOGDIR="./danbooru2017_guided_sr_test_log" python scripts/pixel_guide_super_res_sample.py --data_dir data/danbooru2017/anime --guide_dir data/danbooru2017/anime_sketch --timestep_respacing ddim25 --use_ddim True --model_path danbooru2017_guided_sr_log/ema_0.9999_360000.pt $MODEL_FLAGS $DIFFUSION_FLAGS $TEST_FLAGS diff --git a/train_danbooru.sh b/train_danbooru.sh new file mode 100644 index 0000000000000000000000000000000000000000..10560117162d7604d296356f5feba0ccabcd2f77 --- /dev/null +++ b/train_danbooru.sh @@ -0,0 +1,6 @@ + +MODEL_FLAGS="--image_size 32 --guide_size 128 --num_channels 128 --num_res_blocks 3 --learn_sigma True --dropout 0.0" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TRAIN_FLAGS="--use_fp16 True --lr 1e-4 --batch_size 128 --schedule_sampler loss-second-moment" + +OPENAI_LOGDIR="./danbooru2017_guided_log" python scripts/pixel_guide_train.py --data_dir data/danbooru2017/anime --guide_dir data/danbooru2017/anime_sketch $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS diff --git a/train_danbooru_sr.sh b/train_danbooru_sr.sh new file mode 100644 index 0000000000000000000000000000000000000000..c6be273321500edf2ca352a5a5c06ac215c8f8c5 --- /dev/null +++ b/train_danbooru_sr.sh @@ -0,0 +1,6 @@ + +MODEL_FLAGS="--large_size 128 --small_size 32 --guide_size 128 --num_channels 64 --num_res_blocks 3 --use_attention False --learn_sigma True --dropout 0.0" +DIFFUSION_FLAGS="--diffusion_steps 4000 --noise_schedule cosine" +TRAIN_FLAGS="--crop_size 32 --use_fp16 True --lr 1e-4 --batch_size 128 --schedule_sampler loss-second-moment" + +OPENAI_LOGDIR="./danbooru2017_guided_sr_log" python scripts/pixel_guide_super_res_train.py --data_dir data/danbooru2017/anime --guide_dir data/danbooru2017/anime_sketch $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS