File size: 1,294 Bytes
373d4b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
class DaedalusMobile(nn.Module):
def __init__(self, config):
super(DaedalusMobile, self).__init__()
self.config = config
self.encoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
self.decoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
self.dropout = nn.Dropout(config.dropout)
def forward(self, input_ids, attention_mask):
encoder_output = self.encoder(input_ids, attention_mask)
decoder_output = self.decoder(encoder_output.last_hidden_state, attention_mask)
output = self.dropout(decoder_output.last_hidden_state)
return output
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.config.lr)
return optimizer
def train_step(self, batch):
input_ids, attention_mask, labels = batch
output = self(input_ids, attention_mask)
loss = nn.CrossEntropyLoss()(output, labels)
return loss
def eval_step(self, batch):
input_ids, attention_mask, labels = batch
output = self(input_ids, attention_mask)
loss = nn.CrossEntropyLoss()(output, labels)
return loss |