import torch import torch.nn as nn from torch.nn import functional as F import clip from torchvision.transforms import Normalize as Normalize from torchvision.utils import make_grid import numpy as np from IPython import display from sklearn.cluster import KMeans import torchvision.transforms.functional as TF ### # Loss functions ### ## CLIP ----------------------------------------- class MakeCutouts(nn.Module): def __init__(self, cut_size, cutn, cut_pow=1.): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow def forward(self, input): sideY, sideX = input.shape[2:4] max_size = min(sideX, sideY) min_size = min(sideX, sideY, self.cut_size) cutouts = [] for _ in range(self.cutn): size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) offsetx = torch.randint(0, sideX - size + 1, ()) offsety = torch.randint(0, sideY - size + 1, ()) cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) return torch.cat(cutouts) def spherical_dist_loss(x, y): x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) def make_clip_loss_fn(root, args): clip_size = root.clip_model.visual.input_resolution # for openslip: clip_model.visual.image_size def parse_prompt(prompt): if prompt.startswith('http://') or prompt.startswith('https://'): vals = prompt.rsplit(':', 2) vals = [vals[0] + ':' + vals[1], *vals[2:]] else: vals = prompt.rsplit(':', 1) vals = vals + ['', '1'][len(vals):] return vals[0], float(vals[1]) def parse_clip_prompts(clip_prompt): target_embeds, weights = [], [] for prompt in clip_prompt: txt, weight = parse_prompt(prompt) target_embeds.append(root.clip_model.encode_text(clip.tokenize(txt).to(root.device)).float()) weights.append(weight) target_embeds = torch.cat(target_embeds) weights = torch.tensor(weights, device=root.device) if weights.sum().abs() < 1e-3: raise RuntimeError('Clip prompt weights must not sum to 0.') weights /= weights.sum().abs() return target_embeds, weights normalize = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) make_cutouts = MakeCutouts(clip_size, args.cutn, args.cut_pow) target_embeds, weights = parse_clip_prompts(args.clip_prompt) def clip_loss_fn(x, sigma, **kwargs): nonlocal target_embeds, weights, make_cutouts, normalize clip_in = normalize(make_cutouts(x.add(1).div(2))) image_embeds = root.clip_model.encode_image(clip_in).float() dists = spherical_dist_loss(image_embeds[:, None], target_embeds[None]) dists = dists.view([args.cutn, 1, -1]) losses = dists.mul(weights).sum(2).mean(0) return losses.sum() return clip_loss_fn def make_aesthetics_loss_fn(root,args): clip_size = root.clip_model.visual.input_resolution # for openslip: clip_model.visual.image_size def aesthetics_cond_fn(x, sigma, **kwargs): clip_in = F.interpolate(x, (clip_size, clip_size)) image_embeds = root.clip_model.encode_image(clip_in).float() losses = (10 - root.aesthetics_model(image_embeds)[0]) return losses.sum() return aesthetics_cond_fn ## end CLIP ----------------------------------------- # blue loss from @johnowhitaker's tutorial on Grokking Stable Diffusion def blue_loss_fn(x, sigma, **kwargs): # How far are the blue channel values to 0.9: error = torch.abs(x[:,-1, :, :] - 0.9).mean() return error # MSE loss from init def make_mse_loss(target): def mse_loss(x, sigma, **kwargs): return (x - target).square().mean() return mse_loss # MSE loss from init def exposure_loss(target): def exposure_loss_fn(x, sigma, **kwargs): error = torch.abs(x-target).mean() return error return exposure_loss_fn def mean_loss_fn(x, sigma, **kwargs): error = torch.abs(x).mean() return error def var_loss_fn(x, sigma, **kwargs): error = x.var() return error def get_color_palette(root, n_colors, target, verbose=False): def display_color_palette(color_list): # Expand to 64x64 grid of single color pixels images = color_list.unsqueeze(2).repeat(1,1,64).unsqueeze(3).repeat(1,1,1,64) images = images.double().cpu().add(1).div(2).clamp(0, 1) images = torch.tensor(np.array(images)) grid = make_grid(images, 8).cpu() display.display(TF.to_pil_image(grid)) return # Create color palette kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(torch.flatten(target[0],1,2).T.cpu().numpy()) color_list = torch.Tensor(kmeans.cluster_centers_).to(root.device) if verbose: display_color_palette(color_list) # Get ratio of each color class in the target image color_indexes, color_counts = np.unique(kmeans.labels_, return_counts=True) # color_list = color_list[color_indexes] return color_list, color_counts def make_rgb_color_match_loss(root, target, n_colors, ignore_sat_weight=None, img_shape=None, device='cuda:0'): """ target (tensor): Image sample (values from -1 to 1) to extract the color palette n_colors (int): Number of colors in the color palette ignore_sat_weight (None or number>0): Scale to ignore color saturation in color comparison img_shape (None or (int, int)): shape (width, height) of sample that the conditioning gradient is applied to, if None then calculate target color distribution during gradient calculation rather than once at the beginning """ assert n_colors > 0, "Must use at least one color with color match loss" def adjust_saturation(sample, saturation_factor): # as in torchvision.transforms.functional.adjust_saturation, but for tensors with values from -1,1 return blend(sample, TF.rgb_to_grayscale(sample), saturation_factor) def blend(img1, img2, ratio): return (ratio * img1 + (1.0 - ratio) * img2).clamp(-1, 1).to(img1.dtype) def color_distance_distributions(n_colors, img_shape, color_list, color_counts, n_images=1): # Get the target color distance distributions # Ensure color counts total the amout of pixels in the image n_pixels = img_shape[0]*img_shape[1] color_counts = (color_counts * n_pixels / sum(color_counts)).astype(int) # Make color distances for each color, sorted by distance color_distributions = torch.zeros((n_colors, n_images, n_pixels), device=device) for i_image in range(n_images): for ic,color0 in enumerate(color_list): i_dist = 0 for jc,color1 in enumerate(color_list): color_dist = torch.linalg.norm(color0 - color1) color_distributions[ic, i_image, i_dist:i_dist+color_counts[jc]] = color_dist i_dist += color_counts[jc] color_distributions, _ = torch.sort(color_distributions,dim=2) return color_distributions color_list, color_counts = get_color_palette(root, n_colors, target) color_distributions = None if img_shape is not None: color_distributions = color_distance_distributions(n_colors, img_shape, color_list, color_counts) def rgb_color_ratio_loss(x, sigma, **kwargs): nonlocal color_distributions all_color_norm_distances = torch.ones(len(color_list), x.shape[0], x.shape[2], x.shape[3]).to(device) * 6.0 # distance to color won't be more than max norm1 distance between -1 and 1 in 3 color dimensions for ic,color in enumerate(color_list): # Make a tensor of entirely one color color = color[None,:,None].repeat(1,1,x.shape[2]).unsqueeze(3).repeat(1,1,1,x.shape[3]) # Get the color distances if ignore_sat_weight is None: # Simple color distance color_distances = torch.linalg.norm(x - color, dim=1) else: # Color distance if the colors were saturated # This is to make color comparison ignore shadows and highlights, for example color_distances = torch.linalg.norm(adjust_saturation(x, ignore_sat_weight) - color, dim=1) all_color_norm_distances[ic] = color_distances all_color_norm_distances = torch.flatten(all_color_norm_distances,start_dim=2) if color_distributions is None: color_distributions = color_distance_distributions(n_colors, (x.shape[2], x.shape[3]), color_list, color_counts, n_images=x.shape[0]) # Sort the color distances so we can compare them as if they were a cumulative distribution function all_color_norm_distances, _ = torch.sort(all_color_norm_distances,dim=2) color_norm_distribution_diff = all_color_norm_distances - color_distributions return color_norm_distribution_diff.square().mean() return rgb_color_ratio_loss ### # Thresholding functions for grad ### def threshold_by(threshold, threshold_type, clamp_schedule): def dynamic_thresholding(vals, sigma): # Dynamic thresholding from Imagen paper (May 2022) s = np.percentile(np.abs(vals.cpu()), threshold, axis=tuple(range(1,vals.ndim))) s = np.max(np.append(s,1.0)) vals = torch.clamp(vals, -1*s, s) vals = torch.FloatTensor.div(vals, s) return vals def static_thresholding(vals, sigma): vals = torch.clamp(vals, -1*threshold, threshold) return vals def mean_thresholding(vals, sigma): # Thresholding that appears in Jax and Disco magnitude = vals.square().mean(axis=(1,2,3),keepdims=True).sqrt() vals = vals * torch.where(magnitude > threshold, threshold / magnitude, 1.0) return vals def scheduling(vals, sigma): clamp_val = clamp_schedule[sigma.item()] magnitude = vals.square().mean().sqrt() vals = vals * magnitude.clamp(max=clamp_val) / magnitude #print(clamp_val) return vals if threshold_type == 'dynamic': return dynamic_thresholding elif threshold_type == 'static': return static_thresholding elif threshold_type == 'mean': return mean_thresholding elif threshold_type == 'schedule': return scheduling else: raise Exception(f"Thresholding type {threshold_type} not supported")