import torch from daedalus_mobile import DaedalusMobile from tokenizer import DaedalusTokenizer from config import config def train(model, device, train_loader, optimizer): model.train() total_loss = 0 for batch in train_loader: input_ids, attention_mask, labels = batch input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) optimizer.zero_grad() loss = model.train_step((input_ids, attention_mask, labels)) loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(train_loader) def main(): device = torch.device(config.device) model = DaedalusMobile(config) model.to(device) tokenizer = DaedalusTokenizer(config) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True) optimizer = model.configure_optimizers() for epoch in range(config.epochs): loss = train(model, device, train_loader, optimizer) print(f'Epoch {epoch+1}, Loss: {loss:.4f}') if __name__ == '__main__': main()