File size: 2,158 Bytes
c8ddb9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Conditioning Augmentation Module"""

from typing import Any

import torch
from torch import nn


class CondAugmentation(nn.Module):
    """Conditioning Augmentation Module"""

    def __init__(self, D: int, conditioning_dim: int):
        """
        :param D: Dimension of the text embedding space [D from AttnGAN paper]
        :param conditioning_dim: Dimension of the conditioning space
        """
        super().__init__()
        self.cond_dim = conditioning_dim
        self.cond_augment = nn.Linear(D, conditioning_dim * 4, bias=True)
        self.glu = nn.GLU(dim=1)

    def encode(self, text_embedding: torch.Tensor) -> Any:
        """
        This function encodes the text embedding into the conditioning space
        :param text_embedding: Text embedding
        :return: Conditioning embedding
        """
        x_tensor = self.glu(self.cond_augment(text_embedding))
        mu_tensor = x_tensor[:, : self.cond_dim]
        logvar = x_tensor[:, self.cond_dim :]
        return mu_tensor, logvar

    def sample(self, mu_tensor: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        This function samples from the Gaussian distribution
        :param mu: Mean of the Gaussian distribution
        :param logvar: Log variance of the Gaussian distribution
        :return: Sample from the Gaussian distribution
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(
            std
        )  # check if this should add requires_grad = True to this tensor?
        return mu_tensor + eps * std

    def forward(self, text_embedding: torch.Tensor) -> Any:
        """
        This function encodes the text embedding into the conditioning space,
        and samples from the Gaussian distribution.
        :param text_embedding: Text embedding
        :return c_hat: Conditioning embedding (C^ from StackGAN++ paper)
        :return mu: Mean of the Gaussian distribution
        :return logvar: Log variance of the Gaussian distribution
        """
        mu_tensor, logvar = self.encode(text_embedding)
        c_hat = self.sample(mu_tensor, logvar)
        return c_hat, mu_tensor, logvar