thak123 commited on
Commit
babc7da
1 Parent(s): 241abc0

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -9
model.py CHANGED
@@ -8,23 +8,22 @@ class BERTBaseUncased(nn.Module):
8
  super(BERTBaseUncased, self).__init__()
9
  self.bert = transformers.BertModel.from_pretrained(config.BERT_PATH)
10
 
11
- # self.bert_drop = nn.Dropout(0.3)
12
 
13
- # self.out = nn.Linear(768, 3)
14
- # # self.out = nn.Linear(256, 3)
15
 
16
- # nn.init.xavier_uniform_(self.out.weight)
17
- print("Model")
18
 
19
  def forward(self, ids, mask, token_type_ids):
20
- output = self.bert(
21
  ids,
22
  attention_mask=mask,
23
  token_type_ids=token_type_ids
24
  )
25
- # bo = self.bert_drop(o2)
26
- # # bo = self.tanh(self.fc(bo)) # to be commented if original
27
- # output = self.out(bo)
28
  return output
29
 
30
  def extract_features(self, ids, mask, token_type_ids):
 
8
  super(BERTBaseUncased, self).__init__()
9
  self.bert = transformers.BertModel.from_pretrained(config.BERT_PATH)
10
 
11
+ self.bert_drop = nn.Dropout(0.3)
12
 
13
+ self.out = nn.Linear(768, 3)
14
+ # self.out = nn.Linear(256, 3)
15
 
16
+ nn.init.xavier_uniform_(self.out.weight)
 
17
 
18
  def forward(self, ids, mask, token_type_ids):
19
+ _, o2 = self.bert(
20
  ids,
21
  attention_mask=mask,
22
  token_type_ids=token_type_ids
23
  )
24
+ bo = self.bert_drop(o2)
25
+ # bo = self.tanh(self.fc(bo)) # to be commented if original
26
+ output = self.out(bo)
27
  return output
28
 
29
  def extract_features(self, ids, mask, token_type_ids):