File size: 2,948 Bytes
57034b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig
from .configuration_bert import SimBertConfig
from torch import nn

class SimBertModel(PreTrainedModel):
    """ SimBert Model
    """

    config_class = SimBertConfig

    def __init__(
            self, 
            config: PretrainedConfig
        ) -> None:
            super().__init__(config)
            self.bert = BertModel(config=config, add_pooling_layer=True)
            self.fc = nn.Linear(config.hidden_size, 2)
            # self.loss_fct = nn.CrossEntropyLoss()
            self.loss_fct = nn.MSELoss()
            self.softmax = nn.Softmax(dim=1)

    def forward(
        self,
        input_ids,
        token_type_ids,
        attention_mask,
        labels=None
    ):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        logits = self.fc(pooled_output)
        logits = self.softmax(logits)[:,1]
        if labels is not None:
            loss = self.loss_fct(logits.view(-1), labels.view(-1))
            return loss, logits
        return None, logits
    
class CosSimBertModel(PreTrainedModel): 
    """ CosSimBert Model
    """

    config_class = SimBertConfig

    def __init__(
            self, 
            config: PretrainedConfig
        ) -> None:
            super().__init__(config)
            self.bert = BertModel(config=config, add_pooling_layer=True)
            self.loss_fct = nn.MSELoss()
            self.softmax = nn.Softmax(dim=1)

    def forward(
        self,
        input_ids,
        token_type_ids,
        attention_mask,
        labels=None
    ):
        seq_length = input_ids.size(-1)
        a = {
            "input_ids": input_ids[:,:seq_length//2],
            "token_type_ids": token_type_ids[:,:seq_length//2],
            "attention_mask": attention_mask[:,:seq_length//2]
        }
        b = {
            "input_ids": input_ids[:,seq_length//2:],
            "token_type_ids": token_type_ids[:,seq_length//2:],
            "attention_mask": attention_mask[:,seq_length//2:]
        }
        outputs_a = self.bert(**a)
        outputs_b = self.bert(**b)
        pooled_a_output = outputs_a.pooler_output
        pooled_b_output = outputs_b.pooler_output
        logits = nn.functional.cosine_similarity(pooled_a_output, pooled_b_output)
        if labels is not None:
            loss = self.loss_fct(logits.view(-1), labels.view(-1))
            return loss, logits
        return None, logits
    
    def encode(
        self,
        input_ids,
        token_type_ids,
        attention_mask,
    ):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        pooled_output = outputs.pooler_output
        return pooled_output