import torch from daedalus_mobile import DaedalusMobile from tokenizer import DaedalusTokenizer from config import config def evaluate(model, device, eval_loader): model.eval() total_loss = 0 with torch.no_grad(): for batch in eval_loader: input_ids, attention_mask, labels = batch input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) loss = model.eval_step((input_ids, attention_mask, labels)) total_loss += loss.item() return total_loss / len(eval_loader) def main(): device = torch.device(config.device) model = DaedalusMobile(config) model.to(device) tokenizer = DaedalusTokenizer(config) eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=config.batch_size, shuffle=False) loss = evaluate(model, device, eval_loader) print(f'Loss: {loss:.4f}') if __name__ == '__main__': main()