biplab2008 commited on
Commit
c1941df
·
verified ·
1 Parent(s): 19eea80

Create cnn3d_model.py

Browse files
Files changed (1) hide show
  1. cnn3d_model.py +311 -0
cnn3d_model.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+ from typing import NamedTuple, List, Callable, List, Tuple, Optional
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+
8
+ class LinData(NamedTuple):
9
+ in_dim : int # input dimension
10
+ hidden_layers : List[int] # hidden layers including the output layer
11
+ activations : List[Optional[Callable[[torch.Tensor],torch.Tensor]]] # list of activations
12
+ bns : List[bool] # list of bools
13
+ dropouts : List[Optional[float]] # list of dropouts probas
14
+
15
+ class CNNData(NamedTuple):
16
+ in_dim : int # input dimension
17
+ n_f : List[int] # num filters
18
+ kernel_size : List[Tuple] # kernel size [(5,5,5), (3,3,3),(3,3,3)]
19
+ activations : List[Optional[Callable[[torch.Tensor],torch.Tensor]]] # activation list
20
+ bns : List[bool] # batch normialization [True, True, False]
21
+ dropouts : List[Optional[float]] # # list of dropouts probas [.5,0,0]
22
+ #dropouts_ps : list # [0.5,.7, 0]
23
+ paddings : List[Optional[Tuple]] #[(0,0,0),(0,0,0), (0,0,0)]
24
+ strides : List[Optional[Tuple]] #[(1,1,1),(1,1,1),(1,1,1)]
25
+
26
+
27
+ class NetData(NamedTuple):
28
+ cnn3d : CNNData
29
+ lin : LinData
30
+
31
+ def conv3D_output_size(args, img_size):
32
+
33
+ if not isinstance(args, CNNData):
34
+ raise TypeError("input must be a ParserClass")
35
+
36
+ (cin, h , w) = img_size
37
+ # compute output shape of conv3D
38
+ for idx , chan in enumerate(args.kernel_size):
39
+ padding = args.paddings[idx]
40
+ stride = args.strides[idx]
41
+ (cin, h , w) = (np.floor((cin + 2 * padding[0] - chan[0] ) / stride[0] + 1).astype(int),
42
+ np.floor((h + 2 * padding[1] - chan[1] ) / stride[1] + 1).astype(int),
43
+ np.floor((w + 2 * padding[2] - chan[2] ) / stride[2] + 1).astype(int))
44
+
45
+
46
+ final_dim = int(args.n_f[-1] * cin * h * w)
47
+
48
+ return final_dim
49
+
50
+ class CNN3D_Mike(nn.Module):
51
+ def __init__(self, t_dim=30, img_x=256 , img_y=342, drop_p=0, fc_hidden1=256, fc_hidden2=256):
52
+ super(CNN3D_Mike, self).__init__() # set video dimension
53
+ self.t_dim = t_dim
54
+ self.img_x = img_x
55
+ self.img_y = img_y
56
+ # fully connected layer hidden nodes
57
+ self.fc_hidden1, self.fc_hidden2 = fc_hidden1, fc_hidden2
58
+ self.drop_p = drop_p
59
+ #self.num_classes = num_classes
60
+ self.ch1, self.ch2 = 32, 48
61
+ self.k1, self.k2 = (5, 5, 5), (3, 3, 3) # 3d kernel size
62
+ self.s1, self.s2 = (2, 2, 2), (2, 2, 2) # 3d strides
63
+ self.pd1, self.pd2 = (0, 0, 0), (0, 0, 0) # 3d padding # compute conv1 & conv2 output shape
64
+ self.conv1_outshape = conv3D_output_size((self.t_dim, self.img_x, self.img_y), self.pd1, self.k1, self.s1)
65
+ self.conv2_outshape = conv3D_output_size(self.conv1_outshape, self.pd2, self.k2, self.s2)
66
+ self.conv1 = nn.Conv3d(in_channels=1, out_channels=self.ch1, kernel_size=self.k1, stride=self.s1,
67
+ padding=self.pd1)
68
+ self.bn1 = nn.BatchNorm3d(self.ch1)
69
+ self.conv2 = nn.Conv3d(in_channels=self.ch1, out_channels=self.ch2, kernel_size=self.k2, stride=self.s2,
70
+ padding=self.pd2)
71
+ self.bn2 = nn.BatchNorm3d(self.ch2)
72
+ self.relu = nn.ReLU(inplace=True)
73
+ self.drop = nn.Dropout3d(self.drop_p)
74
+ self.pool = nn.MaxPool3d(2)
75
+ self.fc1 = nn.Linear(self.ch2*self.conv2_outshape[0]*self.conv2_outshape[1]*self.conv2_outshape[2],
76
+ self.fc_hidden1) # fully connected hidden layer
77
+ self.fc2 = nn.Linear(self.fc_hidden1, self.fc_hidden2)
78
+ self.fc3 = nn.Linear(self.fc_hidden2,1) # fully connected layer, output = multi-classes
79
+
80
+
81
+ def forward(self, x_3d):
82
+ # Conv 1
83
+ x = self.conv1(x_3d)
84
+
85
+ x = self.bn1(x)
86
+ x = self.relu(x)
87
+ x = self.drop(x)
88
+ # Conv 2
89
+ x = self.conv2(x)
90
+ x = self.bn2(x)
91
+ x = self.relu(x)
92
+ x = self.drop(x)
93
+ # FC 1 and 2
94
+ x = x.view(x.size(0), -1)
95
+ x = F.relu(self.fc1(x))
96
+ x = F.relu(self.fc2(x))
97
+
98
+ #x = F.relu(self.fc3(x))
99
+ #x = F.relu(self.fc3(x))
100
+ x = F.dropout(x, p=self.drop_p, training=self.training)
101
+ #x = self.fc3(x)
102
+ #x = F.softmax(self.fc2(x))
103
+
104
+ x = self.fc3(x)
105
+
106
+
107
+
108
+ return x
109
+
110
+
111
+
112
+ class CNNLayers(nn.Module):
113
+
114
+ def __init__(self, args):
115
+
116
+ super(CNNLayers, self).__init__()
117
+
118
+ self.in_dim = args.in_dim# 1/3
119
+ self.n_f = args.n_f#[32,64]
120
+ self.kernel_size = args.kernel_size # [(5,5,5), (3,3,3)]
121
+ self.activations = args.activations#['relu', 'relu']
122
+ self.bns = args.bns #[True, True],
123
+ self.dropouts = args.dropouts #[True, True]
124
+ #self.dropouts_ps = args.dropouts_ps#[0.5,.7]
125
+ self.paddings = args.paddings #[(0,0,0),(0,0,0)]
126
+ self.strides = args.strides # strides [(1,1,1),(1,1,1),(1,1,1)])
127
+ #self.poolings = args.poolings
128
+
129
+ assert len(self.n_f) == len(self.activations) == len(self.bns) == len(self.dropouts), 'dimensions mismatch : check dimensions!'
130
+
131
+ # generate layers seq of seq
132
+ self._get_layers()
133
+
134
+ def _get_layers(self):
135
+
136
+ layers =nn.ModuleList()
137
+ in_channels = self.in_dim
138
+
139
+ for idx, chans in enumerate(self.n_f):
140
+ sub_layers = nn.ModuleList()
141
+
142
+ sub_layers.append(nn.Conv3d(in_channels = in_channels,
143
+ out_channels = chans, #self.n_f[idx],
144
+ kernel_size = self.kernel_size[idx],
145
+ stride = self.strides[idx],
146
+ padding = self.paddings[idx]
147
+ ))
148
+
149
+
150
+
151
+ if self.bns[idx] : sub_layers.append(nn.BatchNorm3d(num_features = self.n_f[idx]))
152
+
153
+ #if self.dropouts[idx] : sub_layers.append(nn.Dropout3d(p = self.dropouts_ps[idx]))
154
+
155
+ if self.dropouts[idx] : sub_layers.append(nn.Dropout3d(p = self.dropouts[idx]))
156
+
157
+ #if self.activations[idx] : sub_layers.append(self.__class__.get_activation(self.activations[idx]))
158
+
159
+ if self.activations[idx] : sub_layers.append(self.activations[idx])
160
+
161
+ sub_layers = nn.Sequential(*sub_layers)
162
+
163
+ layers.append(sub_layers)
164
+
165
+ in_channels = self.n_f[idx]
166
+
167
+ self.layers = nn.Sequential(*layers)
168
+
169
+
170
+ @staticmethod
171
+ def get_activation(activation):
172
+ if activation == 'relu':
173
+ activation=nn.ReLU()
174
+ elif activation == 'leakyrelu':
175
+ activation=nn.LeakyReLU(negative_slope=0.1)
176
+ elif activation == 'selu':
177
+ activation=nn.SELU()
178
+
179
+ return activation
180
+
181
+
182
+
183
+ def forward(self, x):
184
+
185
+ x = self.layers(x)
186
+
187
+ return x
188
+
189
+
190
+
191
+ class CNN3D(nn.Module):
192
+
193
+ def __init__(self, args):
194
+ super(CNN3D,self).__init__()
195
+ # check datatype
196
+ if not isinstance(args, NetData):
197
+ raise TypeError("input must be a ParserClass")
198
+
199
+ self.cnn3d = CNNLayers(args.cnn3d)
200
+
201
+ self.lin = LinLayers(args.lin)
202
+
203
+ self.in_dim = args.lin.in_dim
204
+
205
+
206
+ def forward(self, x):
207
+
208
+ # cnn 3d
209
+ x = self.cnn3d(x)
210
+
211
+ x = x.view(-1, self.in_dim)
212
+
213
+ # feedforward
214
+ x = self.lin(x)
215
+
216
+ return x
217
+
218
+
219
+
220
+
221
+ class LinLayers(nn.Module):
222
+
223
+ def __init__(self, args):
224
+ super(LinLayers,self).__init__()
225
+
226
+ in_dim= args.in_dim #16,
227
+ hidden_layers= args.hidden_layers #[512,256,128,2],
228
+ activations=args.activations#[nn.LeakyReLU(0.2),nn.LeakyReLU(0.2),nn.LeakyReLU(0.2)],
229
+ batchnorms=args.bns#[True,True,True],
230
+ dropouts = args.dropouts#[None, 0.2, 0.2]
231
+
232
+
233
+ assert len(hidden_layers) == len(activations) == len(batchnorms) == len(dropouts), 'dimensions mismatch!'
234
+
235
+
236
+ layers=nn.ModuleList()
237
+
238
+ if hidden_layers:
239
+ old_dim=in_dim
240
+ for idx,layer in enumerate(hidden_layers):
241
+ sub_layers = nn.ModuleList()
242
+ sub_layers.append(nn.Linear(old_dim,layer))
243
+ if batchnorms[idx] : sub_layers.append(nn.BatchNorm1d(num_features=layer))
244
+ if activations[idx] : sub_layers.append(activations[idx])
245
+ if dropouts[idx] : sub_layers.append(nn.Dropout(p=dropouts[idx]))
246
+ old_dim = layer
247
+
248
+ sub_layers = nn.Sequential(*sub_layers)
249
+
250
+ layers.append(sub_layers)
251
+
252
+
253
+
254
+ else:# for single layer
255
+ layers.append(nn.Linear(in_dim,out_dim))
256
+ if batchnorms : layers.append(nn.BatchNorm1d(num_features=out_dim))
257
+ if activations : layers.append(activations)
258
+ if dropouts : layers.append(nn.Dropout(p=dropouts))
259
+
260
+ self.layers = nn.Sequential(*layers)
261
+
262
+
263
+
264
+ def forward(self,x):
265
+
266
+ x = self.layers(x)
267
+
268
+ return x
269
+
270
+ '''
271
+ def _check_dimensions(self):
272
+ if isinstance(self.hidden_layers,list) :
273
+ assert len(self.hidden_layers)==len(self.activations)
274
+ assert len(self.hidden_layers)==len(self.batchnorms)
275
+ assert len(self.hidden_layers)==len(self.dropouts)
276
+ '''
277
+
278
+
279
+ def load_model():
280
+ # CNN3D Layer's architecture
281
+ cnndata = CNNData(in_dim = 1,
282
+ n_f =[32,48],
283
+ kernel_size=[(5,5,5), (3,3,3)],
284
+ activations=[nn.ReLU(),nn.ReLU()],
285
+ bns = [True, True],
286
+ dropouts = [0, 0],
287
+ paddings = [(0,0,0),(0,0,0)],
288
+ strides = [(2,2,2),(2,2,2)])
289
+
290
+ # Feedforward layer's architecture
291
+ lindata = LinData(in_dim = conv3D_output_size(cnndata, [30, 256, 342]),
292
+ hidden_layers= [256,256,1],
293
+ activations=[nn.ReLU(),nn.ReLU(),None],
294
+ bns=[False,False,False],
295
+ dropouts =[0.2, 0, 0])
296
+
297
+ # combined architecture
298
+ args = NetData(cnndata, lindata)
299
+
300
+ # weight file
301
+ #weight_file = 'cnn3d_epoch_300.pt'
302
+
303
+ # CNN3D model
304
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
305
+ device = torch.device('cpu')
306
+ cnn3d = CNN3D(args).to(device)
307
+ #cnn3d.load_state_dict(torch.load(os.path.join(base_path,'weights',weight_file), map_location=device))
308
+ cnn3d.eval()
309
+ #print(cnn3d)
310
+
311
+ return cnn3d