File size: 12,199 Bytes
c8ddb9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
"""Module containing the loss functions for the GANs."""
from typing import Any, Dict
import torch
from torch import nn
# pylint: disable=too-many-arguments
# pylint: disable=too-many-locals
def generator_loss(
logits: Dict[str, Dict[str, torch.Tensor]],
local_fake_incept_feat: torch.Tensor,
global_fake_incept_feat: torch.Tensor,
real_labels: torch.Tensor,
words_emb: torch.Tensor,
sent_emb: torch.Tensor,
match_labels: torch.Tensor,
cap_lens: torch.Tensor,
class_ids: torch.Tensor,
real_vgg_feat: torch.Tensor,
fake_vgg_feat: torch.Tensor,
const_dict: Dict[str, float],
) -> Any:
"""Calculate the loss for the generator.
Args:
logits: Dictionary with fake/real and word-level/uncond/cond logits
local_fake_incept_feat: The local inception features for the fake images.
global_fake_incept_feat: The global inception features for the fake images.
real_labels: Label for "real" image as predicted by discriminator,
this is a tensor of ones. [shape: (batch_size, 1)].
word_labels: POS tagged word labels for the captions. [shape: (batch_size, L)]
words_emb: The embeddings for all the words in the captions.
shape: (batch_size, embedding_size, max_caption_length)
sent_emb: The embeddings for the sentences.
shape: (batch_size, embedding_size)
match_labels: Tensor of shape: (batch_size, 1).
This is of the form torch.tensor([0, 1, 2, ..., batch-1])
cap_lens: The length of the 'actual' captions in the batch [without padding]
shape: (batch_size, 1)
class_ids: The class ids for the instance. shape: (batch_size, 1)
real_vgg_feat: The vgg features for the real images. shape: (batch_size, 128, 128, 128)
fake_vgg_feat: The vgg features for the fake images. shape: (batch_size, 128, 128, 128)
const_dict: The dictionary containing the constants.
"""
lambda1 = const_dict["lambda1"]
total_error_g = 0.0
cond_logits = logits["fake"]["cond"]
cond_err_g = nn.BCEWithLogitsLoss()(cond_logits, real_labels)
uncond_logits = logits["fake"]["uncond"]
uncond_err_g = nn.BCEWithLogitsLoss()(uncond_logits, real_labels)
# add up the conditional and unconditional losses
loss_g = cond_err_g + uncond_err_g
total_error_g += loss_g
# DAMSM Loss from attnGAN.
loss_damsm = damsm_loss(
local_fake_incept_feat,
global_fake_incept_feat,
words_emb,
sent_emb,
match_labels,
cap_lens,
class_ids,
const_dict,
)
total_error_g += loss_damsm
loss_per = 0.5 * nn.MSELoss()(real_vgg_feat, fake_vgg_feat) # perceptual loss
total_error_g += lambda1 * loss_per
return total_error_g
def damsm_loss(
local_incept_feat: torch.Tensor,
global_incept_feat: torch.Tensor,
words_emb: torch.Tensor,
sent_emb: torch.Tensor,
match_labels: torch.Tensor,
cap_lens: torch.Tensor,
class_ids: torch.Tensor,
const_dict: Dict[str, float],
) -> Any:
"""Calculate the DAMSM loss from the attnGAN paper.
Args:
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
global_incept_feat: The global inception features. [shape: (batch, D)]
words_emb: The embeddings for all the words in the captions.
shape: (batch, D, max_caption_length)
sent_emb: The embeddings for the sentences. shape: (batch_size, D)
match_labels: Tensor of shape: (batch_size, 1).
This is of the form torch.tensor([0, 1, 2, ..., batch-1])
cap_lens: The length of the 'actual' captions in the batch [without padding]
shape: (batch_size, 1)
class_ids: The class ids for the instance. shape: (batch, 1)
const_dict: The dictionary containing the constants.
"""
batch_size = match_labels.size(0)
# Mask mis-match samples, that come from the same class as the real sample
masks = []
match_scores = []
gamma1 = const_dict["gamma1"]
gamma2 = const_dict["gamma2"]
gamma3 = const_dict["gamma3"]
lambda3 = const_dict["lambda3"]
for i in range(batch_size):
mask = (class_ids == class_ids[i]).int()
# This ensures that "correct class" index is not included in the mask.
mask[i] = 0
masks.append(mask.reshape(1, -1)) # shape: (1, batch)
numb_words = int(cap_lens[i])
# shape: (1, D, L), this picks the caption at ith batch index.
query_words = words_emb[i, :, :numb_words].unsqueeze(0)
# shape: (batch, D, L), this expands the same caption for all batch indices.
query_words = query_words.repeat(batch_size, 1, 1)
c_i = compute_region_context_vector(
local_incept_feat, query_words, gamma1
) # Taken from attnGAN paper. shape: (batch, D, L)
query_words = query_words.transpose(1, 2) # shape: (batch, L, D)
c_i = c_i.transpose(1, 2) # shape: (batch, L, D)
query_words = query_words.reshape(
batch_size * numb_words, -1
) # shape: (batch * L, D)
c_i = c_i.reshape(batch_size * numb_words, -1) # shape: (batch * L, D)
r_i = compute_relevance(
c_i, query_words
) # cosine similarity, or R(c_i, e_i) from attnGAN paper. shape: (batch * L, 1)
r_i = r_i.view(batch_size, numb_words) # shape: (batch, L)
r_i = torch.exp(r_i * gamma2) # shape: (batch, L)
r_i = r_i.sum(dim=1, keepdim=True) # shape: (batch, 1)
r_i = torch.log(
r_i
) # This is image-text matching score b/w whole image and caption, shape: (batch, 1)
match_scores.append(r_i)
masks = torch.cat(masks, dim=0).bool() # type: ignore
match_scores = torch.cat(match_scores, dim=1) # type: ignore
# This corresponds to P(D|Q) from attnGAN.
match_scores = gamma3 * match_scores # type: ignore
match_scores.data.masked_fill_( # type: ignore
masks, -float("inf")
) # mask out the scores for mis-matched samples
match_scores_t = match_scores.transpose( # type: ignore
0, 1
) # This corresponds to P(Q|D) from attnGAN.
# This corresponds to L1_w from attnGAN.
l1_w = nn.CrossEntropyLoss()(match_scores, match_labels)
# This corresponds to L2_w from attnGAN.
l2_w = nn.CrossEntropyLoss()(match_scores_t, match_labels)
incept_feat_norm = torch.linalg.norm(global_incept_feat, dim=1)
sent_emb_norm = torch.linalg.norm(sent_emb, dim=1)
# shape: (batch, batch)
global_match_score = global_incept_feat @ (sent_emb.T)
global_match_score = (
global_match_score / torch.outer(incept_feat_norm, sent_emb_norm)
).clamp(min=1e-8)
global_match_score = gamma3 * global_match_score
# mask out the scores for mis-matched samples
global_match_score.data.masked_fill_(masks, -float("inf")) # type: ignore
global_match_t = global_match_score.T # shape: (batch, batch)
# This corresponds to L1_s from attnGAN.
l1_s = nn.CrossEntropyLoss()(global_match_score, match_labels)
# This corresponds to L2_s from attnGAN.
l2_s = nn.CrossEntropyLoss()(global_match_t, match_labels)
loss_damsm = lambda3 * (l1_w + l2_w + l1_s + l2_s)
return loss_damsm
def compute_relevance(c_i: torch.Tensor, query_words: torch.Tensor) -> Any:
"""Computes the cosine similarity between the region context vector and the query words.
Args:
c_i: The region context vector. shape: (batch * L, D)
query_words: The query words. shape: (batch * L, D)
"""
prod = c_i * query_words # shape: (batch * L, D)
numr = torch.sum(prod, dim=1) # shape: (batch * L, 1)
norm_c = torch.linalg.norm(c_i, ord=2, dim=1)
norm_q = torch.linalg.norm(query_words, ord=2, dim=1)
denr = norm_c * norm_q
r_i = (numr / denr).clamp(min=1e-8).squeeze() # shape: (batch * L, 1)
return r_i
def compute_region_context_vector(
local_incept_feat: torch.Tensor, query_words: torch.Tensor, gamma1: float
) -> Any:
"""Compute the region context vector (c_i) from attnGAN paper.
Args:
local_incept_feat: The local inception features. [shape: (batch, D, 17, 17)]
query_words: The embeddings for all the words in the captions. shape: (batch, D, L)
gamma1: The gamma1 value from attnGAN paper.
"""
batch, L = query_words.size(0), query_words.size(2) # pylint: disable=invalid-name
feat_height, feat_width = local_incept_feat.size(2), local_incept_feat.size(3)
N = feat_height * feat_width # pylint: disable=invalid-name
# Reshape the local inception features to (batch, D, N)
local_incept_feat = local_incept_feat.view(batch, -1, N)
# shape: (batch, N, D)
incept_feat_t = local_incept_feat.transpose(1, 2)
sim_matrix = incept_feat_t @ query_words # shape: (batch, N, L)
sim_matrix = sim_matrix.view(batch * N, L) # shape: (batch * N, L)
sim_matrix = nn.Softmax(dim=1)(sim_matrix) # shape: (batch * N, L)
sim_matrix = sim_matrix.view(batch, N, L) # shape: (batch, N, L)
sim_matrix = torch.transpose(sim_matrix, 1, 2) # shape: (batch, L, N)
sim_matrix = sim_matrix.reshape(batch * L, N) # shape: (batch * L, N)
alpha_j = gamma1 * sim_matrix # shape: (batch * L, N)
alpha_j = nn.Softmax(dim=1)(alpha_j) # shape: (batch * L, N)
alpha_j = alpha_j.view(batch, L, N) # shape: (batch, L, N)
alpha_j_t = torch.transpose(alpha_j, 1, 2) # shape: (batch, N, L)
c_i = (
local_incept_feat @ alpha_j_t
) # shape: (batch, D, L) [summing over N dimension in paper, so we multiply like this]
return c_i
def discriminator_loss(
logits: Dict[str, Dict[str, torch.Tensor]],
labels: Dict[str, Dict[str, torch.Tensor]],
) -> Any:
"""
Calculate discriminator objective
:param dict[str, dict[str, torch.Tensor]] logits:
Dictionary with fake/real and word-level/uncond/cond logits
Example:
logits = {
"fake": {
"word_level": torch.Tensor (BxL)
"uncond": torch.Tensor (Bx1)
"cond": torch.Tensor (Bx1)
},
"real": {
"word_level": torch.Tensor (BxL)
"uncond": torch.Tensor (Bx1)
"cond": torch.Tensor (Bx1)
},
}
:param dict[str, dict[str, torch.Tensor]] labels:
Dictionary with fake/real and word-level/image labels
Example:
labels = {
"fake": {
"word_level": torch.Tensor (BxL)
"image": torch.Tensor (Bx1)
},
"real": {
"word_level": torch.Tensor (BxL)
"image": torch.Tensor (Bx1)
},
}
:param float lambda_4: Hyperparameter for word loss in paper
:return: Discriminator objective loss
:rtype: Any
"""
# define main loss functions for logit losses
tot_loss = 0.0
bce_logits = nn.BCEWithLogitsLoss()
bce = nn.BCELoss()
# calculate word-level loss
word_loss = bce(logits["real"]["word_level"], labels["real"]["word_level"])
# calculate unconditional adversarial loss
uncond_loss = bce_logits(logits["real"]["uncond"], labels["real"]["image"])
# calculate conditional adversarial loss
cond_loss = bce_logits(logits["real"]["cond"], labels["real"]["image"])
tot_loss = (uncond_loss + cond_loss) / 2.0
fake_uncond_loss = bce_logits(logits["fake"]["uncond"], labels["fake"]["image"])
fake_cond_loss = bce_logits(logits["fake"]["cond"], labels["fake"]["image"])
tot_loss += (fake_uncond_loss + fake_cond_loss) / 3.0
tot_loss += word_loss
return tot_loss
def kl_loss(mu_tensor: torch.Tensor, logvar: torch.Tensor) -> Any:
"""
Calculate KL loss
:param torch.Tensor mu_tensor: Mean of latent distribution
:param torch.Tensor logvar: Log variance of latent distribution
:return: KL loss [-0.5 * (1 + log(sigma) - mu^2 - sigma^2)]
:rtype: Any
"""
return torch.mean(-0.5 * (1 + 0.5 * logvar - mu_tensor.pow(2) - torch.exp(logvar)))
|