surkovvv commited on
Commit
6d2f9da
1 Parent(s): ddd86a3

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +21 -0
model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertModel
2
+ import torch
3
+
4
+
5
+ class DistillBERTClass(torch.nn.Module):
6
+ def __init__(self):
7
+ super(DistillBERTClass, self).__init__()
8
+ self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased")
9
+ self.pre_classifier = torch.nn.Linear(768, 512)
10
+ self.dropout = torch.nn.Dropout(0.3)
11
+ self.classifier = torch.nn.Linear(512, 126)
12
+
13
+ def forward(self, input_ids, attention_mask):
14
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
15
+ hidden_state = output_1[0]
16
+ pooler = hidden_state[:, 0]
17
+ pooler = self.pre_classifier(pooler)
18
+ pooler = torch.nn.ReLU()(pooler)
19
+ pooler = self.dropout(pooler)
20
+ output = self.classifier(pooler)
21
+ return output