File size: 10,096 Bytes
8c212a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
from torch import nn
import numpy as np
class SupportSets(nn.Module):
def __init__(self, prompt_features=None, num_support_sets=None, num_support_dipoles=None, support_vectors_dim=None,
lss_beta=0.5, css_beta=0.5, jung_radius=None):
"""SupportSets class constructor.
Args:
prompt_features (torch.Tensor) : CLIP text feature statistics of prompts from the given corpus
num_support_sets (int) : number of support sets (each one defining a warping function)
num_support_dipoles (int) : number of support dipoles per support set (per warping function)
support_vectors_dim (int) : dimensionality of support vectors (latent space dimensionality, z_dim)
lss_beta (float) : set beta parameter for initializing latent space RBFs' gamma parameters
(0.25 < lss_beta < 1.0)
css_beta (float) : set beta parameter for fixing CLIP space RBFs' gamma parameters
(0.25 <= css_beta < 1.0)
jung_radius (float) : radius of the minimum enclosing ball of a set of a set of 10K latent codes
"""
super(SupportSets, self).__init__()
self.prompt_features = prompt_features
################################################################################################################
## ##
## [ Corpus Support Sets (CSS) ] ##
## ##
################################################################################################################
if self.prompt_features is not None:
# Initialization
self.num_support_sets = self.prompt_features.shape[0]
self.num_support_dipoles = 1
self.support_vectors_dim = self.prompt_features.shape[2]
self.css_beta = css_beta
############################################################################################################
## [ SUPPORT_SETS: (K, N, d) ] ##
############################################################################################################
self.SUPPORT_SETS = nn.Parameter(data=torch.ones(self.num_support_sets,
2 * self.num_support_dipoles * self.support_vectors_dim),
requires_grad=False)
self.SUPPORT_SETS.data = self.prompt_features.reshape(self.prompt_features.shape[0],
self.prompt_features.shape[1] *
self.prompt_features.shape[2]).clone()
############################################################################################################
## [ ALPHAS: (K, N) ] ##
############################################################################################################
# Define alphas as pairs of [-1, 1] for each dipole
self.ALPHAS = torch.zeros(self.num_support_sets, 2 * self.num_support_dipoles)
for k in range(self.num_support_sets):
a = []
for _ in range(self.num_support_dipoles):
a.extend([1, -1])
self.ALPHAS[k] = torch.Tensor(a)
############################################################################################################
## [ GAMMAS: (K, N) ] ##
############################################################################################################
# Define RBF loggammas
self.LOGGAMMA = nn.Parameter(data=torch.ones(self.num_support_sets, 1), requires_grad=False)
for k in range(self.num_support_sets):
g = -np.log(self.css_beta) / (self.prompt_features[k, 1] - self.prompt_features[k, 0]).norm() ** 2
self.LOGGAMMA.data[k] = torch.log(torch.Tensor([g]))
################################################################################################################
## ##
## [ Latent Support Sets (LSS) ] ##
## ##
################################################################################################################
else:
# Initialization
if num_support_sets is None:
raise ValueError("Number of latent support sets not defined.")
else:
self.num_support_sets = num_support_sets
if num_support_dipoles is None:
raise ValueError("Number of latent support dipoles not defined.")
else:
self.num_support_dipoles = num_support_dipoles
if support_vectors_dim is None:
raise ValueError("Latent support vector dimensionality not defined.")
else:
self.support_vectors_dim = support_vectors_dim
if jung_radius is None:
raise ValueError("Jung radius not given.")
else:
self.jung_radius = jung_radius
self.lss_beta = lss_beta
############################################################################################################
## [ SUPPORT_SETS: (K, N, d) ] ##
############################################################################################################
# Choose r_min and r_max based on the Jung radius
self.r_min = 0.90 * self.jung_radius
self.r_max = 1.25 * self.jung_radius
self.radii = torch.arange(self.r_min, self.r_max, (self.r_max - self.r_min) / self.num_support_sets)
self.SUPPORT_SETS = nn.Parameter(data=torch.ones(self.num_support_sets,
2 * self.num_support_dipoles * self.support_vectors_dim))
SUPPORT_SETS = torch.zeros(self.num_support_sets, 2 * self.num_support_dipoles, self.support_vectors_dim)
for k in range(self.num_support_sets):
SV_set = []
for i in range(self.num_support_dipoles):
SV = torch.randn(1, self.support_vectors_dim)
SV_set.extend([SV, -SV])
SV_set = torch.cat(SV_set)
SV_set = self.radii[k] * SV_set / torch.norm(SV_set, dim=1, keepdim=True)
SUPPORT_SETS[k, :] = SV_set
# Reshape support sets tensor into a matrix and initialize support sets matrix
self.SUPPORT_SETS.data = SUPPORT_SETS.reshape(
self.num_support_sets, 2 * self.num_support_dipoles * self.support_vectors_dim).clone()
############################################################################################################
## [ ALPHAS: (K, N) ] ##
############################################################################################################
# Define alphas as pairs of [-1, 1] for each dipole
self.ALPHAS = torch.zeros(self.num_support_sets, 2 * self.num_support_dipoles)
for k in range(self.num_support_sets):
a = []
for _ in range(self.num_support_dipoles):
a.extend([1, -1])
self.ALPHAS.data[k] = torch.Tensor(a)
############################################################################################################
## [ GAMMAS: (K, N) ] ##
############################################################################################################
# Define RBF loggammas
self.LOGGAMMA = nn.Parameter(data=torch.ones(self.num_support_sets, 1))
for k in range(self.num_support_sets):
g = -np.log(self.lss_beta) / ((2 * self.radii[k]) ** 2)
self.LOGGAMMA.data[k] = torch.log(torch.Tensor([g]))
def forward(self, support_sets_mask, z):
# Get RBF support sets batch
support_sets_batch = torch.matmul(support_sets_mask, self.SUPPORT_SETS)
support_sets_batch = support_sets_batch.reshape(-1, 2 * self.num_support_dipoles, self.support_vectors_dim)
# Get batch of RBF alpha parameters
alphas_batch = torch.matmul(support_sets_mask, self.ALPHAS).unsqueeze(dim=2)
# Get batch of RBF gamma/log(gamma) parameters
gammas_batch = torch.exp(torch.matmul(support_sets_mask, self.LOGGAMMA).unsqueeze(dim=2))
# Calculate grad of f at z
D = z.unsqueeze(dim=1).repeat(1, 2 * self.num_support_dipoles, 1) - support_sets_batch
grad_f = -2 * (alphas_batch * gammas_batch *
torch.exp(-gammas_batch * (torch.norm(D, dim=2) ** 2).unsqueeze(dim=2)) * D).sum(dim=1)
return grad_f / torch.norm(grad_f, dim=1, keepdim=True)
|