add missing parts
Browse files- automodel.py +19 -0
automodel.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
7 |
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
|
|
8 |
from transformers import BertPreTrainedModel
|
9 |
from transformers.modeling_outputs import (MaskedLMOutput,
|
10 |
SequenceClassifierOutput)
|
@@ -233,6 +234,24 @@ class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
|
|
233 |
attentions=None,
|
234 |
)
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
class BertLMPredictionHead(nn.Module):
|
237 |
|
238 |
def __init__(self, config, bert_model_embedding_weights):
|
|
|
5 |
import torch.nn as nn
|
6 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
7 |
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
8 |
+
from transformers.activations import ACT2FN
|
9 |
from transformers import BertPreTrainedModel
|
10 |
from transformers.modeling_outputs import (MaskedLMOutput,
|
11 |
SequenceClassifierOutput)
|
|
|
234 |
attentions=None,
|
235 |
)
|
236 |
|
237 |
+
|
238 |
+
class BertPredictionHeadTransform(nn.Module):
|
239 |
+
|
240 |
+
def __init__(self, config):
|
241 |
+
super().__init__()
|
242 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
243 |
+
if isinstance(config.hidden_act, str):
|
244 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
245 |
+
else:
|
246 |
+
self.transform_act_fn = config.hidden_act
|
247 |
+
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
|
248 |
+
|
249 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
250 |
+
hidden_states = self.dense(hidden_states)
|
251 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
252 |
+
hidden_states = self.LayerNorm(hidden_states)
|
253 |
+
return hidden_states
|
254 |
+
|
255 |
class BertLMPredictionHead(nn.Module):
|
256 |
|
257 |
def __init__(self, config, bert_model_embedding_weights):
|