import torch import os import torchvision.transforms.functional as TF from torchvision.utils import make_grid import numpy as np from IPython import display # # Callback functions # class SamplerCallback(object): # Creates the callback function to be passed into the samplers for each step def __init__(self, args, root, mask=None, init_latent=None, sigmas=None, sampler=None, verbose=False): self.model = root.model self.device = root.device self.sampler_name = args.sampler self.dynamic_threshold = args.dynamic_threshold self.static_threshold = args.static_threshold self.mask = mask self.init_latent = init_latent self.sigmas = sigmas self.sampler = sampler self.verbose = verbose self.batch_size = args.n_samples self.save_sample_per_step = args.save_sample_per_step self.show_sample_per_step = args.show_sample_per_step self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ] if self.save_sample_per_step: for path in self.paths_to_image_steps: os.makedirs(path, exist_ok=True) self.step_index = 0 self.noise = None if init_latent is not None: self.noise = torch.randn_like(init_latent, device=self.device) self.mask_schedule = None if sigmas is not None and len(sigmas) > 0: self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas)) elif len(sigmas) == 0: self.mask = None # no mask needed if no steps (usually happens because strength==1.0) if self.sampler_name in ["plms","ddim"]: if mask is not None: assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable" if self.sampler_name in ["plms","ddim"]: # Callback function formated for compvis latent diffusion samplers self.callback = self.img_callback_ else: # Default callback function uses k-diffusion sampler variables self.callback = self.k_callback_ self.verbose_print = print if verbose else lambda *args, **kwargs: None def display_images(self, images): images = images.double().cpu().add(1).div(2).clamp(0, 1) images = torch.tensor(np.array(images)) grid = make_grid(images, 4).cpu() display.clear_output(wait=True) display.display(TF.to_pil_image(grid)) return def view_sample_step(self, latents, path_name_modifier=''): if self.save_sample_per_step: samples = self.model.decode_first_stage(latents) fname = f'{path_name_modifier}_{self.step_index:05}.png' for i, sample in enumerate(samples): sample = sample.double().cpu().add(1).div(2).clamp(0, 1) sample = torch.tensor(np.array(sample)) grid = make_grid(sample, 4).cpu() TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname)) if self.show_sample_per_step: samples = self.model.linear_decode(latents) print(path_name_modifier) self.display_images(samples) return # The callback function is applied to the image at each step def dynamic_thresholding_(self, img, threshold): # Dynamic thresholding from Imagen paper (May 2022) s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) s = np.max(np.append(s,1.0)) torch.clamp_(img, -1*s, s) torch.FloatTensor.div_(img, s) # Callback for samplers in the k-diffusion repo, called thus: # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) def k_callback_(self, args_dict): self.step_index = args_dict['i'] if self.dynamic_threshold is not None: self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold) if self.static_threshold is not None: torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold) if self.mask is not None: init_noise = self.init_latent + self.noise * args_dict['sigma'] is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 ) new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1) args_dict['x'].copy_(new_img) self.view_sample_step(args_dict['denoised'], "x0_pred") self.view_sample_step(args_dict['x'], "x") # Callback for Compvis samplers # Function that is called on the image (img) and step (i) at each step def img_callback_(self, img, pred_x0, i): self.step_index = i # Thresholding functions if self.dynamic_threshold is not None: self.dynamic_thresholding_(img, self.dynamic_threshold) if self.static_threshold is not None: torch.clamp_(img, -1*self.static_threshold, self.static_threshold) if self.mask is not None: i_inv = len(self.sigmas) - i - 1 init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(self.device), noise=self.noise) is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 ) new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1) img.copy_(new_img) self.view_sample_step(pred_x0, "x0_pred") self.view_sample_step(img, "x")