File size: 3,799 Bytes
7962ed0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch import nn
from transformers import CLIPModel
from transformers.models.clip.modeling_clip import _expand_mask

from .utils import drop_sequence_mask


def position_embedding(input, d_model):
    input = input.view(-1, 1)
    dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
    sin = torch.sin(input / 10000 ** (2 * dim / d_model))
    cos = torch.cos(input / 10000 ** (2 * dim / d_model))

    out = torch.zeros((input.shape[0], d_model), device=input.device)
    out[:, ::2] = sin
    out[:, 1::2] = cos
    return out


def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
    pos = torch.arange(max_len, dtype=torch.float32)
    out = position_embedding(pos, d_model)
    if padding_idx is not None:
        out[padding_idx] = 0
    
    return out


class KnwlModel(nn.Module):
    def __init__(self, d_knwl, d_out, pt=0.1):
        super().__init__()

        self.pt = pt

        self.fc_knwl = nn.Linear(d_knwl, d_out, bias=False)
        self.fc_query = nn.Linear(d_knwl, d_out)

        self.pos = nn.Embedding(9, d_out)
        self.score1 = nn.Parameter(torch.randn(1, 1, d_out))
        self.score2 = nn.Parameter(torch.randn(1, 1, d_out))
        
        self.obj = nn.Parameter(torch.randn(1, 1, d_out))
        self.attr = nn.Parameter(torch.randn(1, 1, d_out))
        self.act = nn.Parameter(torch.randn(1, 1, d_out))
        self.query = nn.Parameter(torch.randn(1, 1, d_out))

    @property
    def device(self):
        return self.score1.device
      
    def prepare_input(self, knowledge):
        e = self.fc_knwl(knowledge["embed"])
        p = self.pos(knowledge["pos"])
        s = knowledge["score"].unsqueeze(-1) * self.score1 + self.score2
        e_knwl = e + p + s
        m_knwl = drop_sequence_mask(
            *e_knwl.shape[:2], self.device, self.pt, self.training
        )

        e = self.fc_query(knowledge["query"])
        p = torch.arange(knowledge["query"].shape[1], device=self.device)
        p = self.pos(p[None, :])
        e_query = e + p
        m_query = torch.ones(
            e_query.shape[:2], dtype=torch.long, device=self.device
        )

        return e_knwl, m_knwl, e_query, m_query

    def forward(self, knowledge):
        e_obj, m_obj, e_query, m_query = self.prepare_input(knowledge["obj"])
        e_attr, m_attr, _, _ = self.prepare_input(knowledge["attr"])
        e_act, m_act, _, _ = self.prepare_input(knowledge["act"])
        
        e_obj = e_obj + self.obj
        e_attr = e_attr + self.attr
        e_act = e_act + self.act
        e_query = e_query + self.query

        embeds = torch.cat([e_query, e_obj, e_attr, e_act], dim=1)
        masks = torch.cat([m_query, m_obj, m_attr, m_act], dim=1)

        return embeds, masks


class KnwlEncoder(nn.Module):
    def __init__(self, d_out, num_layers=None, grad_ckpt=True):
        super().__init__()

        self.model = CLIPModel.from_pretrained("laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16).vision_model
        self.model.encoder.gradient_checkpointing = grad_ckpt
        if num_layers is not None:
            self.model.encoder.layers = nn.ModuleList([
                self.model.encoder.layers[i] for i in range(-num_layers, 0)
            ])
        self.fc = nn.Linear(self.model.config.hidden_size, d_out, bias=False)

        self.d = self.model.config.hidden_size

    def forward(self, inputs_embeds, attention_mask):
        embed = self.model.pre_layrnorm(inputs_embeds)
        mask = _expand_mask(attention_mask, embed.dtype)
        embed = self.model.encoder(
            inputs_embeds=embed,
            attention_mask=mask,
            return_dict=True,
        )[0]

        embed = self.fc(self.model.post_layernorm(embed))

        return embed