# Introduction

This is a training script for a diffusion model called MinImagen. A smaller adaptation of original Imagen architecture introduced by Google.

# Setup

In [None]:
#install the minimagen package
!pip install minimagen

Collecting minimagen
  Downloading minimagen-0.0.9-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.0/43.0 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiohttp==3.8.1 (from minimagen)
  Downloading aiohttp-3.8.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m21.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiosignal==1.2.0 (from minimagen)
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting attrs==21.4.0 (from minimagen)
  Downloading attrs-21.4.0-py2.py3-none-any.whl (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting certifi==2022.6.15 (from minimagen)
  Downloading certifi-2022.6.15-py3-none-any.whl (160 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m160

In [None]:
#utility imports
import os
from datetime import datetime

#pytorch related imports
import torch.utils.data as data_utils
from torch import optim

#minimagen related imports
from minimagen.Imagen import Imagen
from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest
from minimagen.generate import load_minimagen, load_params
from minimagen.t5 import get_encoded_dim
from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \
    create_directory, get_model_size, save_training_info, get_default_args, MinimagenTrain, \
    load_testing_parameters

ModuleNotFoundError: ignored

In [None]:
# Get device: Connect to GPU runtime for better performance
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Command line argument parser
parser = get_minimagen_parser()
class args_cls:
  a = 0

#get an instance of the args_cls
args = args_cls()

In [None]:
#directory creation for training
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dir_path = f"./training_{timestamp}"
training_dir = create_directory(dir_path)

In [None]:
#A dictionary of hyperparameters
hyperparameters = dict(
            PARAMETERS=None,
            NUM_WORKERS=0,
            BATCH_SIZE=20,
            MAX_NUM_WORDS=32,
            IMG_SIDE_LEN=128,
            EPOCHS=10,
            T5_NAME='t5_small',
            TRAIN_VALID_FRAC=0.5,
            TRAINING_DIRECTORY = '/content/training_20230731_061334',
            TIMESTEPS=25,
            OPTIM_LR=0.0001,
            ACCUM_ITER=1,
            CHCKPT_NUM=500,
            VALID_NUM=None,
            RESTART_DIRECTORY=None,
            TESTING=False,
            timestamp=None,
        )
# Replace relevant values in arg dict
args.__dict__ = {**args.__dict__, **hyperparameters}

# Data

In [None]:
# Load subset of Conceptual Captions dataset.
train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=False)
indices = torch.arange(1000)

#create train and validation datasets with given number of samples
train_dataset = data_utils.Subset(train_dataset, indices)
valid_dataset = data_utils.Subset(valid_dataset, indices)

# Create dataloaders
dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}
train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts)



  0%|          | 0/2 [00:00<?, ?it/s]

# UNet

In [None]:
# Instantiate Unet with default parameters and transfer to GPU if available
unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]
unets = [Unet(**unet_params).to(device) for unet_params in unets_params]

In [None]:
# Specify MinImagen parameters
imagen_params = dict(
    image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),
    timesteps=args.TIMESTEPS,
    cond_drop_prob=0.15,
    text_encoder_name=args.T5_NAME
)

# Create MinImagen from UNets with specified imagen parameters
imagen = Imagen(unets=unets, **imagen_params).to(device)

In [None]:
# Fill in unspecified arguments with defaults
unets_params = [{**get_default_args(Unet), **i} for i in unets_params]
imagen_params = {**get_default_args(Imagen), **imagen_params}

# Get the size of the Imagen model in megabytes
model_size_MB = get_model_size(imagen)

# Save all training info (config files, model size, etc.)
save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)

# Training

In [None]:
# Create optimizer - Adam
optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)

# Train the MinImagen instance
MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)


-------------------- EPOCH 1 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:05<00:22,  5.54s/it][A
 40%|████      | 2/5 [00:31<00:53, 17.84s/it][A
 60%|██████    | 3/5 [00:45<00:31, 15.65s/it][A
 80%|████████  | 4/5 [00:51<00:11, 11.95s/it][A
100%|██████████| 5/5 [01:06<00:00, 13.20s/it]
1it [01:14, 74.38s/it]

Unet 0 avg validation loss:  tensor(1.1316, device='cuda:0')
Unet 1 avg validation loss:  tensor(1.0483, device='cuda:0')


5it [02:27, 29.42s/it]



-------------------- EPOCH 2 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:11<00:46, 11.63s/it][A
 40%|████      | 2/5 [00:26<00:40, 13.34s/it][A
 60%|██████    | 3/5 [00:34<00:21, 10.96s/it][A
 80%|████████  | 4/5 [00:41<00:09,  9.59s/it][A
100%|██████████| 5/5 [02:56<00:00, 35.30s/it]
1it [03:10, 190.22s/it]

Unet 0 avg validation loss:  tensor(1.0965, device='cuda:0')
Unet 1 avg validation loss:  tensor(1.0373, device='cuda:0')


5it [04:12, 50.50s/it]



-------------------- EPOCH 3 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:07<00:31,  7.80s/it][A
 40%|████      | 2/5 [00:13<00:19,  6.65s/it][A
 60%|██████    | 3/5 [00:27<00:20, 10.07s/it][A
 80%|████████  | 4/5 [00:32<00:07,  7.78s/it][A
100%|██████████| 5/5 [02:54<00:00, 34.83s/it]
1it [03:54, 234.78s/it]

Unet 0 avg validation loss:  tensor(1.0735, device='cuda:0')
Unet 1 avg validation loss:  tensor(1.0289, device='cuda:0')


5it [04:21, 52.38s/it]



-------------------- EPOCH 4 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:15<01:02, 15.63s/it][A
 40%|████      | 2/5 [00:23<00:32, 10.99s/it][A
 60%|██████    | 3/5 [00:29<00:17,  8.55s/it][A
 80%|████████  | 4/5 [02:48<01:00, 60.17s/it][A
100%|██████████| 5/5 [02:53<00:00, 34.62s/it]
1it [02:57, 177.30s/it]

Unet 0 avg validation loss:  tensor(1.0502, device='cuda:0')
Unet 1 avg validation loss:  tensor(1.0210, device='cuda:0')


5it [04:07, 49.47s/it]



-------------------- EPOCH 5 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:07<00:28,  7.14s/it][A
 40%|████      | 2/5 [00:31<00:52, 17.42s/it][A
 60%|██████    | 3/5 [00:37<00:23, 11.92s/it][A
 80%|████████  | 4/5 [00:43<00:09,  9.72s/it][A
100%|██████████| 5/5 [00:51<00:00, 10.29s/it]
1it [01:16, 76.25s/it]

Unet 0 avg validation loss:  tensor(1.0274, device='cuda:0')
Unet 1 avg validation loss:  tensor(1.0135, device='cuda:0')


5it [02:32, 30.58s/it]



-------------------- EPOCH 6 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:34<02:17, 34.45s/it][A
 40%|████      | 2/5 [00:45<01:01, 20.50s/it][A
 60%|██████    | 3/5 [00:51<00:28, 14.07s/it][A
 80%|████████  | 4/5 [00:56<00:10, 10.56s/it][A
100%|██████████| 5/5 [01:11<00:00, 14.37s/it]
1it [01:17, 77.59s/it]

Unet 0 avg validation loss:  tensor(1.0088, device='cuda:0')
Unet 1 avg validation loss:  tensor(1.0083, device='cuda:0')


5it [02:25, 29.04s/it]



-------------------- EPOCH 7 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:05<00:21,  5.37s/it][A
 40%|████      | 2/5 [00:18<00:29,  9.75s/it][A
 60%|██████    | 3/5 [00:34<00:25, 12.60s/it][A
 80%|████████  | 4/5 [00:49<00:13, 13.81s/it][A
100%|██████████| 5/5 [00:53<00:00, 10.75s/it]
1it [01:00, 60.32s/it]

Unet 0 avg validation loss:  tensor(0.9863, device='cuda:0')
Unet 1 avg validation loss:  tensor(1.0049, device='cuda:0')


5it [02:07, 25.52s/it]



-------------------- EPOCH 8 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:10<00:41, 10.40s/it][A
 40%|████      | 2/5 [00:15<00:22,  7.41s/it][A
 60%|██████    | 3/5 [00:30<00:21, 10.65s/it][A
 80%|████████  | 4/5 [00:41<00:10, 11.00s/it][A
100%|██████████| 5/5 [00:51<00:00, 10.24s/it]
1it [01:04, 64.48s/it]

Unet 0 avg validation loss:  tensor(0.9715, device='cuda:0')
Unet 1 avg validation loss:  tensor(1.0007, device='cuda:0')


5it [02:05, 25.04s/it]



-------------------- EPOCH 9 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:10<00:41, 10.36s/it][A
 40%|████      | 2/5 [00:16<00:23,  7.89s/it][A
 60%|██████    | 3/5 [00:24<00:15,  7.73s/it][A
 80%|████████  | 4/5 [00:35<00:09,  9.07s/it][A
100%|██████████| 5/5 [02:51<00:00, 34.28s/it]
1it [03:30, 210.85s/it]

Unet 0 avg validation loss:  tensor(0.9587, device='cuda:0')
Unet 1 avg validation loss:  tensor(0.9981, device='cuda:0')


5it [04:11, 50.39s/it]



-------------------- EPOCH 10 --------------------

----------Training...----------


0it [00:00, ?it/s]


----------Validation...----------



  0%|          | 0/5 [00:00<?, ?it/s][A
 20%|██        | 1/5 [00:23<01:32, 23.20s/it][A
 40%|████      | 2/5 [00:33<00:46, 15.36s/it][A
 60%|██████    | 3/5 [00:37<00:20, 10.14s/it][A
 80%|████████  | 4/5 [00:43<00:08,  8.75s/it][A
100%|██████████| 5/5 [00:50<00:00, 10.04s/it]
1it [01:37, 97.50s/it]

Unet 0 avg validation loss:  tensor(0.9483, device='cuda:0')
Unet 1 avg validation loss:  tensor(0.9955, device='cuda:0')


5it [02:03, 24.66s/it]


# Inference

In [None]:
from argparse import ArgumentParser
from minimagen.generate import load_minimagen, sample_and_save


In [None]:
# Specify the caption(s) to generate images for
captions = ['happy']

In [None]:
args_cls

__main__.args_cls

In [None]:
# Use `sample_and_save` to generate and save the iamges
sample_and_save(captions, training_directory='/content/training_20230731_065902')

0it [00:00, ?it/s]
sampling loop time step:   0%|          | 0/25 [00:00<?, ?it/s][A
sampling loop time step:  12%|█▏        | 3/25 [00:00<00:01, 21.00it/s][A
sampling loop time step:  24%|██▍       | 6/25 [00:00<00:00, 20.08it/s][A
sampling loop time step:  36%|███▌      | 9/25 [00:00<00:00, 20.17it/s][A
sampling loop time step:  48%|████▊     | 12/25 [00:00<00:00, 20.03it/s][A
sampling loop time step:  60%|██████    | 15/25 [00:00<00:00, 20.02it/s][A
sampling loop time step:  72%|███████▏  | 18/25 [00:00<00:00, 19.98it/s][A
sampling loop time step:  80%|████████  | 20/25 [00:01<00:00, 19.70it/s][A
sampling loop time step:  88%|████████▊ | 22/25 [00:01<00:00, 19.56it/s][A
sampling loop time step: 100%|██████████| 25/25 [00:01<00:00, 19.65it/s]
1it [00:01,  1.28s/it]
sampling loop time step:   0%|          | 0/25 [00:00<?, ?it/s][A
sampling loop time step:   8%|▊         | 2/25 [00:00<00:01, 11.65it/s][A
sampling loop time step:  16%|█▌        | 4/25 [00:00<00:01, 11.37it/s]