Sifal commited on
Commit
133928a
·
verified ·
1 Parent(s): 4f1265c

Update automodel.py

Browse files
Files changed (1) hide show
  1. automodel.py +101 -168
automodel.py CHANGED
@@ -1,19 +1,91 @@
 
1
  from typing import Optional, Tuple, Union
2
- import torch.nn as nn
3
  import torch
4
- from utils.bert_layers_mosa import BertModel
 
 
5
  from transformers import BertPreTrainedModel
6
  from transformers.modeling_outputs import SequenceClassifierOutput
7
- from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
8
- import logging
9
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class MosaicBertForSequenceClassification(BertPreTrainedModel):
15
  """Bert Model transformer with a sequence classification/regression head.
16
-
17
  This head is just a linear layer on top of the pooled output.
18
  """
19
 
@@ -22,48 +94,44 @@ class MosaicBertForSequenceClassification(BertPreTrainedModel):
22
  self.num_labels = config.num_labels
23
  self.config = config
24
  self.bert = BertModel(config, add_pooling_layer=True)
25
- classifier_dropout = (config.classifier_dropout
26
- if config.classifier_dropout is not None else
27
- config.hidden_dropout_prob)
 
 
28
  self.dropout = nn.Dropout(classifier_dropout)
29
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
30
 
31
- # this resets the weights
32
  self.post_init()
33
 
34
-
35
  @classmethod
36
- def from_pretrained(cls,
37
- pretrained_checkpoint,
38
- state_dict=None,
39
- config=None,
40
- *inputs,
41
- **kwargs):
42
  """Load from pre-trained."""
43
  # this gets a fresh init model
44
  model = cls(config, *inputs, **kwargs)
45
-
46
  # thus we need to load the state_dict
47
  state_dict = torch.load(pretrained_checkpoint)
48
  # remove `model` prefix to avoid error
49
- consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
50
- missing_keys, unexpected_keys = model.load_state_dict(state_dict,
51
- strict=False)
52
 
53
  if len(missing_keys) > 0:
54
  logger.warning(
55
- f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
56
-
57
- logger.warning(f"the number of which is equal to {len(missing_keys)}"
58
  )
59
 
 
 
60
  if len(unexpected_keys) > 0:
61
  logger.warning(
62
  f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
63
  )
64
  logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
65
 
66
-
67
  return model
68
 
69
  def forward(
@@ -80,7 +148,9 @@ class MosaicBertForSequenceClassification(BertPreTrainedModel):
80
  return_dict: Optional[bool] = None,
81
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
82
 
83
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
84
 
85
  outputs = self.bert(
86
  input_ids,
@@ -104,7 +174,9 @@ class MosaicBertForSequenceClassification(BertPreTrainedModel):
104
  if self.config.problem_type is None:
105
  if self.num_labels == 1:
106
  self.config.problem_type = "regression"
107
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
 
 
108
  self.config.problem_type = "single_label_classification"
109
  else:
110
  self.config.problem_type = "multi_label_classification"
@@ -129,144 +201,5 @@ class MosaicBertForSequenceClassification(BertPreTrainedModel):
129
  loss=loss,
130
  logits=logits,
131
  hidden_states=None,
132
- attentions=None,)
133
-
134
-
135
-
136
-
137
-
138
-
139
-
140
-
141
-
142
-
143
-
144
-
145
-
146
-
147
- from typing import Optional
148
- import torch.nn as nn
149
- import torch
150
- from utils.bert_layers_mosa import BertModel
151
- from transformers import BertPreTrainedModel
152
- from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
153
- import logging
154
-
155
- logger = logging.getLogger(__name__)
156
-
157
- class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
158
-
159
- def __init__(self, config, add_pooling_layer=False):
160
- """
161
- Initializes the BertEmbeddings class.
162
-
163
- Args:
164
- config (BertConfig): The configuration for the BERT model.
165
- add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
166
- """
167
- super().__init__(config)
168
- assert config.num_hidden_layers >= config.num_embedding_layers, 'num_hidden_layers should be greater than or equal to num_embedding_layers'
169
- self.config = config
170
- self.strategy = config.strategy
171
- self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
172
- # this resets the weights
173
- self.post_init()
174
-
175
-
176
- @classmethod
177
- def from_pretrained(cls,
178
- pretrained_checkpoint,
179
- state_dict=None,
180
- config=None,
181
- *inputs,
182
- **kwargs):
183
- """Load from pre-trained."""
184
- # this gets a fresh init model
185
- model = cls(config, *inputs, **kwargs)
186
-
187
- # thus we need to load the state_dict
188
- state_dict = torch.load(pretrained_checkpoint)
189
- # remove `model` prefix to avoid error
190
- consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
191
- missing_keys, unexpected_keys = model.load_state_dict(state_dict,
192
- strict=False)
193
-
194
- if len(missing_keys) > 0:
195
- logger.warning(
196
- f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
197
-
198
- logger.warning(f"the number of which is equal to {len(missing_keys)}"
199
- )
200
-
201
- if len(unexpected_keys) > 0:
202
- logger.warning(
203
- f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
204
- )
205
- logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
206
-
207
-
208
- return model
209
-
210
- def forward(
211
- self,
212
- input_ids: Optional[torch.Tensor] = None,
213
- attention_mask: Optional[torch.Tensor] = None,
214
- token_type_ids: Optional[torch.Tensor] = None,
215
- position_ids: Optional[torch.Tensor] = None,
216
- subset_mask : Optional[torch.Tensor] = None,
217
- hospital_ids_lens: list = None,
218
- ) -> torch.Tensor:
219
-
220
- embedding_output = self.bert.embeddings(input_ids, token_type_ids,
221
- position_ids)
222
-
223
- encoder_outputs_all = self.bert.encoder(
224
- embedding_output,
225
- attention_mask,
226
- output_all_encoded_layers=True,
227
- subset_mask=subset_mask)
228
-
229
- # batch_size, hidden_dim
230
- return self.get_embeddings(encoder_outputs_all, hospital_ids_lens, self.config.num_embedding_layers, self.config.strategy)
231
-
232
- def get_embeddings(self, encoder_outputs_all, hospital_ids_lens, num_layers, strategy):
233
-
234
- batch_embeddings = []
235
- start_idx = 0
236
-
237
- # num_layer (we use default = 4), batch_size (concatenated visits), seq_len (clinical note sequences), hidden_dim.
238
- # average across num_layers and seq_len
239
- if strategy == 'mean':
240
- # batch_size (concatenated visits), hidden_dim.
241
- sentence_representation = torch.stack(encoder_outputs_all[-num_layers:]).mean(dim=[0, 2])
242
-
243
- for length in hospital_ids_lens:
244
- # We then average across visits
245
- # batch_size (true batch size), hidden_dim.
246
- batch_embeddings.append(torch.mean(sentence_representation[start_idx:start_idx + length],dim=0))
247
- start_idx += length
248
-
249
- return torch.stack(batch_embeddings)
250
-
251
- elif strategy == 'concat':
252
- # num_layer, batch_size (concatenated visits), hidden_dim.
253
- sentence_representation = torch.stack(encoder_outputs_all[-num_layers:]).mean(dim=2)
254
-
255
- for length in hospital_ids_lens:
256
- # We then average across visits
257
- # num_layer, batch_size (true batch size), hidden_dim.
258
- batch_embeddings.append(torch.mean(sentence_representation[:,start_idx:start_idx + length],dim=1))
259
- start_idx += length
260
-
261
- return torch.stack(batch_embeddings)
262
-
263
- elif strategy == 'all':
264
- # num_layer, batch_size (concatenated visits), seq_len (clinical note sequences), hidden_dim.
265
- sentence_representation = torch.stack(encoder_outputs_all[-num_layers:])
266
- return sentence_representation
267
- else:
268
- raise ValueError(f'{strategy} is not a valid strategy, choose between mean and concat')
269
-
270
-
271
-
272
-
 
1
+ import logging
2
  from typing import Optional, Tuple, Union
3
+
4
  import torch
5
+ import torch.nn as nn
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
8
  from transformers import BertPreTrainedModel
9
  from transformers.modeling_outputs import SequenceClassifierOutput
 
 
 
10
 
11
+ from bert_layers_mosa import BertModel
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+
16
+ class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
17
+
18
+ def __init__(self, config, add_pooling_layer=False):
19
+ """
20
+ Initializes the BertEmbeddings class.
21
+
22
+ Args:
23
+ config (BertConfig): The configuration for the BERT model.
24
+ add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
25
+ """
26
+ super().__init__(config)
27
+ assert (
28
+ config.num_hidden_layers >= config.num_embedding_layers
29
+ ), "num_hidden_layers should be greater than or equal to num_embedding_layers"
30
+ self.config = config
31
+ self.strategy = config.strategy
32
+ self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
33
+ # this resets the weights
34
+ self.post_init()
35
+
36
+ @classmethod
37
+ def from_pretrained(
38
+ cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
39
+ ):
40
+ """Load from pre-trained."""
41
+ # this gets a fresh init model
42
+ model = cls(config, *inputs, **kwargs)
43
+
44
+ # thus we need to load the state_dict
45
+ state_dict = torch.load(pretrained_checkpoint)
46
+ # remove `model` prefix to avoid error
47
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
48
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
49
+
50
+ if len(missing_keys) > 0:
51
+ logger.warning(
52
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
53
+ )
54
+
55
+ logger.warning(f"the number of which is equal to {len(missing_keys)}")
56
+
57
+ if len(unexpected_keys) > 0:
58
+ logger.warning(
59
+ f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
60
+ )
61
+ logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
62
+
63
+ return model
64
+
65
+ def forward(
66
+ self,
67
+ input_ids: Optional[torch.Tensor] = None,
68
+ attention_mask: Optional[torch.Tensor] = None,
69
+ token_type_ids: Optional[torch.Tensor] = None,
70
+ position_ids: Optional[torch.Tensor] = None,
71
+ subset_mask: Optional[torch.Tensor] = None,
72
+ output_all_encoded_layers: Book = True,
73
+ ) -> torch.Tensor:
74
+
75
+ embedding_output = self.bert.embeddings(input_ids, token_type_ids, position_ids)
76
+
77
+ encoder_outputs_all = self.bert.encoder(
78
+ embedding_output,
79
+ attention_mask,
80
+ output_all_encoded_layers=output_all_encoded_layers,
81
+ subset_mask=subset_mask,
82
+ )
83
+
84
+ # batch_size, hidden_dim
85
+ return encoder_outputs_all
86
+
87
  class MosaicBertForSequenceClassification(BertPreTrainedModel):
88
  """Bert Model transformer with a sequence classification/regression head.
 
89
  This head is just a linear layer on top of the pooled output.
90
  """
91
 
 
94
  self.num_labels = config.num_labels
95
  self.config = config
96
  self.bert = BertModel(config, add_pooling_layer=True)
97
+ classifier_dropout = (
98
+ config.classifier_dropout
99
+ if config.classifier_dropout is not None
100
+ else config.hidden_dropout_prob
101
+ )
102
  self.dropout = nn.Dropout(classifier_dropout)
103
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
104
 
105
+ # this resets the weights
106
  self.post_init()
107
 
 
108
  @classmethod
109
+ def from_pretrained(
110
+ cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
111
+ ):
 
 
 
112
  """Load from pre-trained."""
113
  # this gets a fresh init model
114
  model = cls(config, *inputs, **kwargs)
115
+
116
  # thus we need to load the state_dict
117
  state_dict = torch.load(pretrained_checkpoint)
118
  # remove `model` prefix to avoid error
119
+ consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
120
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
 
121
 
122
  if len(missing_keys) > 0:
123
  logger.warning(
124
+ f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
 
 
125
  )
126
 
127
+ logger.warning(f"the number of which is equal to {len(missing_keys)}")
128
+
129
  if len(unexpected_keys) > 0:
130
  logger.warning(
131
  f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
132
  )
133
  logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
134
 
 
135
  return model
136
 
137
  def forward(
 
148
  return_dict: Optional[bool] = None,
149
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
150
 
151
+ return_dict = (
152
+ return_dict if return_dict is not None else self.config.use_return_dict
153
+ )
154
 
155
  outputs = self.bert(
156
  input_ids,
 
174
  if self.config.problem_type is None:
175
  if self.num_labels == 1:
176
  self.config.problem_type = "regression"
177
+ elif self.num_labels > 1 and (
178
+ labels.dtype == torch.long or labels.dtype == torch.int
179
+ ):
180
  self.config.problem_type = "single_label_classification"
181
  else:
182
  self.config.problem_type = "multi_label_classification"
 
201
  loss=loss,
202
  logits=logits,
203
  hidden_states=None,
204
+ attentions=None,
205
+ )