File size: 4,243 Bytes
e4bd7f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
import os

import torch
from torch import nn, Tensor

from bubogpt.common.dist_utils import download_cached_file
from bubogpt.common.utils import is_url
from bubogpt.models.Qformer import BertConfig, BertLMHeadModel


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class BaseQFormer(nn.Module):
    def __init__(self, freeze_qformer=False):
        super().__init__()
        self.freeze_qformer = freeze_qformer
        self.Qformer = None

    def check_and_freeze(self):
        assert self.Qformer is not None
        if self.freeze_qformer:
            for name, param in self.Qformer.named_parameters():
                param.requires_grad = False
            self.Qformer = self.Qformer.eval()
            self.Qformer.train = disabled_train
            self.query_tokens.requires_grad = False
            logging.info("Freeze This QFormer")

    def load_from_pretrained(self, url_or_filename):
        if is_url(url_or_filename):
            cached_file = download_cached_file(
                url_or_filename, check_hash=False, progress=True
            )
            checkpoint = torch.load(cached_file, map_location="cpu")
        elif os.path.isfile(url_or_filename):
            checkpoint = torch.load(url_or_filename, map_location="cpu")
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        state_dict = checkpoint["model"]

        msg = self.load_state_dict(state_dict, strict=False)

        logging.info("Missing keys {}".format(msg.missing_keys))
        logging.info("load checkpoint from %s" % url_or_filename)

        return msg


class SequenceGenericQFormer(BaseQFormer):
    def __init__(self,
                 num_query_token: int,
                 encoder_width: int = 768,
                 freeze_qformer: bool = False,
                 q_former_model: str = "",
                 cross_attention_freq: int = 2
                 ):
        super().__init__(freeze_qformer)
        self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, encoder_width, cross_attention_freq)
        if q_former_model != "":
            self.load_Qformer(q_former_model)
        self.check_and_freeze()

    def set_Qformer(self):
        self.Qformer.cls = None
        self.Qformer.bert.embeddings.word_embeddings = None
        self.Qformer.bert.embeddings.position_embeddings = None
        for layer in self.Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None

    def load_Qformer(self, q_former_model):
        self.Qformer.cls = None
        self.Qformer.bert.embeddings.word_embeddings = None
        self.Qformer.bert.embeddings.position_embeddings = None
        for layer in self.Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None
        self.load_from_pretrained(url_or_filename=q_former_model)

    @classmethod
    def init_Qformer(cls, num_query_token, encoder_width, cross_attention_freq=2):
        encoder_config = BertConfig.from_pretrained("bert-base-uncased")
        encoder_config.encoder_width = encoder_width
        # insert cross-attention layer every other block
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        Qformer = BertLMHeadModel(config=encoder_config)
        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
        return Qformer, query_tokens

    def forward(self, input_embeds: Tensor) -> Tensor:
        input_atts = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
        query_tokens = self.query_tokens.expand(input_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=input_embeds,
            encoder_attention_mask=input_atts,
            return_dict=True,
        )
        return query_output.last_hidden_state