Tej3 commited on
Commit
71bd54f
·
1 Parent(s): 2dc883d

Committing App

Browse files
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import gradio as gr
4
+ import numpy as np
5
+ import wfdb
6
+ import torch
7
+ from wfdb.plot.plot import plot_wfdb
8
+ from wfdb.io.record import Record, rdrecord
9
+
10
+ from models.CNN import CNN, MMCNN_CAT
11
+ from models.RNN import MMRNN
12
+ from utils.helper_functions import predict
13
+
14
+ import matplotlib
15
+ matplotlib.use('Agg')
16
+ import matplotlib.pyplot as plt
17
+
18
+ from transformers import AutoTokenizer, AutoModel
19
+ from langdetect import detect
20
+
21
+ # edit this before Running
22
+ CWD = os.getcwd()
23
+ #CKPT paths
24
+ MMCNN_CAT_ckpt_path = f"{CWD}/demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt"
25
+ MMRNN_ckpt_path = f"{CWD}/demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt"
26
+
27
+ # Define clinical models and tokenizers
28
+ en_clin_bert = 'emilyalsentzer/Bio_ClinicalBERT'
29
+ ger_clin_bert = 'smanjil/German-MedBERT'
30
+
31
+ en_tokenizer = AutoTokenizer.from_pretrained(en_clin_bert)
32
+ en_model = AutoModel.from_pretrained(en_clin_bert)
33
+
34
+ g_tokenizer = AutoTokenizer.from_pretrained(ger_clin_bert)
35
+ g_model = AutoModel.from_pretrained(ger_clin_bert)
36
+
37
+ def preprocess(data_file_path):
38
+ data = [wfdb.rdsamp(data_file_path)]
39
+ data = np.array([signal for signal, meta in data])
40
+ return data
41
+
42
+ def embed(notes):
43
+ if detect(notes) == 'en':
44
+ tokens = en_tokenizer(notes, return_tensors='pt')
45
+ outputs = en_model(**tokens)
46
+ else:
47
+ tokens = g_tokenizer(notes, return_tensors='pt')
48
+ outputs = g_model(**tokens)
49
+
50
+ embeddings = outputs.last_hidden_state
51
+ embedding = torch.mean(embeddings, dim=1).squeeze(0)
52
+
53
+ return embedding
54
+ # return torch.load(f'{"./data/embeddings/"}1.pt')
55
+ def plot_ecg(path):
56
+ record100 = rdrecord(path)
57
+ return plot_wfdb(record=record100, title='ECG Signal Graph', figsize=(12,10), return_fig=True)
58
+
59
+ def infer(model,data, notes):
60
+ embed_notes = embed(notes).unsqueeze(0)
61
+ data= torch.tensor(data)
62
+ if model == "CNN":
63
+ model = MMCNN_CAT()
64
+ checkpoint = torch.load(MMCNN_CAT_ckpt_path)
65
+ model.load_state_dict(checkpoint['model_state_dict'])
66
+ data = data.transpose(1,2).float()
67
+
68
+ elif model == "RNN":
69
+ model = MMRNN(device='cpu')
70
+ model.load_state_dict(torch.load(MMRNN_ckpt_path)['model_state_dict'])
71
+ data = data.float()
72
+ model.eval()
73
+ outputs, predicted = predict(model, data, embed_notes, device='cpu')
74
+ outputs = torch.sigmoid(outputs)[0]
75
+ return {'Conduction Disturbance':round(outputs[0].item(),2), 'Hypertrophy':round(outputs[1].item(),2), 'Myocardial Infarction':round(outputs[2].item(),2), 'Normal ECG':round(outputs[3].item(),2), 'ST/T Change':round(outputs[4].item(),2)}
76
+
77
+ def run(model_name, header_file, data_file, notes):
78
+ demo_dir = f"{CWD}/demo_data"
79
+ hdr_dirname, hdr_basename = os.path.split(header_file.name)
80
+ data_dirname, data_basename = os.path.split(data_file.name)
81
+ shutil.copyfile(data_file.name, f"{demo_dir}/{data_basename}")
82
+ shutil.copyfile(header_file.name, f"{demo_dir}/{hdr_basename}")
83
+ data = preprocess(f"{demo_dir}/{hdr_basename.split('.')[0]}")
84
+ ECG_graph = plot_ecg(f"{demo_dir}/{hdr_basename.split('.')[0]}")
85
+ os.remove(f"{demo_dir}/{data_basename}")
86
+ os.remove(f"{demo_dir}/{hdr_basename}")
87
+ output = infer(model_name, data, notes)
88
+ return output, ECG_graph
89
+
90
+ with gr.Blocks() as demo:
91
+ with gr.Row():
92
+ model = gr.Radio(['CNN', 'RNN'], label= "Select Model")
93
+ with gr.Row():
94
+ with gr.Column(scale=1):
95
+ header_file = gr.File(label = "header_file", file_types=[".hea"])
96
+ data_file = gr.File(label = "data_file", file_types=[".dat"])
97
+ notes = gr.Textbox(label = "Clinical Notes")
98
+ with gr.Column(scale=1):
99
+ output_prob = gr.Label({'Normal ECG':0, 'Myocardial Infarction':0, 'ST/T Change':0, 'Conduction Disturbance':0, 'Hypertrophy':0}, show_label=False)
100
+ with gr.Row():
101
+ ecg_graph = gr.Plot(label = "ECG Signal Visualisation")
102
+ with gr.Row():
103
+ predict_btn = gr.Button("Predict Class")
104
+ predict_btn.click(fn= run, inputs = [model, header_file, data_file, notes], outputs=[output_prob, ecg_graph])
105
+ with gr.Row():
106
+ gr.Examples(examples=[[f"{CWD}/demo_data/test/00001_lr.hea", f"{CWD}/demo_data/test/00001_lr.dat", "sinusrhythmus periphere niederspannung"],\
107
+ [f"{CWD}/demo_data/test/00008_lr.hea", f"{CWD}/demo_data/test/00008_lr.dat", "sinusrhythmus linkstyp qrs(t) abnormal inferiorer infarkt alter unbest."], \
108
+ [f"{CWD}/demo_data/test/00045_lr.hea", f"{CWD}/demo_data/test/00045_lr.dat", "sinusrhythmus unvollstÄndiger rechtsschenkelblock sonst normales ekg"],\
109
+ [f"{CWD}/demo_data/test/00257_lr.hea", f"{CWD}/demo_data/test/00257_lr.dat", "premature atrial contraction(s). sinus rhythm. left atrial enlargement. qs complexes in v2. st segments are slightly elevated in v2,3. st segments are depressed in i, avl. t waves are low or flat in i, v5,6 and inverted in avl. consistent with ischaemic h"],\
110
+ ],
111
+ inputs = [header_file, data_file, notes])
112
+
113
+ if __name__ == "__main__":
114
+ demo.launch()
demo_data/.gitkeep ADDED
File without changes
demo_data/model_MMCNN_CAT_epoch_30_acc_84.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3735115ddc15ecab4844a13124616f339364795349aeef0476491accfa8b4eda
3
+ size 25392011
demo_data/model_MMRNN_undersampled_augmented_rn_epoch_20_acc_84.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7cfaa76908e6246b051fd5725152bca28b5111a83ada18fec5848816d8bd6e7a
3
+ size 1340343
demo_data/test/00001_lr.dat ADDED
Binary file (24 kB). View file
 
demo_data/test/00001_lr.hea ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 00001_lr 12 100 1000
2
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 -119 1508 0 I
3
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 -55 723 0 II
4
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 64 64758 0 III
5
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 86 64423 0 AVR
6
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 -91 1211 0 AVL
7
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 4 7 0 AVF
8
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 -69 63827 0 V1
9
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 -31 6999 0 V2
10
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 0 63759 0 V3
11
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 -26 61447 0 V4
12
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 -39 64979 0 V5
13
+ 00001_lr.dat 16 1000.0(0)/mV 16 0 -79 832 0 V6
demo_data/test/00008_lr.dat ADDED
Binary file (24 kB). View file
 
demo_data/test/00008_lr.hea ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 00008_lr 12 100 1000
2
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -41 2321 0 I
3
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -80 4548 0 II
4
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -39 2234 0 III
5
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 60 62047 0 AVR
6
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -1 0 0 AVL
7
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -60 3352 0 AVF
8
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 45 232 0 V1
9
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -5 65262 0 V2
10
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 5 63785 0 V3
11
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -55 58960 0 V4
12
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -70 3471 0 V5
13
+ 00008_lr.dat 16 1000.0(0)/mV 16 0 -40 2065 0 V6
demo_data/test/00045_lr.dat ADDED
Binary file (24 kB). View file
 
demo_data/test/00045_lr.hea ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 00045_lr 12 100 1000
2
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 -181 1318 0 I
3
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 -438 5652 0 II
4
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 -257 4356 0 III
5
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 310 62008 0 AVR
6
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 38 64012 0 AVL
7
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 -347 4979 0 AVF
8
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 121 3953 0 V1
9
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 51 64138 0 V2
10
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 82 61158 0 V3
11
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 -58 63682 0 V4
12
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 -52 65025 0 V5
13
+ 00045_lr.dat 16 1000.0(0)/mV 16 0 -134 193 0 V6
demo_data/test/00257_lr.dat ADDED
Binary file (24 kB). View file
 
demo_data/test/00257_lr.hea ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 00257_lr 12 100 1000
2
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 -8 8043 0 I
3
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 24 3049 0 II
4
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 32 60557 0 III
5
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 -9 59959 0 AVR
6
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 -20 6506 0 AVL
7
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 28 64558 0 AVF
8
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 -29 60014 0 V1
9
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 -24 64087 0 V2
10
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 138 1192 0 V3
11
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 34 65087 0 V4
12
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 26 65386 0 V5
13
+ 00257_lr.dat 16 1000.0(0)/mV 16 0 32 59612 0 V6
models/CNN.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchinfo import summary
5
+
6
+ # Not in use yet
7
+ class Conv1d_layer(nn.Module):
8
+ def __init__(self, in_channel, out_channel, kernel_size) -> None:
9
+ super().__init__()
10
+ self.conv = nn.Conv1d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size)
11
+ self.batch_norm = torch.nn.BatchNorm1d(out_channel)
12
+ self.dropout = nn.Dropout1d(p=0.5)
13
+
14
+ def forward(self, x):
15
+ x= self.conv(x)
16
+ x = self.batch_norm(x)
17
+ x = self.dropout(x)
18
+ return x
19
+
20
+ class CNN(nn.Module):
21
+ def __init__(self, ecg_channels=12):
22
+ super(CNN, self).__init__()
23
+ self.name = "CNN"
24
+ self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
25
+ self.pool1 = nn.MaxPool1d(2, 2)
26
+ self.conv2 = nn.Conv1d(16, 32, 5)
27
+ self.pool2 = nn.MaxPool1d(2, 2)
28
+ self.conv3 = nn.Conv1d(32, 48, 3)
29
+ self.pool3 = nn.MaxPool1d(2, 2)
30
+ self.fc0 = nn.Linear(5856, 512)
31
+ self.fc1 = nn.Linear(512, 128)
32
+ self.fc2 = nn.Linear(128, 5)
33
+ self.activation = nn.ReLU()
34
+ def forward(self, x, notes=None):
35
+ x = self.pool1(self.activation(self.conv1(x)))
36
+ x = self.pool2(self.activation(self.conv2(x)))
37
+ x = self.pool3(self.activation(self.conv3(x)))
38
+ x = x.view(x.size(0),-1)
39
+ x = self.activation(self.fc0(x))
40
+ x = self.activation(self.fc1(x))
41
+ x = self.fc2(x)
42
+ x = x.squeeze(1)
43
+ return x
44
+
45
+
46
+ class MMCNN_SUM(nn.Module):
47
+ def __init__(self, ecg_channels=12):
48
+ super(MMCNN_SUM, self).__init__()
49
+ # ECG processing Layers
50
+ self.name = "MMCNN_SUM"
51
+ self.conv1 = Conv1d_layer(ecg_channels, 16, 7)
52
+ self.pool1 = nn.MaxPool1d(2, 2)
53
+ self.conv2 = Conv1d_layer(16, 32, 5)
54
+ self.pool2 = nn.MaxPool1d(2, 2)
55
+ self.conv3 = Conv1d_layer(32, 48, 3)
56
+ self.pool3 = nn.MaxPool1d(2, 2)
57
+ self.fc0 = nn.Linear(5856, 512)
58
+ self.fc1 = nn.Linear(512, 128)
59
+ self.fc2 = nn.Linear(128, 5)
60
+
61
+ # Clinical Notes Processing Layers
62
+ self.fc_emb = nn.Linear(768, 128)
63
+ self.norm = nn.LayerNorm(128)
64
+
65
+ self.activation = nn.ReLU()
66
+
67
+ def forward(self, x, notes):
68
+ # ECG Processing
69
+ x = self.pool1(self.activation(self.conv1(x)))
70
+ x = self.pool2(self.activation(self.conv2(x)))
71
+ x = self.pool3(self.activation(self.conv3(x)))
72
+ x = x.view(x.size(0),-1)
73
+ x = self.activation(self.fc0(x))
74
+ x = self.activation(self.fc1(x))
75
+
76
+ # Notes Processing
77
+ notes = notes.view(notes.size(0),-1)
78
+ notes = self.activation(self.fc_emb(notes))
79
+
80
+ x = self.fc2(self.norm(x + notes))
81
+ x = x.squeeze(1)
82
+ return x
83
+
84
+ class MMCNN_CAT(nn.Module):
85
+ def __init__(self, ecg_channels=12):
86
+ super(MMCNN_CAT, self).__init__()
87
+ # ECG processing Layers
88
+ self.name = "MMCNN_CAT"
89
+ self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
90
+ self.pool1 = nn.MaxPool1d(2, 2)
91
+ self.conv2 = nn.Conv1d(16, 32, 5)
92
+ self.pool2 = nn.MaxPool1d(2, 2)
93
+ self.conv3 = nn.Conv1d(32, 48, 3)
94
+ self.pool3 = nn.MaxPool1d(2, 2)
95
+ self.fc0 = nn.Linear(5856, 512)
96
+ self.fc1 = nn.Linear(512, 128)
97
+ self.fc2 = nn.Linear(256, 5)
98
+
99
+ # Clinical Notes Processing Layers
100
+ self.fc_emb = nn.Linear(768, 128)
101
+ self.norm = nn.LayerNorm(128)
102
+
103
+ self.activation = nn.ReLU()
104
+
105
+ def forward(self, x, notes):
106
+ # ECG Processing
107
+ x = self.pool1(self.activation(self.conv1(x)))
108
+ x = self.pool2(self.activation(self.conv2(x)))
109
+ x = self.pool3(self.activation(self.conv3(x)))
110
+ x = x.view(x.size(0),-1)
111
+ x = self.activation(self.fc0(x))
112
+ x = self.activation(self.fc1(x))
113
+
114
+ # Notes Processing
115
+ notes = notes.view(notes.size(0),-1)
116
+ notes = self.activation(self.fc_emb(notes))
117
+
118
+ x = self.fc2(torch.cat((x,notes),dim=1))
119
+ x = x.squeeze(1)
120
+ return x
121
+ class MMCNN_ATT(nn.Module):
122
+ def __init__(self, ecg_channels=12):
123
+ super(MMCNN_ATT, self).__init__()
124
+ # ECG processing Layers
125
+ self.name = "MMCNN_ATT"
126
+ self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
127
+ self.pool1 = nn.MaxPool1d(2, 2)
128
+ self.conv2 = nn.Conv1d(16, 32, 5)
129
+ self.pool2 = nn.MaxPool1d(2, 2)
130
+ self.conv3 = nn.Conv1d(32, 48, 3)
131
+ self.pool3 = nn.MaxPool1d(2, 2)
132
+ self.fc0 = nn.Linear(5856, 512)
133
+ self.fc1 = nn.Linear(512, 128)
134
+ self.fc2 = nn.Linear(128, 5)
135
+
136
+ # Clinical Notes Processing Layers
137
+ self.fc_emb = nn.Linear(768, 128)
138
+ self.norm1 = nn.LayerNorm(128)
139
+ self.norm2 = nn.LayerNorm(128)
140
+
141
+ self.attention = nn.MultiheadAttention(128, 8, batch_first=True)
142
+ self.activation = nn.ReLU()
143
+
144
+ def forward(self, x, notes):
145
+ # ECG Processing
146
+ x = self.pool1(self.activation(self.conv1(x)))
147
+ x = self.pool2(self.activation(self.conv2(x)))
148
+ x = self.pool3(self.activation(self.conv3(x)))
149
+ x = x.view(x.size(0),-1)
150
+ x = self.activation(self.fc0(x))
151
+ x = self.activation(self.fc1(x))
152
+ x = self.norm1(x)
153
+
154
+ # Notes Processing
155
+ notes = notes.view(notes.size(0),-1)
156
+ notes = self.activation(self.fc_emb(notes))
157
+ notes = self.norm2(notes)
158
+ notes=notes.unsqueeze(1)
159
+ x=x.unsqueeze(1)
160
+ x,_= self.attention(notes, x, x)
161
+ x = self.fc2(x)
162
+ x = x.squeeze(1)
163
+ return x
164
+
165
+ class MMCNN_SUM_ATT(nn.Module):
166
+ def __init__(self, ecg_channels=12):
167
+ super(MMCNN_SUM_ATT, self).__init__()
168
+ # ECG processing Layers
169
+ self.name = "MMCNN_SUM_ATT"
170
+ self.conv1 = nn.Conv1d(ecg_channels, 16, 7)
171
+ self.pool1 = nn.MaxPool1d(2, 2)
172
+ self.conv2 = nn.Conv1d(16, 32, 5)
173
+ self.pool2 = nn.MaxPool1d(2, 2)
174
+ self.conv3 = nn.Conv1d(32, 48, 3)
175
+ self.pool3 = nn.MaxPool1d(2, 2)
176
+ self.fc0 = nn.Linear(5856, 512)
177
+ self.fc1 = nn.Linear(512, 128)
178
+ self.fc2 = nn.Linear(128, 5)
179
+
180
+ # Clinical Notes Processing Layers
181
+ self.fc_emb = nn.Linear(768, 128)
182
+ self.norm = nn.LayerNorm(128)
183
+
184
+ self.attention = nn.MultiheadAttention(128, 8, batch_first=True)
185
+ self.activation = nn.ReLU()
186
+
187
+ def forward(self, x, notes):
188
+ # ECG Processing
189
+ x = self.pool1(self.activation(self.conv1(x)))
190
+ x = self.pool2(self.activation(self.conv2(x)))
191
+ x = self.pool3(self.activation(self.conv3(x)))
192
+ x = x.view(x.size(0),-1)
193
+ x = self.activation(self.fc0(x))
194
+ x = self.activation(self.fc1(x))
195
+
196
+ # Notes Processing
197
+ notes = notes.view(notes.size(0),-1)
198
+ notes = self.activation(self.fc_emb(notes))
199
+ x = self.norm(x + notes)
200
+
201
+ x=x.unsqueeze(1)
202
+ # print(x.shape)
203
+ x,_= self.attention(x, x, x)
204
+
205
+ x = self.fc2(x)
206
+ x = x.squeeze(1)
207
+ return x
208
+
209
+ if __name__ == "__main__":
210
+ model = CNN()
211
+ # model = Conv1d_layer(12, 16, 7)
212
+ summary(model, input_size = (1, 12, 1000))
213
+
models/RNN.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RNN(nn.Module):
6
+ def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, cuda=True, device='cuda'):
7
+ super(RNN, self).__init__()
8
+ self.hidden_dim = hidden_dim
9
+ self.num_layers = num_layers
10
+ self.device = device
11
+
12
+ self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim,
13
+ num_layers=self.num_layers, batch_first=True)
14
+ self.fc1 = nn.Linear(self.hidden_dim, self.hidden_dim)
15
+ self.fc2 = nn.Linear(self.hidden_dim, num_classes)
16
+ self.relu = nn.ReLU()
17
+
18
+ def forward(self, x, notes):
19
+ h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
20
+ c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
21
+
22
+ nn.init.xavier_normal_(h)
23
+ nn.init.xavier_normal_(c)
24
+ h = h.to(self.device)
25
+ c = c.to(self.device)
26
+ x = x.to(self.device)
27
+
28
+ output, _ = self.lstm(x, (h, c))
29
+
30
+ out = self.fc2(self.relu(self.fc1(output[:, -1, :])))
31
+
32
+ return out
33
+
34
+
35
+ class MMRNN(nn.ModuleList):
36
+ def __init__(self, input_dim=12, hidden_dim=64, num_layers=2, num_classes=5, embed_size=768, device="cuda"):
37
+ super(MMRNN, self).__init__()
38
+ self.hidden_dim = hidden_dim
39
+ self.num_layers = num_layers
40
+ self.device = device
41
+
42
+ self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.hidden_dim,
43
+ num_layers=self.num_layers, batch_first=True)
44
+ self.fc1 = nn.Linear(self.hidden_dim, embed_size)
45
+ self.fc2 = nn.Linear(embed_size, num_classes)
46
+
47
+ self.lnorm_out = nn.LayerNorm(embed_size)
48
+ self.lnorm_embed = nn.LayerNorm(embed_size)
49
+
50
+ def forward(self, x, note):
51
+ h = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
52
+ c = torch.zeros(self.num_layers, x.size(0), self.hidden_dim)
53
+
54
+ nn.init.xavier_normal_(h)
55
+ nn.init.xavier_normal_(c)
56
+ h = h.to(self.device)
57
+ c = c.to(self.device)
58
+ x = x.to(self.device)
59
+ note = note.to(self.device)
60
+
61
+ output, _ = self.lstm(x, (h, c))
62
+ # Take last hidden state
63
+ out = self.fc1(output[:, -1, :])
64
+
65
+ note = self.lnorm_embed(note)
66
+ out = self.lnorm_out(out)
67
+ out = note + out
68
+
69
+ out = self.fc2(out)
70
+
71
+ return out.squeeze(1)
models/__pycache__/CNN.cpython-39.pyc ADDED
Binary file (6.49 kB). View file
 
models/__pycache__/RNN.cpython-39.pyc ADDED
Binary file (2.34 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.25.0
2
+ langdetect==1.0.9
3
+ matplotlib==3.6.3
4
+ numpy==1.24.2
5
+ pandas==1.5.3
6
+ PyWavelets==1.4.1
7
+ scikit_learn==1.2.1
8
+ torch==1.12.1
9
+ torchinfo==1.7.2
10
+ torchvision==0.13.1
11
+ tqdm==4.64.1
12
+ transformers==4.28.1
13
+ wfdb==4.1.0
utils/RNN_utils.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ from tqdm.autonotebook import tqdm
4
+ import pywt
5
+ import os
6
+
7
+
8
+ def display_eval(epoch, epochs, tlength, global_step, tcorrect, tsamples, t_valid_samples, average_train_loss, average_valid_loss, total_acc_val):
9
+ tqdm.write(
10
+ f'Epoch: [{epoch + 1}/{epochs}], Step [{global_step}/{epochs*tlength}] | Train Loss: {average_train_loss: .3f} \
11
+ | Train Accuracy: {tcorrect / tsamples: .3f} \
12
+ | Val Loss: {average_valid_loss: .3f} \
13
+ | Val Accuracy: {total_acc_val / t_valid_samples: .3f}')
14
+
15
+
16
+ def save_model(model, optimizer, valid_loss, epoch, path='model.pt'):
17
+ torch.save({'valid_loss': valid_loss,
18
+ 'model_state_dict': model.state_dict(),
19
+ 'epoch': epoch + 1,
20
+ 'optimizer': optimizer.state_dict()
21
+ }, path)
22
+ tqdm.write(f'Model saved to ==> {path}')
23
+
24
+
25
+ def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'):
26
+ torch.save({'train_loss_list': train_loss_list,
27
+ 'valid_loss_list': valid_loss_list,
28
+ 'global_steps_list': global_steps_list,
29
+ }, path)
30
+
31
+
32
+ def plot_losses(metrics_save_name='metrics', save_dir='./'):
33
+ path = f'{save_dir}metrics_{metrics_save_name}.pt'
34
+ state = torch.load(path)
35
+
36
+ train_loss_list = state['train_loss_list']
37
+ valid_loss_list = state['valid_loss_list']
38
+ global_steps_list = state['global_steps_list']
39
+
40
+ plt.plot(global_steps_list, train_loss_list, label='Train')
41
+ plt.plot(global_steps_list, valid_loss_list, label='Valid')
42
+ plt.xlabel('Global Steps')
43
+ plt.ylabel('Loss')
44
+ plt.legend()
45
+ plt.show()
46
+
47
+
48
+ def train_RNN(epochs, train_loader, valid_loader, model, loss_fn, optimizer, eval_every=0.25, best_valid_loss=float("Inf"), device='cuda', model_save_name='', save_dir='./'):
49
+ model.train()
50
+
51
+ running_loss = 0.0
52
+ valid_running_loss = 0.0
53
+ global_step = 0
54
+ train_loss_list = []
55
+ valid_loss_list = []
56
+ global_steps_list = []
57
+
58
+ wavelet = 'db4'
59
+ level = 3
60
+
61
+ for epoch in tqdm(range(epochs)):
62
+ running_loss = 0.0
63
+ t_correct = 0
64
+ t_samples = 0
65
+ for images, labels, notes in train_loader:
66
+ optimizer.zero_grad()
67
+
68
+ coeffs = pywt.wavedec(images, wavelet, level=level, axis=1)
69
+ threshold = 0.1 * \
70
+ torch.median(torch.abs(torch.from_numpy(coeffs[-1])))
71
+ denoised_coeffs = [pywt.threshold(
72
+ data=c, mode='hard', value=threshold) for c in coeffs]
73
+ images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
74
+
75
+ images = torch.tensor(images).float().to(device)
76
+ labels = labels.to(device)
77
+ notes = notes.to(device)
78
+
79
+ output = model(images, notes)
80
+
81
+ loss = loss_fn(output, labels.float())
82
+ running_loss += loss.item()*len(labels)
83
+ loss.backward()
84
+ global_step += 1*len(images)
85
+
86
+ optimizer.step()
87
+
88
+ values, indices = torch.max(output, dim=1)
89
+ t_correct += sum(1 for s, i in enumerate(indices)
90
+ if labels[s][i] == 1)
91
+ t_samples += len(indices)
92
+
93
+ if (global_step % (int(eval_every*len(train_loader.dataset)))) < train_loader.batch_size:
94
+ model.eval()
95
+ valid_running_loss = 0.0
96
+ total_acc_val = 0
97
+ with torch.no_grad():
98
+
99
+ for images, labels, notes in valid_loader:
100
+
101
+ coeffs = pywt.wavedec(
102
+ images, wavelet, level=level, axis=1)
103
+ threshold = 0.1 * \
104
+ torch.median(
105
+ torch.abs(torch.from_numpy(coeffs[-1])))
106
+ denoised_coeffs = [pywt.threshold(
107
+ data=c, mode='hard', value=threshold) for c in coeffs]
108
+ images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
109
+
110
+ images = torch.tensor(images).float().to(device)
111
+ labels = labels.to(device)
112
+ notes = notes.to(device)
113
+ output = model(images, notes)
114
+
115
+ loss = loss_fn(output, labels.float()).item()
116
+ valid_running_loss += loss*len(images)
117
+ values, indices = torch.max(output, dim=1)
118
+ total_acc_val += sum(1 for s,
119
+ i in enumerate(indices) if labels[s][i] == 1)
120
+
121
+ # evaluation
122
+ average_train_loss = running_loss / t_samples
123
+ average_valid_loss = valid_running_loss / \
124
+ len(valid_loader.dataset)
125
+ train_loss_list.append(average_train_loss)
126
+ valid_loss_list.append(average_valid_loss)
127
+ global_steps_list.append(global_step)
128
+
129
+ display_eval(epoch, epochs, len(train_loader.dataset), global_step, t_correct, t_samples, len(
130
+ valid_loader.dataset), average_train_loss, average_valid_loss, total_acc_val)
131
+
132
+ # resetting running values
133
+ model.train()
134
+
135
+ if best_valid_loss > average_valid_loss:
136
+ best_valid_loss = average_valid_loss
137
+ save_model(model, optimizer, best_valid_loss, epoch,
138
+ path=f'{save_dir}model_{model_save_name}.pt')
139
+ save_metrics(train_loss_list, valid_loss_list,
140
+ global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
141
+
142
+ save_metrics(train_loss_list, valid_loss_list, global_steps_list,
143
+ path=f'{save_dir}metrics_{model_save_name}.pt')
144
+ print("Training complete.")
145
+ return model
146
+
147
+
148
+ def evaluate_RNN(model, test_loader, device="cuda"):
149
+ model.eval()
150
+ y_pred = []
151
+ y_true = []
152
+
153
+ wavelet = 'db4'
154
+ level = 3
155
+
156
+ total_acc_test = 0
157
+ with torch.no_grad():
158
+ for images, labels, notes in test_loader:
159
+ coeffs = pywt.wavedec(images, wavelet, level=level, axis=1)
160
+ threshold = 0.1 * \
161
+ torch.median(torch.abs(torch.from_numpy(coeffs[-1])))
162
+ denoised_coeffs = [pywt.threshold(
163
+ data=c, mode='hard', value=threshold) for c in coeffs]
164
+ images = pywt.waverec(denoised_coeffs, wavelet, axis=1)
165
+
166
+ images = torch.tensor(images).float().to(device)
167
+ labels = labels.to(device)
168
+ notes = notes.to(device)
169
+ output = model(images, notes)
170
+
171
+ values, indices = torch.max(output, dim=1)
172
+ y_pred.extend(indices.tolist())
173
+ y_true.extend(labels.tolist())
174
+ total_acc_test += sum(1 for s,
175
+ i in enumerate(indices) if labels[s][i] == 1)
176
+
177
+ test_accuracy = total_acc_test / len(test_loader.dataset)
178
+ print(f'Test Accuracy: {test_accuracy: .3f}')
179
+
180
+ return test_accuracy
181
+
182
+
183
+ def rename_with_acc(save_name, save_dir, acc):
184
+ acc = round(acc*100)
185
+ # Rename model
186
+ new_model_name = f'{save_dir}model_{save_name}_acc_{acc}.pt'
187
+ new_metrics_name = f'{save_dir}metrics_{save_name}_acc_{acc}.pt'
188
+
189
+ if os.path.isfile(new_model_name):
190
+ os.remove(new_model_name)
191
+ if os.path.isfile(new_metrics_name):
192
+ os.remove(new_metrics_name)
193
+
194
+ os.rename(f'{save_dir}model_{save_name}.pt',
195
+ f'{save_dir}model_{save_name}_acc_{acc}.pt')
196
+ # Rename metrics
197
+ os.rename(f'{save_dir}metrics_{save_name}.pt',
198
+ f'{save_dir}metrics_{save_name}_acc_{acc}.pt')
utils/__pycache__/RNN_utils.cpython-39.pyc ADDED
Binary file (5.71 kB). View file
 
utils/__pycache__/helper_functions.cpython-39.pyc ADDED
Binary file (2.91 kB). View file
 
utils/__pycache__/trainer.cpython-39.pyc ADDED
Binary file (3.28 kB). View file
 
utils/helper_functions.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def define_optimizer(model, lr, alpha):
4
+ # Define optimizer
5
+ optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, alpha=alpha)
6
+ optimizer.zero_grad()
7
+ return optimizer
8
+
9
+ def tuple_of_tensors_to_tensor(tuple_of_tensors):
10
+ return torch.stack(list(tuple_of_tensors), dim=0)
11
+
12
+ def predict(model, inputs, notes, device):
13
+ outputs = model.forward(inputs, notes)
14
+ predicted = torch.sigmoid(outputs)
15
+ predicted = (predicted>0.5).float()
16
+ return outputs, predicted
17
+
18
+ def display_train(epoch, num_epochs, i, model, correct, total, loss, train_loader, valid_loader, device):
19
+ print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Train Loss: {loss.item():.4f}')
20
+ train_accuracy = correct/total
21
+ print(f'Epoch [{epoch+1}/{num_epochs}], Train Accuracy: {train_accuracy:.4f}')
22
+ valid_loss, valid_accuracy = eval_valid(model, valid_loader, epoch, num_epochs, device)
23
+ return train_accuracy, valid_accuracy, valid_loss
24
+
25
+ def eval_valid(model, valid_loader, epoch, num_epochs, device):
26
+ # Compute model train accuracy on test after all samples have been seen using test samples
27
+ model.eval()
28
+ with torch.no_grad():
29
+ correct = 0
30
+ total = 0
31
+ running_loss = 0
32
+ for inputs, labels, notes in valid_loader:
33
+ # Get images and labels from test loader
34
+ inputs = inputs.transpose(1,2).float().to(device)
35
+ labels = labels.float().to(device)
36
+ notes = notes.to(device)
37
+
38
+ # Forward pass and predict class using max
39
+ # outputs = model(inputs)
40
+ outputs, predicted = predict(model, inputs, notes, device) #torch.max(outputs.data, 1)
41
+ loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, labels)
42
+ running_loss += loss.item()*len(labels)
43
+
44
+ # Check if predicted class matches label and count numbler of correct predictions
45
+ total += labels.size(0)
46
+ #TODO: change acc criteria
47
+ # correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() # (predicted == labels).sum().item()
48
+ values, indices = torch.max(outputs,dim=1)
49
+ correct += sum(1 for s, i in enumerate(indices)
50
+ if labels[s][i] == 1)
51
+
52
+ # Compute final accuracy and display
53
+ valid_accuracy = correct/total
54
+ validation_loss = running_loss/total
55
+ print(f'Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {valid_accuracy:.4f}, Validation Loss: {validation_loss:.4f}')
56
+ return validation_loss, valid_accuracy
57
+
58
+
59
+ def eval_test(model, test_loader, device):
60
+ # Compute model test accuracy on test after training
61
+ model.eval()
62
+ with torch.no_grad():
63
+ correct = 0
64
+ total = 0
65
+ for inputs, labels, notes in test_loader:
66
+ # Get images and labels from test loader
67
+ inputs = inputs.transpose(1,2).float().to(device)
68
+ labels = labels.float().to(device)
69
+ notes = notes.to(device)
70
+
71
+ # Forward pass and predict class using max
72
+ # outputs = model(inputs)
73
+ outputs, predicted = predict(model, inputs, notes, device)#torch.max(outputs.data, 1)
74
+
75
+ # Check if predicted class matches label and count numbler of correct predictions
76
+ total += labels.size(0)
77
+ #TODO: change acc criteria
78
+ # correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() # (predicted == labels).sum().item()
79
+ values, indices = torch.max(outputs,dim=1)
80
+ correct += sum(1 for s, i in enumerate(indices)
81
+ if labels[s][i] == 1)
82
+
83
+ # Compute final accuracy and display
84
+ test_accuracy = correct/total
85
+ print(f'Ended Training, Test Accuracy: {test_accuracy:.4f}')
86
+ return test_accuracy
utils/trainer.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .helper_functions import define_optimizer, predict, display_train, eval_test
3
+ from tqdm import tqdm
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+ def save_model(model, optimizer, valid_loss, epoch, path='model.pt'):
8
+ torch.save({'valid_loss': valid_loss,
9
+ 'model_state_dict': model.state_dict(),
10
+ 'epoch': epoch + 1,
11
+ 'optimizer': optimizer.state_dict()
12
+ }, path)
13
+ tqdm.write(f'Model saved to ==> {path}')
14
+
15
+
16
+ def save_metrics(train_loss_list, valid_loss_list, global_steps_list, path='metrics.pt'):
17
+ torch.save({'train_loss_list': train_loss_list,
18
+ 'valid_loss_list': valid_loss_list,
19
+ 'global_steps_list': global_steps_list,
20
+ }, path)
21
+
22
+ def plot_losses(metrics_save_name='metrics', save_dir='./'):
23
+ path = f'{save_dir}metrics_{metrics_save_name}.pt'
24
+ state = torch.load(path)
25
+
26
+ train_loss_list = state['train_loss_list']
27
+ valid_loss_list = state['valid_loss_list']
28
+ global_steps_list = state['global_steps_list']
29
+
30
+ plt.plot(global_steps_list, train_loss_list, label='Train')
31
+ plt.plot(global_steps_list, valid_loss_list, label='Valid')
32
+ plt.xlabel('Global Steps')
33
+ plt.ylabel('Loss')
34
+ plt.legend()
35
+ plt.show()
36
+
37
+ def trainer(model, train_loader, test_loader, valid_loader, num_epochs = 10, lr = 0.01, alpha = 0.99, eval_interval = 10, model_save_name='', save_dir='./'):
38
+
39
+ # Use GPU if available, else use CPU
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ print(device)
42
+
43
+
44
+ # History for train acc, test acc
45
+ train_accs = []
46
+ valid_accs = []
47
+ global_step = 0
48
+ train_loss_list = []
49
+ valid_loss_list = []
50
+ global_steps_list = []
51
+ best_valid_loss = float("inf")
52
+
53
+
54
+ # Define optimizer
55
+ optimizer = define_optimizer(model, lr, alpha)
56
+
57
+ # Training model
58
+ for epoch in range(num_epochs):
59
+ # Go trough all samples in train dataset
60
+ model.train()
61
+ running_loss = 0
62
+ correct = 0
63
+ total = 0
64
+ for i, (inputs, labels, notes) in enumerate(train_loader):
65
+ # Get from dataloader and send to device
66
+ inputs = inputs.transpose(1,2).float().to(device)
67
+ # print(labels.shape)
68
+ labels = labels.float().to(device)
69
+ notes = notes.to(device)
70
+ # print(labels.shape)
71
+
72
+
73
+ # Forward pass
74
+ outputs, predicted = predict(model, inputs, notes, device)
75
+ # print(predicted.shape, labels.shape)
76
+
77
+ # Check if predicted class matches label and count numbler of correct predictions
78
+ total += labels.size(0)
79
+ #TODO: change acc criteria
80
+ # correct += torch.nn.functional.cosine_similarity(labels,predicted).sum().item() #(predicted == labels).sum().item()
81
+ values, indices = torch.max(outputs,dim=1)
82
+ correct += sum(1 for s, i in enumerate(indices)
83
+ if labels[s][i] == 1)
84
+ # Compute loss
85
+ # we use outputs before softmax function to the cross_entropy loss
86
+ loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs, labels)
87
+ running_loss += loss.item()*len(labels)
88
+ global_step += 1*len(inputs)
89
+ # Backward and optimize
90
+ loss.backward()
91
+ optimizer.step()
92
+ optimizer.zero_grad()
93
+
94
+ # Display losses over iterations and evaluate on validation set
95
+ if (i+1) % eval_interval == 0:
96
+ train_accuracy, valid_accuracy, valid_loss = display_train(epoch, num_epochs, i, model, \
97
+ correct, total, loss, \
98
+ train_loader, valid_loader, device)
99
+
100
+ average_train_loss = running_loss / total
101
+ # average_valid_loss = valid_loss
102
+ train_loss_list.append(average_train_loss)
103
+ valid_loss_list.append(valid_loss)
104
+ global_steps_list.append(global_step)
105
+
106
+ if valid_loss < best_valid_loss:
107
+ best_valid_loss = valid_loss
108
+ save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt')
109
+ save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
110
+ # torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_best_lr_{lr}.pt')
111
+
112
+
113
+ if(len(train_loader)%eval_interval!=0):
114
+ train_accuracy, valid_accuracy, valid_loss = display_train(epoch, num_epochs, i, model, \
115
+ correct, total, loss, \
116
+ train_loader, valid_loader, device)
117
+
118
+ average_train_loss = running_loss / total
119
+ # average_valid_loss = valid_loss/len(valid_loader.dataset)
120
+ train_loss_list.append(average_train_loss)
121
+ valid_loss_list.append(valid_loss)
122
+ global_steps_list.append(global_step)
123
+
124
+ if valid_loss < best_valid_loss:
125
+ best_valid_loss = valid_loss
126
+ save_model(model, optimizer, best_valid_loss, epoch, path=f'{save_dir}model_{model_save_name}.pt')
127
+ save_metrics(train_loss_list, valid_loss_list, global_steps_list, path=f'{save_dir}metrics_{model_save_name}.pt')
128
+ # torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_best_lr_{lr}.pt')
129
+ # Append accuracies to list at the end of each iteration
130
+ train_accs.append(train_accuracy)
131
+ valid_accs.append(valid_accuracy)
132
+ # torch.save(model.state_dict(), f'./ckpt_mid/{model.name}_epoch_{epoch}_lr_{lr}.pt')
133
+ save_metrics(train_loss_list, valid_loss_list, global_steps_list,
134
+ path=f'{save_dir}metrics_{model_save_name}.pt')
135
+ # Load best_model
136
+ checkpoint = torch.load(f'{save_dir}model_{model_save_name}.pt')
137
+ model.load_state_dict(checkpoint['model_state_dict'])
138
+ # Evaluate on test after training has completed
139
+ test_acc = eval_test(model, test_loader, device)
140
+ # Return
141
+ return train_accs, valid_accs, test_acc