daedalus_mobile / daedalus_mobile.py
BathSalt-1's picture
Create daedalus_mobile.py
373d4b0 verified
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
class DaedalusMobile(nn.Module):
def __init__(self, config):
super(DaedalusMobile, self).__init__()
self.config = config
self.encoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
self.decoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small')
self.dropout = nn.Dropout(config.dropout)
def forward(self, input_ids, attention_mask):
encoder_output = self.encoder(input_ids, attention_mask)
decoder_output = self.decoder(encoder_output.last_hidden_state, attention_mask)
output = self.dropout(decoder_output.last_hidden_state)
return output
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.config.lr)
return optimizer
def train_step(self, batch):
input_ids, attention_mask, labels = batch
output = self(input_ids, attention_mask)
loss = nn.CrossEntropyLoss()(output, labels)
return loss
def eval_step(self, batch):
input_ids, attention_mask, labels = batch
output = self(input_ids, attention_mask)
loss = nn.CrossEntropyLoss()(output, labels)
return loss