thak123 commited on
Commit
3e1334e
1 Parent(s): 2938546

Create engine.py

Browse files
Files changed (1) hide show
  1. engine.py +116 -0
engine.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from tqdm import tqdm
4
+ from utils import categorical_accuracy
5
+
6
+
7
+ def loss_fn(outputs, targets):
8
+ return nn.CrossEntropyLoss()(outputs, targets)
9
+
10
+
11
+ def train_fn(data_loader, model, optimizer, device, scheduler):
12
+ model.train()
13
+ train_loss, train_acc = 0.0, 0.0
14
+
15
+ for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
16
+ ids = d["ids"]
17
+ token_type_ids = d["token_type_ids"]
18
+ mask = d["mask"]
19
+ targets = d["targets"]
20
+
21
+ ids = ids.to(device, dtype=torch.long)
22
+ token_type_ids = token_type_ids.to(device, dtype=torch.long)
23
+ mask = mask.to(device, dtype=torch.long)
24
+ targets = targets.to(device, dtype=torch.long)
25
+
26
+ optimizer.zero_grad()
27
+ outputs = model(
28
+ ids=ids,
29
+ mask=mask,
30
+ token_type_ids=token_type_ids
31
+ )
32
+
33
+ loss = loss_fn(outputs, targets)
34
+ loss.backward()
35
+
36
+ optimizer.step()
37
+ scheduler.step()
38
+ train_loss += loss.item()
39
+ pred_labels = torch.argmax(outputs, dim=1)
40
+ # (pred_labels == targets).sum().item()
41
+ train_acc += categorical_accuracy(outputs, targets).item()
42
+
43
+ train_loss /= len(data_loader)
44
+ train_acc /= len(data_loader)
45
+ return train_loss, train_acc
46
+
47
+
48
+ def eval_fn(data_loader, model, device):
49
+ model.eval()
50
+ eval_loss, eval_acc = 0.0, 0.0
51
+ fin_targets = []
52
+ fin_outputs = []
53
+ with torch.no_grad():
54
+ for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
55
+ ids = d["ids"]
56
+ token_type_ids = d["token_type_ids"]
57
+ mask = d["mask"]
58
+ targets = d["targets"]
59
+
60
+ ids = ids.to(device, dtype=torch.long)
61
+ token_type_ids = token_type_ids.to(device, dtype=torch.long)
62
+ mask = mask.to(device, dtype=torch.long)
63
+ targets = targets.to(device, dtype=torch.long)
64
+
65
+ outputs = model(
66
+ ids=ids,
67
+ mask=mask,
68
+ token_type_ids=token_type_ids
69
+ )
70
+ loss = loss_fn(outputs, targets)
71
+ eval_loss += loss.item()
72
+ pred_labels = torch.argmax(outputs, axis=1)
73
+ # (pred_labels == targets).sum().item()
74
+ eval_acc += categorical_accuracy(outputs, targets).item()
75
+ fin_targets.extend(targets.cpu().detach().numpy().tolist())
76
+ fin_outputs.extend(torch.argmax(
77
+ outputs, dim=1).cpu().detach().numpy().tolist())
78
+ eval_loss /= len(data_loader)
79
+ eval_acc /= len(data_loader)
80
+ return fin_outputs, fin_targets, eval_loss, eval_acc
81
+
82
+
83
+
84
+ def predict_fn(data_loader, model, device, extract_features=False):
85
+ model.eval()
86
+
87
+ fin_outputs = []
88
+ extracted_features =[]
89
+ with torch.no_grad():
90
+ for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
91
+ ids = d["ids"]
92
+ token_type_ids = d["token_type_ids"]
93
+ mask = d["mask"]
94
+ # targets = d["targets"]
95
+
96
+ ids = ids.to(device, dtype=torch.long)
97
+ token_type_ids = token_type_ids.to(device, dtype=torch.long)
98
+ mask = mask.to(device, dtype=torch.long)
99
+
100
+ outputs = model(
101
+ ids=ids,
102
+ mask=mask,
103
+ token_type_ids=token_type_ids
104
+ )
105
+ if extract_features:
106
+ extracted_features.extend( model.extract_features(
107
+ ids=ids,
108
+ mask=mask,
109
+ token_type_ids=token_type_ids
110
+ ).cpu().detach().numpy().tolist())
111
+
112
+ fin_outputs.extend(torch.argmax(
113
+ outputs, dim=1).cpu().detach().numpy().tolist())
114
+
115
+ return fin_outputs, extracted_features
116
+