BathSalt-1 commited on
Commit
373d4b0
·
verified ·
1 Parent(s): 528f8d7

Create daedalus_mobile.py

Browse files
Files changed (1) hide show
  1. daedalus_mobile.py +34 -0
daedalus_mobile.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+
6
+ class DaedalusMobile(nn.Module):
7
+ def __init__(self, config):
8
+ super(DaedalusMobile, self).__init__()
9
+ self.config = config
10
+ self.encoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
11
+ self.decoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
12
+ self.dropout = nn.Dropout(config.dropout)
13
+
14
+ def forward(self, input_ids, attention_mask):
15
+ encoder_output = self.encoder(input_ids, attention_mask)
16
+ decoder_output = self.decoder(encoder_output.last_hidden_state, attention_mask)
17
+ output = self.dropout(decoder_output.last_hidden_state)
18
+ return output
19
+
20
+ def configure_optimizers(self):
21
+ optimizer = optim.Adam(self.parameters(), lr=self.config.lr)
22
+ return optimizer
23
+
24
+ def train_step(self, batch):
25
+ input_ids, attention_mask, labels = batch
26
+ output = self(input_ids, attention_mask)
27
+ loss = nn.CrossEntropyLoss()(output, labels)
28
+ return loss
29
+
30
+ def eval_step(self, batch):
31
+ input_ids, attention_mask, labels = batch
32
+ output = self(input_ids, attention_mask)
33
+ loss = nn.CrossEntropyLoss()(output, labels)
34
+ return loss