import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffab.modules.common.layers import clampped_one_hot from diffab.modules.common.so3 import ApproxAngularDistribution, random_normal_so3, so3vec_to_rotation, rotation_to_so3vec class VarianceSchedule(nn.Module): def __init__(self, num_steps=100, s=0.01): super().__init__() T = num_steps t = torch.arange(0, num_steps+1, dtype=torch.float) f_t = torch.cos( (np.pi / 2) * ((t/T) + s) / (1 + s) ) ** 2 alpha_bars = f_t / f_t[0] betas = 1 - (alpha_bars[1:] / alpha_bars[:-1]) betas = torch.cat([torch.zeros([1]), betas], dim=0) betas = betas.clamp_max(0.999) sigmas = torch.zeros_like(betas) for i in range(1, betas.size(0)): sigmas[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] sigmas = torch.sqrt(sigmas) self.register_buffer('betas', betas) self.register_buffer('alpha_bars', alpha_bars) self.register_buffer('alphas', 1 - betas) self.register_buffer('sigmas', sigmas) class PositionTransition(nn.Module): def __init__(self, num_steps, var_sched_opt={}): super().__init__() self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) def add_noise(self, p_0, mask_generate, t): """ Args: p_0: (N, L, 3). mask_generate: (N, L). t: (N,). """ alpha_bar = self.var_sched.alpha_bars[t] c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) e_rand = torch.randn_like(p_0) p_noisy = c0*p_0 + c1*e_rand p_noisy = torch.where(mask_generate[..., None].expand_as(p_0), p_noisy, p_0) return p_noisy, e_rand def denoise(self, p_t, eps_p, mask_generate, t): # IMPORTANT: # clampping alpha is to fix the instability issue at the first step (t=T) # it seems like a problem with the ``improved ddpm''. alpha = self.var_sched.alphas[t].clamp_min( self.var_sched.alphas[-2] ) alpha_bar = self.var_sched.alpha_bars[t] sigma = self.var_sched.sigmas[t].view(-1, 1, 1) c0 = ( 1.0 / torch.sqrt(alpha + 1e-8) ).view(-1, 1, 1) c1 = ( (1 - alpha) / torch.sqrt(1 - alpha_bar + 1e-8) ).view(-1, 1, 1) z = torch.where( (t > 1)[:, None, None].expand_as(p_t), torch.randn_like(p_t), torch.zeros_like(p_t), ) p_next = c0 * (p_t - c1 * eps_p) + sigma * z p_next = torch.where(mask_generate[..., None].expand_as(p_t), p_next, p_t) return p_next class RotationTransition(nn.Module): def __init__(self, num_steps, var_sched_opt={}, angular_distrib_fwd_opt={}, angular_distrib_inv_opt={}): super().__init__() self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) # Forward (perturb) c1 = torch.sqrt(1 - self.var_sched.alpha_bars) # (T,). self.angular_distrib_fwd = ApproxAngularDistribution(c1.tolist(), **angular_distrib_fwd_opt) # Inverse (generate) sigma = self.var_sched.sigmas self.angular_distrib_inv = ApproxAngularDistribution(sigma.tolist(), **angular_distrib_inv_opt) self.register_buffer('_dummy', torch.empty([0, ])) def add_noise(self, v_0, mask_generate, t): """ Args: v_0: (N, L, 3). mask_generate: (N, L). t: (N,). """ N, L = mask_generate.size() alpha_bar = self.var_sched.alpha_bars[t] c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) # Noise rotation e_scaled = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_fwd, device=self._dummy.device) # (N, L, 3) e_normal = e_scaled / (c1 + 1e-8) E_scaled = so3vec_to_rotation(e_scaled) # (N, L, 3, 3) # Scaled true rotation R0_scaled = so3vec_to_rotation(c0 * v_0) # (N, L, 3, 3) R_noisy = E_scaled @ R0_scaled v_noisy = rotation_to_so3vec(R_noisy) v_noisy = torch.where(mask_generate[..., None].expand_as(v_0), v_noisy, v_0) return v_noisy, e_scaled def denoise(self, v_t, v_next, mask_generate, t): N, L = mask_generate.size() e = random_normal_so3(t[:, None].expand(N, L), self.angular_distrib_inv, device=self._dummy.device) # (N, L, 3) e = torch.where( (t > 1)[:, None, None].expand(N, L, 3), e, torch.zeros_like(e) # Simply denoise and don't add noise at the last step ) E = so3vec_to_rotation(e) R_next = E @ so3vec_to_rotation(v_next) v_next = rotation_to_so3vec(R_next) v_next = torch.where(mask_generate[..., None].expand_as(v_next), v_next, v_t) return v_next class AminoacidCategoricalTransition(nn.Module): def __init__(self, num_steps, num_classes=20, var_sched_opt={}): super().__init__() self.num_classes = num_classes self.var_sched = VarianceSchedule(num_steps, **var_sched_opt) @staticmethod def _sample(c): """ Args: c: (N, L, K). Returns: x: (N, L). """ N, L, K = c.size() c = c.view(N*L, K) + 1e-8 x = torch.multinomial(c, 1).view(N, L) return x def add_noise(self, x_0, mask_generate, t): """ Args: x_0: (N, L) mask_generate: (N, L). t: (N,). Returns: c_t: Probability, (N, L, K). x_t: Sample, LongTensor, (N, L). """ N, L = x_0.size() K = self.num_classes c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K). alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) c_noisy = (alpha_bar*c_0) + ( (1-alpha_bar)/K ) c_t = torch.where(mask_generate[..., None].expand(N,L,K), c_noisy, c_0) x_t = self._sample(c_t) return c_t, x_t def posterior(self, x_t, x_0, t): """ Args: x_t: Category LongTensor (N, L) or Probability FloatTensor (N, L, K). x_0: Category LongTensor (N, L) or Probability FloatTensor (N, L, K). t: (N,). Returns: theta: Posterior probability at (t-1)-th step, (N, L, K). """ K = self.num_classes if x_t.dim() == 3: c_t = x_t # When x_t is probability distribution. else: c_t = clampped_one_hot(x_t, num_classes=K).float() # (N, L, K) if x_0.dim() == 3: c_0 = x_0 # When x_0 is probability distribution. else: c_0 = clampped_one_hot(x_0, num_classes=K).float() # (N, L, K) alpha = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) alpha_bar = self.var_sched.alpha_bars[t][:, None, None] # (N, 1, 1) theta = ((alpha*c_t) + (1-alpha)/K) * ((alpha_bar*c_0) + (1-alpha_bar)/K) # (N, L, K) theta = theta / (theta.sum(dim=-1, keepdim=True) + 1e-8) return theta def denoise(self, x_t, c_0_pred, mask_generate, t): """ Args: x_t: (N, L). c_0_pred: Normalized probability predicted by networks, (N, L, K). mask_generate: (N, L). t: (N,). Returns: post: Posterior probability at (t-1)-th step, (N, L, K). x_next: Sample at (t-1)-th step, LongTensor, (N, L). """ c_t = clampped_one_hot(x_t, num_classes=self.num_classes).float() # (N, L, K) post = self.posterior(c_t, c_0_pred, t=t) # (N, L, K) post = torch.where(mask_generate[..., None].expand(post.size()), post, c_t) x_next = self._sample(post) return post, x_next