Update automodel.py
Browse files- automodel.py +101 -168
automodel.py
CHANGED
@@ -1,19 +1,91 @@
|
|
|
|
1 |
from typing import Optional, Tuple, Union
|
2 |
-
|
3 |
import torch
|
4 |
-
|
|
|
|
|
5 |
from transformers import BertPreTrainedModel
|
6 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
7 |
-
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
8 |
-
import logging
|
9 |
-
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
10 |
|
|
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
class MosaicBertForSequenceClassification(BertPreTrainedModel):
|
15 |
"""Bert Model transformer with a sequence classification/regression head.
|
16 |
-
|
17 |
This head is just a linear layer on top of the pooled output.
|
18 |
"""
|
19 |
|
@@ -22,48 +94,44 @@ class MosaicBertForSequenceClassification(BertPreTrainedModel):
|
|
22 |
self.num_labels = config.num_labels
|
23 |
self.config = config
|
24 |
self.bert = BertModel(config, add_pooling_layer=True)
|
25 |
-
classifier_dropout = (
|
26 |
-
|
27 |
-
|
|
|
|
|
28 |
self.dropout = nn.Dropout(classifier_dropout)
|
29 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
30 |
|
31 |
-
|
32 |
self.post_init()
|
33 |
|
34 |
-
|
35 |
@classmethod
|
36 |
-
def from_pretrained(
|
37 |
-
|
38 |
-
|
39 |
-
config=None,
|
40 |
-
*inputs,
|
41 |
-
**kwargs):
|
42 |
"""Load from pre-trained."""
|
43 |
# this gets a fresh init model
|
44 |
model = cls(config, *inputs, **kwargs)
|
45 |
-
|
46 |
# thus we need to load the state_dict
|
47 |
state_dict = torch.load(pretrained_checkpoint)
|
48 |
# remove `model` prefix to avoid error
|
49 |
-
consume_prefix_in_state_dict_if_present(state_dict, prefix=
|
50 |
-
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
51 |
-
strict=False)
|
52 |
|
53 |
if len(missing_keys) > 0:
|
54 |
logger.warning(
|
55 |
-
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
56 |
-
|
57 |
-
logger.warning(f"the number of which is equal to {len(missing_keys)}"
|
58 |
)
|
59 |
|
|
|
|
|
60 |
if len(unexpected_keys) > 0:
|
61 |
logger.warning(
|
62 |
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
|
63 |
)
|
64 |
logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
|
65 |
|
66 |
-
|
67 |
return model
|
68 |
|
69 |
def forward(
|
@@ -80,7 +148,9 @@ class MosaicBertForSequenceClassification(BertPreTrainedModel):
|
|
80 |
return_dict: Optional[bool] = None,
|
81 |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
82 |
|
83 |
-
return_dict =
|
|
|
|
|
84 |
|
85 |
outputs = self.bert(
|
86 |
input_ids,
|
@@ -104,7 +174,9 @@ class MosaicBertForSequenceClassification(BertPreTrainedModel):
|
|
104 |
if self.config.problem_type is None:
|
105 |
if self.num_labels == 1:
|
106 |
self.config.problem_type = "regression"
|
107 |
-
elif self.num_labels > 1 and (
|
|
|
|
|
108 |
self.config.problem_type = "single_label_classification"
|
109 |
else:
|
110 |
self.config.problem_type = "multi_label_classification"
|
@@ -129,144 +201,5 @@ class MosaicBertForSequenceClassification(BertPreTrainedModel):
|
|
129 |
loss=loss,
|
130 |
logits=logits,
|
131 |
hidden_states=None,
|
132 |
-
attentions=None,
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
from typing import Optional
|
148 |
-
import torch.nn as nn
|
149 |
-
import torch
|
150 |
-
from utils.bert_layers_mosa import BertModel
|
151 |
-
from transformers import BertPreTrainedModel
|
152 |
-
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
153 |
-
import logging
|
154 |
-
|
155 |
-
logger = logging.getLogger(__name__)
|
156 |
-
|
157 |
-
class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
|
158 |
-
|
159 |
-
def __init__(self, config, add_pooling_layer=False):
|
160 |
-
"""
|
161 |
-
Initializes the BertEmbeddings class.
|
162 |
-
|
163 |
-
Args:
|
164 |
-
config (BertConfig): The configuration for the BERT model.
|
165 |
-
add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
|
166 |
-
"""
|
167 |
-
super().__init__(config)
|
168 |
-
assert config.num_hidden_layers >= config.num_embedding_layers, 'num_hidden_layers should be greater than or equal to num_embedding_layers'
|
169 |
-
self.config = config
|
170 |
-
self.strategy = config.strategy
|
171 |
-
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
172 |
-
# this resets the weights
|
173 |
-
self.post_init()
|
174 |
-
|
175 |
-
|
176 |
-
@classmethod
|
177 |
-
def from_pretrained(cls,
|
178 |
-
pretrained_checkpoint,
|
179 |
-
state_dict=None,
|
180 |
-
config=None,
|
181 |
-
*inputs,
|
182 |
-
**kwargs):
|
183 |
-
"""Load from pre-trained."""
|
184 |
-
# this gets a fresh init model
|
185 |
-
model = cls(config, *inputs, **kwargs)
|
186 |
-
|
187 |
-
# thus we need to load the state_dict
|
188 |
-
state_dict = torch.load(pretrained_checkpoint)
|
189 |
-
# remove `model` prefix to avoid error
|
190 |
-
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
|
191 |
-
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
192 |
-
strict=False)
|
193 |
-
|
194 |
-
if len(missing_keys) > 0:
|
195 |
-
logger.warning(
|
196 |
-
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}")
|
197 |
-
|
198 |
-
logger.warning(f"the number of which is equal to {len(missing_keys)}"
|
199 |
-
)
|
200 |
-
|
201 |
-
if len(unexpected_keys) > 0:
|
202 |
-
logger.warning(
|
203 |
-
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
|
204 |
-
)
|
205 |
-
logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
|
206 |
-
|
207 |
-
|
208 |
-
return model
|
209 |
-
|
210 |
-
def forward(
|
211 |
-
self,
|
212 |
-
input_ids: Optional[torch.Tensor] = None,
|
213 |
-
attention_mask: Optional[torch.Tensor] = None,
|
214 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
215 |
-
position_ids: Optional[torch.Tensor] = None,
|
216 |
-
subset_mask : Optional[torch.Tensor] = None,
|
217 |
-
hospital_ids_lens: list = None,
|
218 |
-
) -> torch.Tensor:
|
219 |
-
|
220 |
-
embedding_output = self.bert.embeddings(input_ids, token_type_ids,
|
221 |
-
position_ids)
|
222 |
-
|
223 |
-
encoder_outputs_all = self.bert.encoder(
|
224 |
-
embedding_output,
|
225 |
-
attention_mask,
|
226 |
-
output_all_encoded_layers=True,
|
227 |
-
subset_mask=subset_mask)
|
228 |
-
|
229 |
-
# batch_size, hidden_dim
|
230 |
-
return self.get_embeddings(encoder_outputs_all, hospital_ids_lens, self.config.num_embedding_layers, self.config.strategy)
|
231 |
-
|
232 |
-
def get_embeddings(self, encoder_outputs_all, hospital_ids_lens, num_layers, strategy):
|
233 |
-
|
234 |
-
batch_embeddings = []
|
235 |
-
start_idx = 0
|
236 |
-
|
237 |
-
# num_layer (we use default = 4), batch_size (concatenated visits), seq_len (clinical note sequences), hidden_dim.
|
238 |
-
# average across num_layers and seq_len
|
239 |
-
if strategy == 'mean':
|
240 |
-
# batch_size (concatenated visits), hidden_dim.
|
241 |
-
sentence_representation = torch.stack(encoder_outputs_all[-num_layers:]).mean(dim=[0, 2])
|
242 |
-
|
243 |
-
for length in hospital_ids_lens:
|
244 |
-
# We then average across visits
|
245 |
-
# batch_size (true batch size), hidden_dim.
|
246 |
-
batch_embeddings.append(torch.mean(sentence_representation[start_idx:start_idx + length],dim=0))
|
247 |
-
start_idx += length
|
248 |
-
|
249 |
-
return torch.stack(batch_embeddings)
|
250 |
-
|
251 |
-
elif strategy == 'concat':
|
252 |
-
# num_layer, batch_size (concatenated visits), hidden_dim.
|
253 |
-
sentence_representation = torch.stack(encoder_outputs_all[-num_layers:]).mean(dim=2)
|
254 |
-
|
255 |
-
for length in hospital_ids_lens:
|
256 |
-
# We then average across visits
|
257 |
-
# num_layer, batch_size (true batch size), hidden_dim.
|
258 |
-
batch_embeddings.append(torch.mean(sentence_representation[:,start_idx:start_idx + length],dim=1))
|
259 |
-
start_idx += length
|
260 |
-
|
261 |
-
return torch.stack(batch_embeddings)
|
262 |
-
|
263 |
-
elif strategy == 'all':
|
264 |
-
# num_layer, batch_size (concatenated visits), seq_len (clinical note sequences), hidden_dim.
|
265 |
-
sentence_representation = torch.stack(encoder_outputs_all[-num_layers:])
|
266 |
-
return sentence_representation
|
267 |
-
else:
|
268 |
-
raise ValueError(f'{strategy} is not a valid strategy, choose between mean and concat')
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
1 |
+
import logging
|
2 |
from typing import Optional, Tuple, Union
|
3 |
+
|
4 |
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
7 |
+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
8 |
from transformers import BertPreTrainedModel
|
9 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
|
|
|
|
|
|
10 |
|
11 |
+
from bert_layers_mosa import BertModel
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
+
|
16 |
+
class MosaicBertForEmbeddingGeneration(BertPreTrainedModel):
|
17 |
+
|
18 |
+
def __init__(self, config, add_pooling_layer=False):
|
19 |
+
"""
|
20 |
+
Initializes the BertEmbeddings class.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
config (BertConfig): The configuration for the BERT model.
|
24 |
+
add_pooling_layer (bool, optional): Whether to add a pooling layer. Defaults to False.
|
25 |
+
"""
|
26 |
+
super().__init__(config)
|
27 |
+
assert (
|
28 |
+
config.num_hidden_layers >= config.num_embedding_layers
|
29 |
+
), "num_hidden_layers should be greater than or equal to num_embedding_layers"
|
30 |
+
self.config = config
|
31 |
+
self.strategy = config.strategy
|
32 |
+
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
33 |
+
# this resets the weights
|
34 |
+
self.post_init()
|
35 |
+
|
36 |
+
@classmethod
|
37 |
+
def from_pretrained(
|
38 |
+
cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
|
39 |
+
):
|
40 |
+
"""Load from pre-trained."""
|
41 |
+
# this gets a fresh init model
|
42 |
+
model = cls(config, *inputs, **kwargs)
|
43 |
+
|
44 |
+
# thus we need to load the state_dict
|
45 |
+
state_dict = torch.load(pretrained_checkpoint)
|
46 |
+
# remove `model` prefix to avoid error
|
47 |
+
consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
|
48 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
49 |
+
|
50 |
+
if len(missing_keys) > 0:
|
51 |
+
logger.warning(
|
52 |
+
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
53 |
+
)
|
54 |
+
|
55 |
+
logger.warning(f"the number of which is equal to {len(missing_keys)}")
|
56 |
+
|
57 |
+
if len(unexpected_keys) > 0:
|
58 |
+
logger.warning(
|
59 |
+
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
|
60 |
+
)
|
61 |
+
logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
|
62 |
+
|
63 |
+
return model
|
64 |
+
|
65 |
+
def forward(
|
66 |
+
self,
|
67 |
+
input_ids: Optional[torch.Tensor] = None,
|
68 |
+
attention_mask: Optional[torch.Tensor] = None,
|
69 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
70 |
+
position_ids: Optional[torch.Tensor] = None,
|
71 |
+
subset_mask: Optional[torch.Tensor] = None,
|
72 |
+
output_all_encoded_layers: Book = True,
|
73 |
+
) -> torch.Tensor:
|
74 |
+
|
75 |
+
embedding_output = self.bert.embeddings(input_ids, token_type_ids, position_ids)
|
76 |
+
|
77 |
+
encoder_outputs_all = self.bert.encoder(
|
78 |
+
embedding_output,
|
79 |
+
attention_mask,
|
80 |
+
output_all_encoded_layers=output_all_encoded_layers,
|
81 |
+
subset_mask=subset_mask,
|
82 |
+
)
|
83 |
+
|
84 |
+
# batch_size, hidden_dim
|
85 |
+
return encoder_outputs_all
|
86 |
+
|
87 |
class MosaicBertForSequenceClassification(BertPreTrainedModel):
|
88 |
"""Bert Model transformer with a sequence classification/regression head.
|
|
|
89 |
This head is just a linear layer on top of the pooled output.
|
90 |
"""
|
91 |
|
|
|
94 |
self.num_labels = config.num_labels
|
95 |
self.config = config
|
96 |
self.bert = BertModel(config, add_pooling_layer=True)
|
97 |
+
classifier_dropout = (
|
98 |
+
config.classifier_dropout
|
99 |
+
if config.classifier_dropout is not None
|
100 |
+
else config.hidden_dropout_prob
|
101 |
+
)
|
102 |
self.dropout = nn.Dropout(classifier_dropout)
|
103 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
104 |
|
105 |
+
# this resets the weights
|
106 |
self.post_init()
|
107 |
|
|
|
108 |
@classmethod
|
109 |
+
def from_pretrained(
|
110 |
+
cls, pretrained_checkpoint, state_dict=None, config=None, *inputs, **kwargs
|
111 |
+
):
|
|
|
|
|
|
|
112 |
"""Load from pre-trained."""
|
113 |
# this gets a fresh init model
|
114 |
model = cls(config, *inputs, **kwargs)
|
115 |
+
|
116 |
# thus we need to load the state_dict
|
117 |
state_dict = torch.load(pretrained_checkpoint)
|
118 |
# remove `model` prefix to avoid error
|
119 |
+
consume_prefix_in_state_dict_if_present(state_dict, prefix="model.")
|
120 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
|
121 |
|
122 |
if len(missing_keys) > 0:
|
123 |
logger.warning(
|
124 |
+
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
|
|
|
|
125 |
)
|
126 |
|
127 |
+
logger.warning(f"the number of which is equal to {len(missing_keys)}")
|
128 |
+
|
129 |
if len(unexpected_keys) > 0:
|
130 |
logger.warning(
|
131 |
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}",
|
132 |
)
|
133 |
logger.warning(f"the number of which is equal to {len(unexpected_keys)}")
|
134 |
|
|
|
135 |
return model
|
136 |
|
137 |
def forward(
|
|
|
148 |
return_dict: Optional[bool] = None,
|
149 |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
150 |
|
151 |
+
return_dict = (
|
152 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
153 |
+
)
|
154 |
|
155 |
outputs = self.bert(
|
156 |
input_ids,
|
|
|
174 |
if self.config.problem_type is None:
|
175 |
if self.num_labels == 1:
|
176 |
self.config.problem_type = "regression"
|
177 |
+
elif self.num_labels > 1 and (
|
178 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
179 |
+
):
|
180 |
self.config.problem_type = "single_label_classification"
|
181 |
else:
|
182 |
self.config.problem_type = "multi_label_classification"
|
|
|
201 |
loss=loss,
|
202 |
logits=logits,
|
203 |
hidden_states=None,
|
204 |
+
attentions=None,
|
205 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|