KenyaNonaka0210 commited on
Commit
778b7b6
·
1 Parent(s): dc75fed
Files changed (2) hide show
  1. configuration_luxe.py +53 -0
  2. modeling_luxe.py +250 -0
configuration_luxe.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class LuxeConfig(PretrainedConfig):
5
+ model_type = "luxe"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size=50267,
10
+ entity_vocab_size=500000,
11
+ num_category_entities=0,
12
+ hidden_size=768,
13
+ entity_emb_size=256,
14
+ num_hidden_layers=12,
15
+ num_attention_heads=12,
16
+ intermediate_size=3072,
17
+ hidden_act="gelu",
18
+ hidden_dropout_prob=0.1,
19
+ attention_probs_dropout_prob=0.1,
20
+ max_position_embeddings=512,
21
+ type_vocab_size=2,
22
+ initializer_range=0.02,
23
+ layer_norm_eps=1e-12,
24
+ use_entity_aware_attention=True,
25
+ classifier_dropout=None,
26
+ normalize_entity_embeddings=False,
27
+ entity_temperature=1.0,
28
+ pad_token_id=1,
29
+ bos_token_id=0,
30
+ eos_token_id=2,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
34
+
35
+ self.vocab_size = vocab_size
36
+ self.entity_vocab_size = entity_vocab_size
37
+ self.num_category_entities = num_category_entities
38
+ self.hidden_size = hidden_size
39
+ self.entity_emb_size = entity_emb_size
40
+ self.num_hidden_layers = num_hidden_layers
41
+ self.num_attention_heads = num_attention_heads
42
+ self.hidden_act = hidden_act
43
+ self.intermediate_size = intermediate_size
44
+ self.hidden_dropout_prob = hidden_dropout_prob
45
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
46
+ self.max_position_embeddings = max_position_embeddings
47
+ self.type_vocab_size = type_vocab_size
48
+ self.initializer_range = initializer_range
49
+ self.layer_norm_eps = layer_norm_eps
50
+ self.use_entity_aware_attention = use_entity_aware_attention
51
+ self.classifier_dropout = classifier_dropout
52
+ self.normalize_entity_embeddings = normalize_entity_embeddings
53
+ self.entity_temperature = entity_temperature
modeling_luxe.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers.modeling_utils import PreTrainedModel
7
+ from transformers.models.luke.modeling_luke import (
8
+ EntityPredictionHead,
9
+ LukeLMHead,
10
+ LukeModel,
11
+ )
12
+ from transformers.utils import ModelOutput
13
+
14
+ from configuration_luxe import LuxeConfig
15
+
16
+
17
+ @dataclass
18
+ class LuxeMaskedLMOutput(ModelOutput):
19
+ loss: Optional[torch.FloatTensor] = None
20
+ mlm_loss: Optional[torch.FloatTensor] = None
21
+ mep_loss: Optional[torch.FloatTensor] = None
22
+ tep_loss: Optional[torch.FloatTensor] = None
23
+ tcp_loss: Optional[torch.FloatTensor] = None
24
+ logits: torch.FloatTensor = None
25
+ entity_logits: Optional[torch.FloatTensor] = None
26
+ topic_entity_logits: torch.FloatTensor = None
27
+ topic_category_logits: Optional[torch.FloatTensor] = None
28
+ last_hidden_state: torch.FloatTensor = None
29
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
30
+ entity_last_hidden_state: torch.FloatTensor = None
31
+ entity_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
32
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
33
+
34
+
35
+ class LuxePreTrainedModel(PreTrainedModel):
36
+ config_class = LuxeConfig
37
+ base_model_prefix = "luke"
38
+ supports_gradient_checkpointing = True
39
+ _no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
40
+
41
+ def _init_weights(self, module: nn.Module):
42
+ if isinstance(module, nn.Linear):
43
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
44
+ if module.bias is not None:
45
+ module.bias.data.zero_()
46
+ elif isinstance(module, nn.Embedding):
47
+ if module.embedding_dim == 1: # embedding for bias parameters
48
+ module.weight.data.zero_()
49
+ else:
50
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
51
+ if module.padding_idx is not None:
52
+ module.weight.data[module.padding_idx].zero_()
53
+ elif isinstance(module, nn.LayerNorm):
54
+ module.bias.data.zero_()
55
+ module.weight.data.fill_(1.0)
56
+
57
+
58
+ class LuxeForMaskedLM(LuxePreTrainedModel):
59
+ _tied_weights_keys = [
60
+ "lm_head.decoder.weight",
61
+ "lm_head.decoder.bias",
62
+ "entity_predictions.decoder.weight",
63
+ ]
64
+
65
+ def __init__(self, config: LuxeConfig):
66
+ super().__init__(config)
67
+
68
+ self.luke = LukeModel(config)
69
+
70
+ if self.config.normalize_entity_embeddings:
71
+ self.luke.entity_embeddings.entity_embeddings = nn.Embedding(
72
+ config.entity_vocab_size,
73
+ config.entity_emb_size,
74
+ padding_idx=0,
75
+ max_norm=1.0,
76
+ )
77
+
78
+ self.lm_head = LukeLMHead(config)
79
+ self.entity_predictions = EntityPredictionHead(config)
80
+
81
+ self.loss_fn = nn.CrossEntropyLoss()
82
+
83
+ # Initialize weights and apply final processing
84
+ self.post_init()
85
+
86
+ def tie_weights(self):
87
+ super().tie_weights()
88
+ self._tie_or_clone_weights(
89
+ self.entity_predictions.decoder,
90
+ self.luke.entity_embeddings.entity_embeddings,
91
+ )
92
+
93
+ def get_output_embeddings(self) -> nn.Module:
94
+ return self.lm_head.decoder
95
+
96
+ def set_output_embeddings(self, new_embeddings: nn.Module):
97
+ self.lm_head.decoder = new_embeddings
98
+
99
+ def forward(
100
+ self,
101
+ input_ids: Optional[torch.LongTensor] = None,
102
+ attention_mask: Optional[torch.FloatTensor] = None,
103
+ token_type_ids: Optional[torch.LongTensor] = None,
104
+ position_ids: Optional[torch.LongTensor] = None,
105
+ entity_ids: Optional[torch.LongTensor] = None,
106
+ entity_attention_mask: Optional[torch.LongTensor] = None,
107
+ entity_token_type_ids: Optional[torch.LongTensor] = None,
108
+ entity_position_ids: Optional[torch.LongTensor] = None,
109
+ labels: Optional[torch.LongTensor] = None,
110
+ entity_labels: Optional[torch.LongTensor] = None,
111
+ topic_entity_labels: Optional[torch.LongTensor] = None,
112
+ head_mask: Optional[torch.FloatTensor] = None,
113
+ inputs_embeds: Optional[torch.FloatTensor] = None,
114
+ output_attentions: Optional[bool] = None,
115
+ output_hidden_states: Optional[bool] = None,
116
+ return_dict: Optional[bool] = None,
117
+ ) -> Union[Tuple, LuxeMaskedLMOutput]:
118
+ return_dict = (
119
+ return_dict if return_dict is not None else self.config.use_return_dict
120
+ )
121
+
122
+ outputs = self.luke(
123
+ input_ids=input_ids,
124
+ attention_mask=attention_mask,
125
+ token_type_ids=token_type_ids,
126
+ position_ids=position_ids,
127
+ entity_ids=entity_ids,
128
+ entity_attention_mask=entity_attention_mask,
129
+ entity_token_type_ids=entity_token_type_ids,
130
+ entity_position_ids=entity_position_ids,
131
+ head_mask=head_mask,
132
+ inputs_embeds=inputs_embeds,
133
+ output_attentions=output_attentions,
134
+ output_hidden_states=output_hidden_states,
135
+ return_dict=True,
136
+ )
137
+
138
+ loss = None
139
+
140
+ mlm_loss = None
141
+ logits = self.lm_head(outputs.last_hidden_state)
142
+ if labels is not None:
143
+ labels = labels.to(logits.device)
144
+ mlm_loss = self.loss_fn(
145
+ logits.view(-1, self.config.vocab_size), labels.view(-1)
146
+ )
147
+ if loss is None:
148
+ loss = mlm_loss
149
+
150
+ mep_loss = None
151
+ entity_logits = None
152
+ if outputs.entity_last_hidden_state is not None:
153
+ entity_logits = self.entity_predictions(outputs.entity_last_hidden_state)
154
+ if entity_labels is not None:
155
+ mep_loss = self.loss_fn(
156
+ entity_logits.view(-1, self.config.entity_vocab_size)
157
+ / self.config.entity_temperature,
158
+ entity_labels.view(-1),
159
+ )
160
+ if loss is None:
161
+ loss = mep_loss
162
+ else:
163
+ loss = loss + mep_loss
164
+
165
+ topic_entity_logits = self.entity_predictions(outputs.last_hidden_state[:, 0])
166
+ topic_category_logits = None
167
+ if self.config.num_category_entities > 0:
168
+ topic_category_logits = topic_entity_logits[
169
+ :, -self.config.num_category_entities :
170
+ ]
171
+ topic_entity_logits = topic_entity_logits[
172
+ :, : -self.config.num_category_entities
173
+ ]
174
+
175
+ topic_category_labels = None
176
+ if topic_entity_labels is not None and self.config.num_category_entities > 0:
177
+ topic_category_labels = topic_entity_labels[
178
+ :, -self.config.num_category_entities :
179
+ ]
180
+ topic_entity_labels = topic_entity_labels[
181
+ :, : -self.config.num_category_entities
182
+ ]
183
+
184
+ tep_loss = None
185
+ if topic_entity_labels is not None:
186
+ num_topic_entity_labels = topic_entity_labels.sum(dim=1)
187
+ if (num_topic_entity_labels > 0).any():
188
+ topic_entity_labels = topic_entity_labels.to(
189
+ topic_entity_logits.dtype
190
+ ) / num_topic_entity_labels.unsqueeze(-1)
191
+ tep_loss = self.loss_fn(
192
+ topic_entity_logits[num_topic_entity_labels > 0]
193
+ / self.config.entity_temperature,
194
+ topic_entity_labels[num_topic_entity_labels > 0],
195
+ )
196
+ if loss is None:
197
+ loss = tep_loss
198
+ else:
199
+ loss = loss + tep_loss
200
+
201
+ tcp_loss = None
202
+ if topic_category_labels is not None:
203
+ num_topic_category_labels = topic_category_labels.sum(dim=1)
204
+ if (num_topic_category_labels > 0).any():
205
+ topic_category_labels = topic_category_labels.to(
206
+ topic_category_logits.dtype
207
+ ) / num_topic_category_labels.unsqueeze(-1)
208
+ tcp_loss = self.loss_fn(
209
+ topic_category_logits[num_topic_category_labels > 0]
210
+ / self.config.entity_temperature,
211
+ topic_category_labels[num_topic_category_labels > 0],
212
+ )
213
+ if loss is None:
214
+ loss = tcp_loss
215
+ else:
216
+ loss = loss + tcp_loss
217
+
218
+ if not return_dict:
219
+ return tuple(
220
+ v
221
+ for v in [
222
+ logits,
223
+ entity_logits,
224
+ topic_entity_logits,
225
+ topic_category_logits,
226
+ outputs.last_hidden_state,
227
+ outputs.entity_last_hidden_state,
228
+ outputs.hidden_states,
229
+ outputs.entity_hidden_states,
230
+ outputs.attentions,
231
+ ]
232
+ if v is not None
233
+ )
234
+
235
+ return LuxeMaskedLMOutput(
236
+ loss=loss,
237
+ mlm_loss=mlm_loss,
238
+ mep_loss=mep_loss,
239
+ tep_loss=tep_loss,
240
+ tcp_loss=tcp_loss,
241
+ logits=logits,
242
+ entity_logits=entity_logits,
243
+ topic_entity_logits=topic_entity_logits,
244
+ topic_category_logits=topic_category_logits,
245
+ last_hidden_state=outputs.last_hidden_state,
246
+ hidden_states=outputs.hidden_states,
247
+ entity_last_hidden_state=outputs.entity_last_hidden_state,
248
+ entity_hidden_states=outputs.entity_hidden_states,
249
+ attentions=outputs.attentions,
250
+ )