File size: 1,294 Bytes
373d4b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

class DaedalusMobile(nn.Module):
    def __init__(self, config):
        super(DaedalusMobile, self).__init__()
        self.config = config
        self.encoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
        self.decoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_ids, attention_mask):
        encoder_output = self.encoder(input_ids, attention_mask)
        decoder_output = self.decoder(encoder_output.last_hidden_state, attention_mask)
        output = self.dropout(decoder_output.last_hidden_state)
        return output

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.config.lr)
        return optimizer

    def train_step(self, batch):
        input_ids, attention_mask, labels = batch
        output = self(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(output, labels)
        return loss

    def eval_step(self, batch):
        input_ids, attention_mask, labels = batch
        output = self(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(output, labels)
        return loss