import argparse, os, sys, glob import torch import PIL import numpy as np from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange from itertools import islice from einops import rearrange, repeat from torchvision.utils import make_grid import time from pytorch_lightning import seed_everything from torch import autocast from contextlib import contextmanager, nullcontext from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) def torch_gc(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.cuda() model.half() model.eval() return model def load_img(image, W, H): w, h = image.size print(f"loaded input image of size ({w}, {h})") image = image.resize((int(W), int(H)), resample=PIL.Image.LANCZOS) print(f"resize input image to size ({W}, {H})") image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.*image - 1. class AppModel(): def __init__(self,): self.config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") self.model = load_model_from_config(self.config, "models/ldm/stable-diffusion-v1/model.ckpt") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.device = device self.model = self.model.to(device) self.sampler = PLMSSampler(self.model) self.img_sampler = DDIMSampler(self.model) self.C = 4 # latent channels self.f = 8 # downsampling factors def run_with_prompt(self, seed, prompt, n_samples, W, H, scale, ddim_steps, strength=0., init_img=None): torch_gc() seed_everything(seed) ddim_eta=0.0 assert prompt is not None print(f"Prompt: {prompt}") batch_size = n_samples data = [batch_size * [prompt]] start_code = None n_rows = int(n_samples**0.5) precision_scope = autocast if init_img is None: with torch.no_grad(): with precision_scope(device_type='cuda', dtype=torch.float16): with self.model.ema_scope(): all_samples = list() for prompts in tqdm(data, desc="data"): torch_gc() uc = None if scale != 1.0: uc = self.model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) c = self.model.get_learned_conditioning(prompts) shape = [self.C, H // self.f, W // self.f] samples_ddim, _ = self.sampler.sample(S=ddim_steps, conditioning=c, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=start_code) x_samples_ddim = self.model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') image = Image.fromarray(x_sample.astype(np.uint8)) all_samples.append(image) # additionally, grid image grid = torch.stack([x_samples_ddim], 0) grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = make_grid(grid, nrow=n_rows) # to image grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() grid = grid.astype(np.uint8) torch_gc() return grid, all_samples else: init_image = load_img(init_img, W, H).to(self.device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) torch_gc() with precision_scope(device_type='cuda', dtype=torch.float16): init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space torch_gc() sampler = self.img_sampler sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False) assert 0. <= strength < 1., 'can only work with strength in [0.0, 1.0)' t_enc = int(strength * ddim_steps) print(f"target t_enc is {t_enc} steps") with torch.no_grad(): with precision_scope(device_type='cuda', dtype=torch.float16): with self.model.ema_scope(): all_samples = list() for prompts in tqdm(data, desc="data"): uc = None if scale != 1.0: uc = self.model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) c = self.model.get_learned_conditioning(prompts) # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) # decode it samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=uc,) x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') image = Image.fromarray(x_sample.astype(np.uint8)) all_samples.append(image) # additionally, save as grid grid = torch.stack([x_samples], 0) grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = make_grid(grid, nrow=n_rows) # to image grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() grid = grid.astype(np.uint8) torch_gc() return grid, all_samples if __name__ == "__main__": main()