File size: 2,914 Bytes
f55b152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torchvision.models as models


class InceptionEncoder(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(InceptionEncoder, self).__init__()
        self.train_CNN = train_CNN
        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(embed_size, momentum = 0.01)
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):
        features = self.inception(images)
        norm_features = self.bn(features)
        return self.dropout(self.relu(norm_features))


class LstmDecoder(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
        super(LstmDecoder, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.device = device
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers = self.num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, encoder_out, captions):
        h0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
        c0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((encoder_out.unsqueeze(0), embeddings), dim=0)
        hiddens, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
        outputs = self.linear(hiddens)
        return outputs


class SeqToSeq(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
        super(SeqToSeq, self).__init__()
        self.encoder = InceptionEncoder(embed_size)
        self.decoder = LstmDecoder(embed_size, hidden_size, vocab_size, num_layers, device)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def caption_image(self, image, vocabulary, max_length = 50):
        result_caption = []

        with torch.no_grad():
            x = self.encoder(image).unsqueeze(0)
            states = None

            for _ in range(max_length):
                hiddens, states = self.decoder.lstm(x, states)
                output = self.decoder.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                result_caption.append(predicted.item())
                x = self.decoder.embed(predicted).unsqueeze(0)

                if vocabulary[str(predicted.item())] == "<EOS>":
                    break

        return [vocabulary[str(idx)] for idx in result_caption]