|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
class DaedalusMobile(nn.Module): |
|
def __init__(self, config): |
|
super(DaedalusMobile, self).__init__() |
|
self.config = config |
|
self.encoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small') |
|
self.decoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small') |
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
encoder_output = self.encoder(input_ids, attention_mask) |
|
decoder_output = self.decoder(encoder_output.last_hidden_state, attention_mask) |
|
output = self.dropout(decoder_output.last_hidden_state) |
|
return output |
|
|
|
def configure_optimizers(self): |
|
optimizer = optim.Adam(self.parameters(), lr=self.config.lr) |
|
return optimizer |
|
|
|
def train_step(self, batch): |
|
input_ids, attention_mask, labels = batch |
|
output = self(input_ids, attention_mask) |
|
loss = nn.CrossEntropyLoss()(output, labels) |
|
return loss |
|
|
|
def eval_step(self, batch): |
|
input_ids, attention_mask, labels = batch |
|
output = self(input_ids, attention_mask) |
|
loss = nn.CrossEntropyLoss()(output, labels) |
|
return loss |