thak123 commited on
Commit
e1cfebe
1 Parent(s): e5ffa90

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +36 -0
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import transformers
3
+ import torch.nn as nn
4
+
5
+
6
+ class BERTBaseUncased(nn.Module):
7
+ def __init__(self):
8
+ super(BERTBaseUncased, self).__init__()
9
+ self.bert = transformers.BertModel.from_pretrained(config.BERT_PATH)
10
+
11
+ self.bert_drop = nn.Dropout(0.3)
12
+
13
+ self.out = nn.Linear(768, 3)
14
+ # self.out = nn.Linear(256, 3)
15
+
16
+ nn.init.xavier_uniform_(self.out.weight)
17
+
18
+ def forward(self, ids, mask, token_type_ids):
19
+ _, o2 = self.bert(
20
+ ids,
21
+ attention_mask=mask,
22
+ token_type_ids=token_type_ids
23
+ )
24
+ bo = self.bert_drop(o2)
25
+ # bo = self.tanh(self.fc(bo)) # to be commented if original
26
+ output = self.out(bo)
27
+ return output
28
+
29
+ def extract_features(self, ids, mask, token_type_ids):
30
+ _, o2 = self.bert(
31
+ ids,
32
+ attention_mask=mask,
33
+ token_type_ids=token_type_ids
34
+ )
35
+ bo = self.bert_drop(o2)
36
+ return bo