import torch import torch.nn as nn import torch.optim as optim from transformers import AutoModelForSeq2SeqLM, AutoTokenizer class DaedalusMobile(nn.Module): def __init__(self, config): super(DaedalusMobile, self).__init__() self.config = config self.encoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small') self.decoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small') self.dropout = nn.Dropout(config.dropout) def forward(self, input_ids, attention_mask): encoder_output = self.encoder(input_ids, attention_mask) decoder_output = self.decoder(encoder_output.last_hidden_state, attention_mask) output = self.dropout(decoder_output.last_hidden_state) return output def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=self.config.lr) return optimizer def train_step(self, batch): input_ids, attention_mask, labels = batch output = self(input_ids, attention_mask) loss = nn.CrossEntropyLoss()(output, labels) return loss def eval_step(self, batch): input_ids, attention_mask, labels = batch output = self(input_ids, attention_mask) loss = nn.CrossEntropyLoss()(output, labels) return loss