sd / repositories /k-diffusion /sample_clip_guided.py
zhengqilin
add init
1976a91
raw
history blame
5.19 kB
#!/usr/bin/env python3
"""CLIP guided sampling from k-diffusion models."""
import argparse
import math
import accelerate
import clip
from kornia import augmentation as KA
from resize_right import resize
import torch
from torch.nn import functional as F
from torchvision import transforms
from tqdm import trange, tqdm
import k_diffusion as K
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def make_cond_model_fn(model, cond_fn):
def model_fn(x, sigma, **kwargs):
with torch.enable_grad():
x = x.detach().requires_grad_()
denoised = model(x, sigma, **kwargs)
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
return cond_denoised
return model_fn
def make_static_thresh_model_fn(model, value=1.):
def model_fn(x, sigma, **kwargs):
return model(x, sigma, **kwargs).clamp(-value, value)
return model_fn
def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument('prompt', type=str,
default='the prompt to use')
p.add_argument('--batch-size', type=int, default=16,
help='the batch size')
p.add_argument('--checkpoint', type=str, required=True,
help='the checkpoint to use')
p.add_argument('--clip-guidance-scale', '-cgs', type=float, default=500.,
help='the CLIP guidance scale')
p.add_argument('--clip-model', type=str, default='ViT-B/16', choices=clip.available_models(),
help='the CLIP model to use')
p.add_argument('--config', type=str, required=True,
help='the model config')
p.add_argument('-n', type=int, default=64,
help='the number of images to sample')
p.add_argument('--prefix', type=str, default='out',
help='the output prefix')
p.add_argument('--steps', type=int, default=100,
help='the number of denoising steps')
args = p.parse_args()
config = K.config.load_config(open(args.config))
model_config = config['model']
# TODO: allow non-square input sizes
assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
size = model_config['input_size']
accelerator = accelerate.Accelerator()
device = accelerator.device
print('Using device:', device, flush=True)
inner_model = K.config.make_model(config).eval().requires_grad_(False).to(device)
inner_model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')['model_ema'])
accelerator.print('Parameters:', K.utils.n_params(inner_model))
model = K.Denoiser(inner_model, sigma_data=model_config['sigma_data'])
sigma_min = model_config['sigma_min']
sigma_max = model_config['sigma_max']
clip_model = clip.load(args.clip_model, device=device)[0].eval().requires_grad_(False)
clip_normalize = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711))
clip_size = (clip_model.visual.input_resolution, clip_model.visual.input_resolution)
aug = KA.RandomAffine(0, (1/14, 1/14), p=1, padding_mode='border')
def get_image_embed(x):
if x.shape[2:4] != clip_size:
x = resize(x, out_shape=clip_size, pad_mode='reflect')
x = clip_normalize(x)
x = clip_model.encode_image(x).float()
return F.normalize(x)
target_embed = F.normalize(clip_model.encode_text(clip.tokenize(args.prompt, truncate=True).to(device)).float())
def cond_fn(x, t, denoised):
image_embed = get_image_embed(aug(denoised.add(1).div(2)))
loss = spherical_dist_loss(image_embed, target_embed).sum() * args.clip_guidance_scale
grad = -torch.autograd.grad(loss, x)[0]
return grad
model_fn = make_cond_model_fn(model, cond_fn)
model_fn = make_static_thresh_model_fn(model_fn)
@torch.no_grad()
@K.utils.eval_mode(model)
def run():
if accelerator.is_local_main_process:
tqdm.write('Sampling...')
sigmas = K.sampling.get_sigmas_karras(args.steps, sigma_min, sigma_max, rho=7., device=device)
def sample_fn(n):
x = torch.randn([n, model_config['input_channels'], size[0], size[1]], device=device) * sigmas[0]
x_0 = K.sampling.sample_dpmpp_2s_ancestral(model_fn, x, sigmas, eta=1., disable=not accelerator.is_local_main_process)
return x_0
x_0 = K.evaluation.compute_features(accelerator, sample_fn, lambda x: x, args.n, args.batch_size)
if accelerator.is_main_process:
for i, out in enumerate(x_0):
filename = f'{args.prefix}_{i:05}.png'
K.utils.to_pil_image(out).save(filename)
try:
run()
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()