Sifal commited on
Commit
f3edc31
·
verified ·
1 Parent(s): 43433dd

Create automodel.py

Browse files
Files changed (1) hide show
  1. automodel.py +272 -0
automodel.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+ import torch.nn as nn
3
+ import torch
4
+ from utils.bert_layers_mosa import BertModel
5
+ from transformers import BertPreTrainedModel
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
8
+ import logging
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class MosaicBertForSequenceClassification(BertPreTrainedModel):
15
+ """Bert Model transformer with a sequence classification/regression head.
16
+
17
+ This head is just a linear layer on top of the pooled output.
18
+ """
19
+
20
+ def __init__(self, config):
21
+ super().__init__(config)
22
+ self.num_labels = config.num_labels
23
+ self.config = config
24
+ self.bert = BertModel(config, add_pooling_layer=True)
25
+ classifier_dropout = (config.classifier_dropout
26
+ if config.classifier_dropout is not None else
27
+ config.hidden_dropout_prob)
28
+ self.dropout = nn.Dropout(classifier_dropout)
29
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
30
+
31
+ # this resets the weights
32
+ self.post_init()
33
+
34
+
35
+ @classmethod
36
+ def from_pretrained(cls,
37
+ pretrained_checkpoint,
38
+ state_dict=None,
39
+ config=None,
40
+ *inputs,
41
+ **kwargs):
42
+ """Load from pre-trained."""
43
+ # this gets a fresh init model
44
+ model = cls(config, *inputs, **kwargs)
45
+
46
+ # thus we need to load the state_dict
47
+ state_dict = torch.load(pretrained_checkpoint)
48
+ # remove `model` prefix to avoid error
49
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
50
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
51
+ strict=False)
52
+
53
+ if len(missing_keys) > 0:
54
+ logger.warning(
55
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
56
+
57
+ logger.warning(f"the number of which is equal to {len(missing_keys)}"
58
+ )
59
+
60
+ if len(unexpected_keys) > 0:
61
+ logger.warning(
62
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
63
+ )
64
+ logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
65
+
66
+
67
+ return model
68
+
69
+ def forward(
70
+ self,
71
+ input_ids: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ token_type_ids: Optional[torch.Tensor] = None,
74
+ position_ids: Optional[torch.Tensor] = None,
75
+ head_mask: Optional[torch.Tensor] = None,
76
+ inputs_embeds: Optional[torch.Tensor] = None,
77
+ labels: Optional[torch.Tensor] = None,
78
+ output_attentions: Optional[bool] = None,
79
+ output_hidden_states: Optional[bool] = None,
80
+ return_dict: Optional[bool] = None,
81
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
82
+
83
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
84
+
85
+ outputs = self.bert(
86
+ input_ids,
87
+ attention_mask=attention_mask,
88
+ token_type_ids=token_type_ids,
89
+ position_ids=position_ids,
90
+ head_mask=head_mask,
91
+ inputs_embeds=inputs_embeds,
92
+ output_attentions=output_attentions,
93
+ output_hidden_states=output_hidden_states,
94
+ return_dict=return_dict,
95
+ )
96
+
97
+ pooled_output = outputs[1]
98
+
99
+ pooled_output = self.dropout(pooled_output)
100
+ logits = self.classifier(pooled_output)
101
+
102
+ loss = None
103
+ if labels is not None:
104
+ if self.config.problem_type is None:
105
+ if self.num_labels == 1:
106
+ self.config.problem_type = "regression"
107
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
108
+ self.config.problem_type = "single_label_classification"
109
+ else:
110
+ self.config.problem_type = "multi_label_classification"
111
+
112
+ if self.config.problem_type == "regression":
113
+ loss_fct = MSELoss()
114
+ if self.num_labels == 1:
115
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
116
+ else:
117
+ loss = loss_fct(logits, labels)
118
+ elif self.config.problem_type == "single_label_classification":
119
+ loss_fct = CrossEntropyLoss()
120
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
121
+ elif self.config.problem_type == "multi_label_classification":
122
+ loss_fct = BCEWithLogitsLoss()
123
+ loss = loss_fct(logits, labels)
124
+ if not return_dict:
125
+ output = (logits,) + outputs[2:]
126
+ return ((loss,) + output) if loss is not None else output
127
+
128
+ return SequenceClassifierOutput(
129
+ loss=loss,
130
+ logits=logits,
131
+ hidden_states=None,
132
+ attentions=None,)
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+ from typing import Optional
148
+ import torch.nn as nn
149
+ import torch
150
+ from utils.bert_layers_mosa import BertModel
151
+ from transformers import BertPreTrainedModel
152
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
153
+ import logging
154
+
155
+ logger = logging.getLogger(__name__)
156
+
157
+ class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
158
+
159
+ def __init__(self, config, add_pooling_layer=False):
160
+ """
161
+ Initializes the BertEmbeddings class.
162
+
163
+ Args:
164
+ config (BertConfig): The configuration for the BERT model.
165
+ add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
166
+ """
167
+ super().__init__(config)
168
+ assert config.num_hidden_layers >= config.num_embedding_layers, 'num_hidden_layers should be greater than or equal to num_embedding_layers'
169
+ self.config = config
170
+ self.strategy = config.strategy
171
+ self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
172
+ # this resets the weights
173
+ self.post_init()
174
+
175
+
176
+ @classmethod
177
+ def from_pretrained(cls,
178
+ pretrained_checkpoint,
179
+ state_dict=None,
180
+ config=None,
181
+ *inputs,
182
+ **kwargs):
183
+ """Load from pre-trained."""
184
+ # this gets a fresh init model
185
+ model = cls(config, *inputs, **kwargs)
186
+
187
+ # thus we need to load the state_dict
188
+ state_dict = torch.load(pretrained_checkpoint)
189
+ # remove `model` prefix to avoid error
190
+ consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
191
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict,
192
+ strict=False)
193
+
194
+ if len(missing_keys) > 0:
195
+ logger.warning(
196
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
197
+
198
+ logger.warning(f"the number of which is equal to {len(missing_keys)}"
199
+ )
200
+
201
+ if len(unexpected_keys) > 0:
202
+ logger.warning(
203
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
204
+ )
205
+ logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
206
+
207
+
208
+ return model
209
+
210
+ def forward(
211
+ self,
212
+ input_ids: Optional[torch.Tensor] = None,
213
+ attention_mask: Optional[torch.Tensor] = None,
214
+ token_type_ids: Optional[torch.Tensor] = None,
215
+ position_ids: Optional[torch.Tensor] = None,
216
+ subset_mask : Optional[torch.Tensor] = None,
217
+ hospital_ids_lens: list = None,
218
+ ) -> torch.Tensor:
219
+
220
+ embedding_output = self.bert.embeddings(input_ids, token_type_ids,
221
+ position_ids)
222
+
223
+ encoder_outputs_all = self.bert.encoder(
224
+ embedding_output,
225
+ attention_mask,
226
+ output_all_encoded_layers=True,
227
+ subset_mask=subset_mask)
228
+
229
+ # batch_size, hidden_dim
230
+ return self.get_embeddings(encoder_outputs_all, hospital_ids_lens, self.config.num_embedding_layers, self.config.strategy)
231
+
232
+ def get_embeddings(self, encoder_outputs_all, hospital_ids_lens, num_layers, strategy):
233
+
234
+ batch_embeddings = []
235
+ start_idx = 0
236
+
237
+ # num_layer (we use default = 4), batch_size (concatenated visits), seq_len (clinical note sequences), hidden_dim.
238
+ # average across num_layers and seq_len
239
+ if strategy == 'mean':
240
+ # batch_size (concatenated visits), hidden_dim.
241
+ sentence_representation = torch.stack(encoder_outputs_all[-num_layers:]).mean(dim=[0, 2])
242
+
243
+ for length in hospital_ids_lens:
244
+ # We then average across visits
245
+ # batch_size (true batch size), hidden_dim.
246
+ batch_embeddings.append(torch.mean(sentence_representation[start_idx:start_idx + length],dim=0))
247
+ start_idx += length
248
+
249
+ return torch.stack(batch_embeddings)
250
+
251
+ elif strategy == 'concat':
252
+ # num_layer, batch_size (concatenated visits), hidden_dim.
253
+ sentence_representation = torch.stack(encoder_outputs_all[-num_layers:]).mean(dim=2)
254
+
255
+ for length in hospital_ids_lens:
256
+ # We then average across visits
257
+ # num_layer, batch_size (true batch size), hidden_dim.
258
+ batch_embeddings.append(torch.mean(sentence_representation[:,start_idx:start_idx + length],dim=1))
259
+ start_idx += length
260
+
261
+ return torch.stack(batch_embeddings)
262
+
263
+ elif strategy == 'all':
264
+ # num_layer, batch_size (concatenated visits), seq_len (clinical note sequences), hidden_dim.
265
+ sentence_representation = torch.stack(encoder_outputs_all[-num_layers:])
266
+ return sentence_representation
267
+ else:
268
+ raise ValueError(f'{strategy} is not a valid strategy, choose between mean and concat')
269
+
270
+
271
+
272
+