File size: 5,141 Bytes
ca389f6
cde81bb
cc017e3
d61c863
cde81bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0a51b5
cde81bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
075cc29
bd66b34
 
075cc29
 
 
 
cde81bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import sys
os.system('pip install gradio==2.3.0a0')
import gradio as gr
os.system('git clone https://github.com/openai/CLIP')
os.system('git clone https://github.com/openai/guided-diffusion')
os.system('pip install -e ./CLIP')
os.system('pip install -e ./guided-diffusion')
os.system('pip install kornia')
os.system("curl -OL 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'")
# Imports
import math
import sys
#from IPython import display
from kornia import augmentation, filters
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm
sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')
import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
# Model settings
model_config = model_and_diffusion_defaults()
model_config.update({
    'attention_resolutions': '32, 16, 8',
    'class_cond': False,
    'diffusion_steps': 1000,
    'rescale_timesteps': False,
    'timestep_respacing': '500',
    'image_size': 256,
    'learn_sigma': True,
    'noise_schedule': 'linear',
    'num_channels': 256,
    'num_head_channels': 64,
    'num_res_blocks': 2,
    'resblock_updown': True,
    'use_fp16': True,
    'use_scale_shift_norm': True,
})
# Load models and define necessary functions
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
model.eval().requires_grad_(False).to(device)
if model_config['use_fp16']:
    model.convert_to_fp16()
clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])
def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
    
def inference(text):
    prompt = text
    batch_size = 1
    clip_guidance_scale = 2750
    seed = 0
    
    if seed is not None:
        torch.manual_seed(seed)
    
    text_embed = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
    
    translate_by = 8 / clip_size
    if translate_by:
        aug = augmentation.RandomAffine(0, (translate_by, translate_by),
                                        padding_mode='border', p=1)
    else:
        aug = nn.Identity()
    
    cur_t = diffusion.num_timesteps - 1
    
    def cond_fn(x, t, y=None):
        with torch.enable_grad():
            x_in = x.detach().requires_grad_()
            sigma = min(24, diffusion.sqrt_recipm1_alphas_cumprod[cur_t] / 4)
            kernel_size = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3)
            x_blur = filters.gaussian_blur2d(x_in, (kernel_size, kernel_size), (sigma, sigma))
            clip_in = F.interpolate(aug(x_blur.add(1).div(2)), (clip_size, clip_size),
                                    mode='bilinear', align_corners=False)
            image_embed = clip_model.encode_image(normalize(clip_in)).float()
            losses = spherical_dist_loss(image_embed, text_embed)
            grad = -torch.autograd.grad(losses.sum(), x_in)[0]
            return grad * clip_guidance_scale
    
    samples = diffusion.p_sample_loop_progressive(
        model,
        (batch_size, 3, model_config['image_size'], model_config['image_size']),
        clip_denoised=True,
        model_kwargs={},
        cond_fn=cond_fn,
        progress=True,
    )
    
    for i, sample in enumerate(samples):
        cur_t -= 1
        if i % 100 == 0 or cur_t == -1:
            print()
            for j, image in enumerate(sample['pred_xstart']):
                filename = f'progress_{j:05}.png'
                TF.to_pil_image(image.add(1).div(2).clamp(0, 1)).save(filename)
                tqdm.write(f'Step {i}, output {j}:')
                #display.display(display.Image(filename))
    return 'progress_00000.png'
    
title = "CLIP Guided Diffusion"
description = "Gradio demo for CLIP Guided Diffusion. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'>By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. | <a href='https://colab.research.google.com/drive/1ED6_MYVXTApBHzQObUPaaMolgf9hZOOF' target='_blank'>Colab</a></p>"

iface = gr.Interface(inference, inputs="text", outputs="image", title=title, description=description, article=article, examples=[["coral reef city by artistation artists"]],
    enable_queue=True)
iface.launch()