File size: 4,225 Bytes
f4b9f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch 
import torch.nn as nn
from transformers import PreTrainedModel 
from configs import BertVAEConfig
from transformers.models.bert.modeling_bert import BertEncoder, BertModel


class BertVAE(PreTrainedModel):
    config_class = BertVAEConfig

    def __init__(self, config):
       super().__init__(config)
       self.encoder = BertEncoder(config)
       self.bert = BertModel.from_pretrained('bert-base-uncased')
       self.fc_mu = nn.Linear(config.hidden_size, config.hidden_size)
       self.fc_var = nn.Linear(config.hidden_size, config.hidden_size)
       self.enc_cls = nn.Linear(config.hidden_size, config.position_num)
       self.dec_cls = nn.Linear(config.hidden_size, config.position_num)
       self.decoder = BertEncoder(config)

       for p in self.bert.parameters():
           p.requires_grad = False
    

    def encode(self, input_ids,  **kwargs):
        '''
            x: {input_ids: (batch_size, seq_len), attention_mask: (batch_size, seq_len)}
        '''
        
        x = self.bert(input_ids).last_hidden_state
        outputs = self.encoder(x, **kwargs)
        hidden_state = outputs.last_hidden_state
        mu = self.fc_mu(hidden_state)
        log_var = self.fc_var(hidden_state)
        return mu, log_var
    

    def encoder_cls(self, input_ids, **kwargs):
        '''
            input_ids: {input_ids: (batch_size, seq_len)}
        '''
        x = self.bert(input_ids).last_hidden_state
        outputs = self.encoder(x, **kwargs)
        hidden_state = outputs.last_hidden_state
        return self.enc_cls(hidden_state[:, 0, :])
    

    def decoder_cls(self, z, **kwargs):
        '''
            z: latent vector of shape (batch_size, seq_len, dim)
        '''
        outputs = self.decoder(z, **kwargs)
        hidden_state = outputs.last_hidden_state
        return self.dec_cls(hidden_state[:, 0, :])


    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    

    def decode(self, z, **kwargs):
        '''
            z: latent vector of shape (batch_size, seq_len, dim)
        '''
        outputs = self.decoder(z, **kwargs)
        return outputs.last_hidden_state


    def forward(self, input_ids, position=None, **kwargs):
        mu, log_var = self.encode(**input_ids, **kwargs)
        z = self.reparameterize(mu, log_var)
        return self.decode(z, **kwargs), mu, log_var
    

    def _elbo(self, x, x_hat, mu, log_var):
        '''
        Given input x, logits, mu, log_var, compute the negative ELBO
            x: input tensor of shape (batch_size, seq_len, dim)
            logits: logits tensor of shape (batch_size, seq_len, dim)
            mu: mean tensor of shape (batch_size, seq_len, dim)
            log_var: log variance tensor of shape (batch_size, seq_len, dim)
        '''
        recon_loss = nn.functional.mse_loss(x_hat, x, reduction='mean')
        kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()))
        return recon_loss + kl_loss*0.1


    def elbo(self, input_ids, **kwargs):
        '''
        Given input x, compute the ELBO
            x: input tensor of shape (batch_size, seq_len, dim)
        '''
        x = self.bert(input_ids, **kwargs).last_hidden_state
        outputs = self.encoder(x, **kwargs)
        hidden_state = outputs.last_hidden_state
        mu = self.fc_mu(hidden_state)
        log_var = self.fc_var(hidden_state)
        z = self.reparameterize(mu, log_var)
        outputs = self.decoder(z, **kwargs)
        x_hat = outputs.last_hidden_state
        return self._elbo(x, x_hat, mu, log_var)
    

    def reconstruct(self, input_ids, **kwargs):
        '''
        Given input_ids, reconstruct x
            x: input tensor of shape (batch_size, seq_len, dim)
        '''
        return self.forward(input_ids, **kwargs)[0]


    
    def sample(self, num_samples, device, **kwargs):
        '''
        Given input x, generate a sample
            x: input tensor of shape (batch_size, seq_len, dim)
        '''
        z = torch.randn(num_samples, self.config.max_position_embeddings, self.config.hidden_size).to(device)
        return self.decode(z, **kwargs)