import torch import torch.nn.functional as F from torchvision import transforms def calc_mean_std(feat, eps=1e-5): # eps is a small value added to the variance to avoid divide-by-zero. size = feat.size() N, C = size[:2] feat_var = feat.view(N, C, -1).var(dim=2) + eps if len(size) == 3: feat_std = feat_var.sqrt().view(N, C, 1) feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1) else: feat_std = feat_var.sqrt().view(N, C, 1, 1) feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) return feat_mean, feat_std def get_img(img, resolution=512): norm_mean = [0.5, 0.5, 0.5] norm_std = [0.5, 0.5, 0.5] transform = transforms.Compose([ transforms.Resize((resolution, resolution)), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std) ]) img = transform(img) return img.unsqueeze(0) @torch.no_grad() def slerp(p0, p1, fract_mixing: float, adain=True): r""" Copied from lunarring/latentblending Helper function to correctly mix two random variables using spherical interpolation. The function will always cast up to float64 for sake of extra 4. Args: p0: First tensor for interpolation p1: Second tensor for interpolation fract_mixing: float Mixing coefficient of interval [0, 1]. 0 will return in p0 1 will return in p1 0.x will return a mix between both preserving angular velocity. """ if p0.dtype == torch.float16: recast_to = 'fp16' else: recast_to = 'fp32' p0 = p0.double() p1 = p1.double() if adain: mean1, std1 = calc_mean_std(p0) mean2, std2 = calc_mean_std(p1) mean = mean1 * (1 - fract_mixing) + mean2 * fract_mixing std = std1 * (1 - fract_mixing) + std2 * fract_mixing norm = torch.linalg.norm(p0) * torch.linalg.norm(p1) epsilon = 1e-7 dot = torch.sum(p0 * p1) / norm dot = dot.clamp(-1+epsilon, 1-epsilon) theta_0 = torch.arccos(dot) sin_theta_0 = torch.sin(theta_0) theta_t = theta_0 * fract_mixing s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 s1 = torch.sin(theta_t) / sin_theta_0 interp = p0*s0 + p1*s1 if adain: interp = F.instance_norm(interp) * std + mean if recast_to == 'fp16': interp = interp.half() elif recast_to == 'fp32': interp = interp.float() return interp def do_replace_attn(key: str): # return key.startswith('up_blocks.2') or key.startswith('up_blocks.3') return key.startswith('up')