BathSalt-1 commited on
Commit
0ccabc3
·
verified ·
1 Parent(s): 11bec05

Create evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +27 -0
evaluate.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from daedalus_mobile import DaedalusMobile
3
+ from tokenizer import DaedalusTokenizer
4
+ from config import config
5
+
6
+ def evaluate(model, device, eval_loader):
7
+ model.eval()
8
+ total_loss = 0
9
+ with torch.no_grad():
10
+ for batch in eval_loader:
11
+ input_ids, attention_mask, labels = batch
12
+ input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
13
+ loss = model.eval_step((input_ids, attention_mask, labels))
14
+ total_loss += loss.item()
15
+ return total_loss / len(eval_loader)
16
+
17
+ def main():
18
+ device = torch.device(config.device)
19
+ model = DaedalusMobile(config)
20
+ model.to(device)
21
+ tokenizer = DaedalusTokenizer(config)
22
+ eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=config.batch_size, shuffle=False)
23
+ loss = evaluate(model, device, eval_loader)
24
+ print(f'Loss: {loss:.4f}')
25
+
26
+ if __name__ == '__main__':
27
+ main()