haritsahm commited on
Commit
861e32a
·
1 Parent(s): 0fd3229

Add model files

Browse files
models/__init__.py ADDED
File without changes
models/hypercomplex_layers.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This layers are borrowed from: https://github.com/eleGAN23/HyperNets
2
+ # by Eleonora Grassucci,
3
+ # Please check the original reposiotry for further explanations.
4
+
5
+ import math
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from numpy.random import RandomState
12
+ from torch.nn import Module, init
13
+ from torch.nn.parameter import Parameter
14
+
15
+ from models import hypercomplex_ops as hp_ops
16
+
17
+ ########################
18
+ ## STANDARD PHM LAYER ##
19
+ ########################
20
+
21
+
22
+ class PHMLinear(nn.Module):
23
+ def __init__(self, n, in_features, out_features, cuda=True):
24
+ super().__init__()
25
+ self.n = n
26
+ self.in_features = in_features
27
+ self.out_features = out_features
28
+ self.cuda = cuda
29
+
30
+ self.bias = nn.Parameter(torch.Tensor(out_features))
31
+
32
+ self.A = nn.Parameter(
33
+ torch.nn.init.xavier_uniform_(torch.zeros((n, n, n))))
34
+
35
+ self.S = nn.Parameter(torch.nn.init.xavier_uniform_(
36
+ torch.zeros((n, self.out_features//n, self.in_features//n))))
37
+
38
+ self.weight = torch.zeros((self.out_features, self.in_features))
39
+
40
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
41
+ bound = 1 / math.sqrt(fan_in)
42
+ init.uniform_(self.bias, -bound, bound)
43
+
44
+ # adapted from Bayer Research's implementation
45
+ def kronecker_product1(self, a, b):
46
+ siz1 = torch.Size(torch.tensor(
47
+ a.shape[-2:]) * torch.tensor(b.shape[-2:]))
48
+ res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
49
+ siz0 = res.shape[:-4]
50
+ out = res.reshape(siz0 + siz1)
51
+ return out
52
+
53
+ def kronecker_product2(self):
54
+ H = torch.zeros((self.out_features, self.in_features))
55
+ for i in range(self.n):
56
+ H = H + torch.kron(self.A[i], self.S[i])
57
+ return H
58
+
59
+ def forward(self, input):
60
+ self.weight = torch.sum(self.kronecker_product1(self.A, self.S), dim=0)
61
+ # self.weight = self.kronecker_product2() <- SLOWER
62
+ input = input.type(dtype=self.weight.type())
63
+ return F.linear(input, weight=self.weight, bias=self.bias)
64
+
65
+ def extra_repr(self) -> str:
66
+ return 'in_features={}, out_features={}, bias={}'.format(
67
+ self.in_features, self.out_features, self.bias is not None)
68
+
69
+ def reset_parameters(self) -> None:
70
+ init.kaiming_uniform_(self.A, a=math.sqrt(5))
71
+ init.kaiming_uniform_(self.S, a=math.sqrt(5))
72
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder)
73
+ bound = 1 / math.sqrt(fan_in)
74
+ init.uniform_(self.bias, -bound, bound)
75
+
76
+ #############################
77
+ ## CONVOLUTIONAL PH LAYER ##
78
+ #############################
79
+
80
+
81
+ class PHConv(Module):
82
+ def __init__(self, n, in_features, out_features, kernel_size, padding=0, stride=1, cuda=True):
83
+ super().__init__()
84
+ self.n = n
85
+ self.in_features = in_features
86
+ self.out_features = out_features
87
+ self.padding = padding
88
+ self.stride = stride
89
+ self.cuda = cuda
90
+
91
+ self.bias = nn.Parameter(torch.Tensor(out_features))
92
+ self.A = nn.Parameter(
93
+ torch.nn.init.xavier_uniform_(torch.zeros((n, n, n))))
94
+ self.F = nn.Parameter(torch.nn.init.xavier_uniform_(
95
+ torch.zeros((n, self.out_features//n, self.in_features//n, kernel_size, kernel_size))))
96
+ self.weight = torch.zeros((self.out_features, self.in_features))
97
+ self.kernel_size = kernel_size
98
+
99
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
100
+ bound = 1 / math.sqrt(fan_in)
101
+ init.uniform_(self.bias, -bound, bound)
102
+
103
+ def kronecker_product1(self, A, F):
104
+ siz1 = torch.Size(torch.tensor(
105
+ A.shape[-2:]) * torch.tensor(F.shape[-4:-2]))
106
+ siz2 = torch.Size(torch.tensor(F.shape[-2:]))
107
+ res = A.unsqueeze(-1).unsqueeze(-3).unsqueeze(-1).unsqueeze(-1) * \
108
+ F.unsqueeze(-4).unsqueeze(-6)
109
+ siz0 = res.shape[:1]
110
+ out = res.reshape(siz0 + siz1 + siz2)
111
+ return out
112
+
113
+ def kronecker_product2(self):
114
+ H = torch.zeros((self.out_features, self.in_features,
115
+ self.kernel_size, self.kernel_size))
116
+ if self.cuda:
117
+ H = H.cuda()
118
+ for i in range(self.n):
119
+ kron_prod = torch.kron(self.A[i], self.F[i]).view(
120
+ self.out_features, self.in_features, self.kernel_size, self.kernel_size)
121
+ H = H + kron_prod
122
+ return H
123
+
124
+ def forward(self, input):
125
+ self.weight = torch.sum(self.kronecker_product1(self.A, self.F), dim=0)
126
+ # self.weight = self.kronecker_product2()
127
+ # if self.cuda:
128
+ # self.weight = self.weight.cuda()
129
+
130
+ input = input.type(dtype=self.weight.type())
131
+
132
+ return F.conv2d(input, weight=self.weight, stride=self.stride, padding=self.padding)
133
+
134
+ def extra_repr(self) -> str:
135
+ return 'in_features={}, out_features={}, bias={}'.format(
136
+ self.in_features, self.out_features, self.bias is not None)
137
+
138
+ def reset_parameters(self) -> None:
139
+ init.kaiming_uniform_(self.A, a=math.sqrt(5))
140
+ init.kaiming_uniform_(self.F, a=math.sqrt(5))
141
+ fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder)
142
+ bound = 1 / math.sqrt(fan_in)
143
+ init.uniform_(self.bias, -bound, bound)
144
+
145
+
146
+ class KroneckerConv(Module):
147
+ r"""Applies a Quaternion Convolution to the incoming data."""
148
+
149
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
150
+ dilatation=1, padding=0, groups=1, bias=True, init_criterion='glorot',
151
+ weight_init='quaternion', seed=None, operation='convolution2d', rotation=False,
152
+ quaternion_format=True, scale=False, learn_A=False, cuda=True, first_layer=False):
153
+
154
+ super().__init__()
155
+
156
+ self.in_channels = in_channels // 4
157
+ self.out_channels = out_channels // 4
158
+ self.stride = stride
159
+ self.padding = padding
160
+ self.groups = groups
161
+ self.dilatation = dilatation
162
+ self.init_criterion = init_criterion
163
+ self.weight_init = weight_init
164
+ self.seed = seed if seed is not None else np.random.randint(0, 1234)
165
+ self.rng = RandomState(self.seed)
166
+ self.operation = operation
167
+ self.rotation = rotation
168
+ self.quaternion_format = quaternion_format
169
+ self.winit = {'quaternion': hp_ops.quaternion_init,
170
+ 'unitary': hp_ops.unitary_init,
171
+ 'random': hp_ops.random_init}[self.weight_init]
172
+ self.scale = scale
173
+ self.learn_A = learn_A
174
+ self.cuda = cuda
175
+ self.first_layer = first_layer
176
+
177
+ (self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
178
+ self.in_channels, self.out_channels, kernel_size)
179
+
180
+ self.r_weight = Parameter(torch.Tensor(*self.w_shape))
181
+ self.i_weight = Parameter(torch.Tensor(*self.w_shape))
182
+ self.j_weight = Parameter(torch.Tensor(*self.w_shape))
183
+ self.k_weight = Parameter(torch.Tensor(*self.w_shape))
184
+
185
+ if self.scale:
186
+ self.scale_param = Parameter(torch.Tensor(self.r_weight.shape))
187
+ else:
188
+ self.scale_param = None
189
+
190
+ if self.rotation:
191
+ self.zero_kernel = Parameter(torch.zeros(
192
+ self.r_weight.shape), requires_grad=False)
193
+ if bias:
194
+ self.bias = Parameter(torch.Tensor(out_channels))
195
+ else:
196
+ self.register_parameter('bias', None)
197
+ self.reset_parameters()
198
+
199
+ def reset_parameters(self):
200
+ hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
201
+ self.kernel_size, self.winit, self.rng, self.init_criterion)
202
+ if self.scale_param is not None:
203
+ torch.nn.init.xavier_uniform_(self.scale_param.data)
204
+ if self.bias is not None:
205
+ self.bias.data.zero_()
206
+
207
+ def forward(self, input):
208
+ if self.rotation:
209
+ # return quaternion_conv_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight,
210
+ # self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation,
211
+ # self.quaternion_format, self.scale_param)
212
+ pass
213
+ else:
214
+ return hp_ops.kronecker_conv(input, self.r_weight, self.i_weight, self.j_weight,
215
+ self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation, self.learn_A, self.cuda, self.first_layer)
216
+
217
+ def __repr__(self):
218
+ return self.__class__.__name__ + '(' \
219
+ + 'in_channels=' + str(self.in_channels) \
220
+ + ', out_channels=' + str(self.out_channels) \
221
+ + ', bias=' + str(self.bias is not None) \
222
+ + ', kernel_size=' + str(self.kernel_size) \
223
+ + ', stride=' + str(self.stride) \
224
+ + ', padding=' + str(self.padding) \
225
+ + ', init_criterion=' + str(self.init_criterion) \
226
+ + ', weight_init=' + str(self.weight_init) \
227
+ + ', seed=' + str(self.seed) \
228
+ + ', rotation=' + str(self.rotation) \
229
+ + ', q_format=' + str(self.quaternion_format) \
230
+ + ', operation=' + str(self.operation) + ')'
231
+
232
+
233
+ class QuaternionTransposeConv(Module):
234
+ r"""Applies a Quaternion Transposed Convolution (or Deconvolution) to the incoming data."""
235
+
236
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
237
+ dilatation=1, padding=0, output_padding=0, groups=1, bias=True, init_criterion='he',
238
+ weight_init='quaternion', seed=None, operation='convolution2d', rotation=False,
239
+ quaternion_format=False):
240
+
241
+ super().__init__()
242
+
243
+ self.in_channels = in_channels // 4
244
+ self.out_channels = out_channels // 4
245
+ self.stride = stride
246
+ self.padding = padding
247
+ self.output_padding = output_padding
248
+ self.groups = groups
249
+ self.dilatation = dilatation
250
+ self.init_criterion = init_criterion
251
+ self.weight_init = weight_init
252
+ self.seed = seed if seed is not None else np.random.randint(0, 1234)
253
+ self.rng = RandomState(self.seed)
254
+ self.operation = operation
255
+ self.rotation = rotation
256
+ self.quaternion_format = quaternion_format
257
+ self.winit = {'quaternion': hp_ops.quaternion_init,
258
+ 'unitary': hp_ops.unitary_init,
259
+ 'random': hp_ops.random_init}[self.weight_init]
260
+
261
+ (self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
262
+ self.out_channels, self.in_channels, kernel_size)
263
+
264
+ self.r_weight = Parameter(torch.Tensor(*self.w_shape))
265
+ self.i_weight = Parameter(torch.Tensor(*self.w_shape))
266
+ self.j_weight = Parameter(torch.Tensor(*self.w_shape))
267
+ self.k_weight = Parameter(torch.Tensor(*self.w_shape))
268
+
269
+ if bias:
270
+ self.bias = Parameter(torch.Tensor(out_channels))
271
+ else:
272
+ self.register_parameter('bias', None)
273
+ self.reset_parameters()
274
+
275
+ def reset_parameters(self):
276
+ hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
277
+ self.kernel_size, self.winit, self.rng, self.init_criterion)
278
+ if self.bias is not None:
279
+ self.bias.data.zero_()
280
+
281
+ def forward(self, input):
282
+ if self.rotation:
283
+ return hp_ops.quaternion_tranpose_conv_rotation(input, self.r_weight, self.i_weight,
284
+ self.j_weight, self.k_weight, self.bias, self.stride, self.padding,
285
+ self.output_padding, self.groups, self.dilatation, self.quaternion_format)
286
+ else:
287
+ return hp_ops.quaternion_transpose_conv(input, self.r_weight, self.i_weight, self.j_weight,
288
+ self.k_weight, self.bias, self.stride, self.padding, self.output_padding,
289
+ self.groups, self.dilatation)
290
+
291
+ def __repr__(self):
292
+ return self.__class__.__name__ + '(' \
293
+ + 'in_channels=' + str(self.in_channels) \
294
+ + ', out_channels=' + str(self.out_channels) \
295
+ + ', bias=' + str(self.bias is not None) \
296
+ + ', kernel_size=' + str(self.kernel_size) \
297
+ + ', stride=' + str(self.stride) \
298
+ + ', padding=' + str(self.padding) \
299
+ + ', dilation=' + str(self.dilation) \
300
+ + ', init_criterion=' + str(self.init_criterion) \
301
+ + ', weight_init=' + str(self.weight_init) \
302
+ + ', seed=' + str(self.seed) \
303
+ + ', operation=' + str(self.operation) + ')'
304
+
305
+
306
+ class QuaternionConv(Module):
307
+ r"""Applies a Quaternion Convolution to the incoming data."""
308
+
309
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
310
+ dilatation=1, padding=0, groups=1, bias=True, init_criterion='glorot',
311
+ weight_init='quaternion', seed=None, operation='convolution2d', rotation=False, quaternion_format=True, scale=False):
312
+
313
+ super().__init__()
314
+
315
+ self.in_channels = in_channels // 4
316
+ self.out_channels = out_channels // 4
317
+ self.stride = stride
318
+ self.padding = padding
319
+ self.groups = groups
320
+ self.dilatation = dilatation
321
+ self.init_criterion = init_criterion
322
+ self.weight_init = weight_init
323
+ self.seed = seed if seed is not None else np.random.randint(0, 1234)
324
+ self.rng = RandomState(self.seed)
325
+ self.operation = operation
326
+ self.rotation = rotation
327
+ self.quaternion_format = quaternion_format
328
+ self.winit = {'quaternion': hp_ops.quaternion_init,
329
+ 'unitary': hp_ops.unitary_init,
330
+ 'random': hp_ops.random_init}[self.weight_init]
331
+ self.scale = scale
332
+
333
+ (self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
334
+ self.in_channels, self.out_channels, kernel_size)
335
+
336
+ self.r_weight = Parameter(torch.Tensor(*self.w_shape))
337
+ self.i_weight = Parameter(torch.Tensor(*self.w_shape))
338
+ self.j_weight = Parameter(torch.Tensor(*self.w_shape))
339
+ self.k_weight = Parameter(torch.Tensor(*self.w_shape))
340
+
341
+ if self.scale:
342
+ self.scale_param = Parameter(torch.Tensor(self.r_weight.shape))
343
+ else:
344
+ self.scale_param = None
345
+
346
+ if self.rotation:
347
+ self.zero_kernel = Parameter(torch.zeros(
348
+ self.r_weight.shape), requires_grad=False)
349
+ if bias:
350
+ self.bias = Parameter(torch.Tensor(out_channels))
351
+ else:
352
+ self.register_parameter('bias', None)
353
+ self.reset_parameters()
354
+
355
+ def reset_parameters(self):
356
+ hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
357
+ self.kernel_size, self.winit, self.rng, self.init_criterion)
358
+ if self.scale_param is not None:
359
+ torch.nn.init.xavier_uniform_(self.scale_param.data)
360
+ if self.bias is not None:
361
+ self.bias.data.zero_()
362
+
363
+ def forward(self, input):
364
+ if self.rotation:
365
+ return hp_ops.quaternion_conv_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight,
366
+ self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation,
367
+ self.quaternion_format, self.scale_param)
368
+ else:
369
+ return hp_ops.quaternion_conv(input, self.r_weight, self.i_weight, self.j_weight,
370
+ self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation)
371
+
372
+ def __repr__(self):
373
+ return self.__class__.__name__ + '(' \
374
+ + 'in_channels=' + str(self.in_channels) \
375
+ + ', out_channels=' + str(self.out_channels) \
376
+ + ', bias=' + str(self.bias is not None) \
377
+ + ', kernel_size=' + str(self.kernel_size) \
378
+ + ', stride=' + str(self.stride) \
379
+ + ', padding=' + str(self.padding) \
380
+ + ', init_criterion=' + str(self.init_criterion) \
381
+ + ', weight_init=' + str(self.weight_init) \
382
+ + ', seed=' + str(self.seed) \
383
+ + ', rotation=' + str(self.rotation) \
384
+ + ', q_format=' + str(self.quaternion_format) \
385
+ + ', operation=' + str(self.operation) + ')'
386
+
387
+
388
+ class QuaternionLinearAutograd(Module):
389
+ r"""Applies a quaternion linear transformation to the incoming data.
390
+
391
+ A custom Autograd function is call to drastically reduce the VRAM consumption. Nonetheless, computing time
392
+ is also slower compared to QuaternionLinear().
393
+ """
394
+
395
+ def __init__(self, in_features, out_features, bias=True,
396
+ init_criterion='glorot', weight_init='quaternion',
397
+ seed=None, rotation=False, quaternion_format=True, scale=False):
398
+
399
+ super().__init__()
400
+ self.in_features = in_features//4
401
+ self.out_features = out_features//4
402
+ self.rotation = rotation
403
+ self.quaternion_format = quaternion_format
404
+ self.r_weight = Parameter(torch.Tensor(
405
+ self.in_features, self.out_features))
406
+ self.i_weight = Parameter(torch.Tensor(
407
+ self.in_features, self.out_features))
408
+ self.j_weight = Parameter(torch.Tensor(
409
+ self.in_features, self.out_features))
410
+ self.k_weight = Parameter(torch.Tensor(
411
+ self.in_features, self.out_features))
412
+ self.scale = scale
413
+
414
+ if self.scale:
415
+ self.scale_param = Parameter(torch.Tensor(
416
+ self.in_features, self.out_features))
417
+ else:
418
+ self.scale_param = None
419
+
420
+ if self.rotation:
421
+ self.zero_kernel = Parameter(torch.zeros(
422
+ self.r_weight.shape), requires_grad=False)
423
+
424
+ if bias:
425
+ self.bias = Parameter(torch.Tensor(self.out_features*4))
426
+ else:
427
+ self.register_parameter('bias', None)
428
+ self.init_criterion = init_criterion
429
+ self.weight_init = weight_init
430
+ self.seed = seed if seed is not None else np.random.randint(0, 1234)
431
+ self.rng = RandomState(self.seed)
432
+ self.reset_parameters()
433
+
434
+ def reset_parameters(self):
435
+ winit = {'quaternion': hp_ops.quaternion_init, 'unitary': hp_ops.unitary_init,
436
+ 'random': hp_ops.random_init}[self.weight_init]
437
+ if self.scale_param is not None:
438
+ torch.nn.init.xavier_uniform_(self.scale_param.data)
439
+ if self.bias is not None:
440
+ self.bias.data.fill_(0)
441
+ hp_ops.affect_init(self.r_weight, self.i_weight, self.j_weight, self.k_weight, winit,
442
+ self.rng, self.init_criterion)
443
+
444
+ def forward(self, input):
445
+ # See the autograd section for explanation of what happens here.
446
+ if self.rotation:
447
+ return hp_ops.quaternion_linear_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias, self.quaternion_format, self.scale_param)
448
+ else:
449
+ return hp_ops.quaternion_linear(input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
450
+
451
+ def __repr__(self):
452
+ return self.__class__.__name__ + '(' \
453
+ + 'in_features=' + str(self.in_features) \
454
+ + ', out_features=' + str(self.out_features) \
455
+ + ', bias=' + str(self.bias is not None) \
456
+ + ', init_criterion=' + str(self.init_criterion) \
457
+ + ', weight_init=' + str(self.weight_init) \
458
+ + ', rotation=' + str(self.rotation) \
459
+ + ', seed=' + str(self.seed) + ')'
460
+
461
+
462
+ class QuaternionLinear(Module):
463
+ r"""Applies a quaternion linear transformation to the incoming data."""
464
+
465
+ def __init__(self, in_features, out_features, bias=True,
466
+ init_criterion='he', weight_init='quaternion',
467
+ seed=None):
468
+
469
+ super().__init__()
470
+ self.in_features = in_features//4
471
+ self.out_features = out_features//4
472
+ self.r_weight = Parameter(torch.Tensor(
473
+ self.in_features, self.out_features))
474
+ self.i_weight = Parameter(torch.Tensor(
475
+ self.in_features, self.out_features))
476
+ self.j_weight = Parameter(torch.Tensor(
477
+ self.in_features, self.out_features))
478
+ self.k_weight = Parameter(torch.Tensor(
479
+ self.in_features, self.out_features))
480
+
481
+ if bias:
482
+ self.bias = Parameter(torch.Tensor(self.out_features*4))
483
+ else:
484
+ self.register_parameter('bias', None)
485
+
486
+ self.init_criterion = init_criterion
487
+ self.weight_init = weight_init
488
+ self.seed = seed if seed is not None else np.random.randint(0, 1234)
489
+ self.rng = RandomState(self.seed)
490
+ self.reset_parameters()
491
+
492
+ def reset_parameters(self):
493
+ winit = {'quaternion': hp_ops.quaternion_init,
494
+ 'unitary': hp_ops.unitary_init}[self.weight_init]
495
+ if self.bias is not None:
496
+ self.bias.data.fill_(0)
497
+ affect_init(self.r_weight, self.i_weight, self.j_weight, self.k_weight, winit,
498
+ self.rng, self.init_criterion)
499
+
500
+ def forward(self, input):
501
+ # See the autograd section for explanation of what happens here.
502
+ if input.dim() == 3:
503
+ T, N, C = input.size()
504
+ input = input.view(T * N, C)
505
+ output = hp_ops.QuaternionLinearFunction.apply(
506
+ input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
507
+ output = output.view(T, N, output.size(1))
508
+ elif input.dim() == 2:
509
+ output = hp_ops.QuaternionLinearFunction.apply(
510
+ input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
511
+ else:
512
+ raise NotImplementedError
513
+
514
+ return output
515
+
516
+ def __repr__(self):
517
+ return self.__class__.__name__ + '(' \
518
+ + 'in_features=' + str(self.in_features) \
519
+ + ', out_features=' + str(self.out_features) \
520
+ + ', bias=' + str(self.bias is not None) \
521
+ + ', init_criterion=' + str(self.init_criterion) \
522
+ + ', weight_init=' + str(self.weight_init) \
523
+ + ', seed=' + str(self.seed) + ')'
models/hypercomplex_ops.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##########################################################
2
+ # pytorch-qnn v1.0
3
+ # Titouan Parcollet
4
+ # LIA, Université d'Avignon et des Pays du Vaucluse
5
+ # ORKIS, Aix-en-provence
6
+ # October 2018
7
+ ##########################################################
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from numpy.random import RandomState
13
+ from scipy.stats import chi
14
+ from torch.autograd import Variable
15
+
16
+
17
+ def q_normalize(input, channel=1):
18
+ r = get_r(input)
19
+ i = get_i(input)
20
+ j = get_j(input)
21
+ k = get_k(input)
22
+
23
+ norm = torch.sqrt(r*r + i*i + j*j + k*k + 0.0001)
24
+ r = r / norm
25
+ i = i / norm
26
+ j = j / norm
27
+ k = k / norm
28
+
29
+ return torch.cat([r, i, j, k], dim=channel)
30
+
31
+
32
+ def check_input(input):
33
+ if input.dim() not in {2, 3, 4, 5}:
34
+ raise RuntimeError(
35
+ 'Quaternion linear accepts only input of dimension 2 or 3. Quaternion conv accepts up to 5 dim '
36
+ ' input.dim = ' + str(input.dim())
37
+ )
38
+
39
+ if input.dim() < 4:
40
+ nb_hidden = input.size()[-1]
41
+ else:
42
+ nb_hidden = input.size()[1]
43
+
44
+ if nb_hidden % 4 != 0:
45
+ raise RuntimeError(
46
+ 'Quaternion Tensors must be divisible by 4.'
47
+ ' input.size()[1] = ' + str(nb_hidden)
48
+ )
49
+ #
50
+ # Getters
51
+ #
52
+
53
+
54
+ def get_r(input):
55
+ check_input(input)
56
+ if input.dim() < 4:
57
+ nb_hidden = input.size()[-1]
58
+ else:
59
+ nb_hidden = input.size()[1]
60
+
61
+ if input.dim() == 2:
62
+ return input.narrow(1, 0, nb_hidden // 4)
63
+ if input.dim() == 3:
64
+ return input.narrow(2, 0, nb_hidden // 4)
65
+ if input.dim() >= 4:
66
+ return input.narrow(1, 0, nb_hidden // 4)
67
+
68
+
69
+ def get_i(input):
70
+ if input.dim() < 4:
71
+ nb_hidden = input.size()[-1]
72
+ else:
73
+ nb_hidden = input.size()[1]
74
+ if input.dim() == 2:
75
+ return input.narrow(1, nb_hidden // 4, nb_hidden // 4)
76
+ if input.dim() == 3:
77
+ return input.narrow(2, nb_hidden // 4, nb_hidden // 4)
78
+ if input.dim() >= 4:
79
+ return input.narrow(1, nb_hidden // 4, nb_hidden // 4)
80
+
81
+
82
+ def get_j(input):
83
+ check_input(input)
84
+ if input.dim() < 4:
85
+ nb_hidden = input.size()[-1]
86
+ else:
87
+ nb_hidden = input.size()[1]
88
+ if input.dim() == 2:
89
+ return input.narrow(1, nb_hidden // 2, nb_hidden // 4)
90
+ if input.dim() == 3:
91
+ return input.narrow(2, nb_hidden // 2, nb_hidden // 4)
92
+ if input.dim() >= 4:
93
+ return input.narrow(1, nb_hidden // 2, nb_hidden // 4)
94
+
95
+
96
+ def get_k(input):
97
+ check_input(input)
98
+ if input.dim() < 4:
99
+ nb_hidden = input.size()[-1]
100
+ else:
101
+ nb_hidden = input.size()[1]
102
+ if input.dim() == 2:
103
+ return input.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4)
104
+ if input.dim() == 3:
105
+ return input.narrow(2, nb_hidden - nb_hidden // 4, nb_hidden // 4)
106
+ if input.dim() >= 4:
107
+ return input.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4)
108
+
109
+
110
+ def get_modulus(input, vector_form=False):
111
+ check_input(input)
112
+ r = get_r(input)
113
+ i = get_i(input)
114
+ j = get_j(input)
115
+ k = get_k(input)
116
+ if vector_form:
117
+ return torch.sqrt(r * r + i * i + j * j + k * k)
118
+ else:
119
+ return torch.sqrt((r * r + i * i + j * j + k * k).sum(dim=0))
120
+
121
+
122
+ def get_normalized(input, eps=0.0001):
123
+ check_input(input)
124
+ data_modulus = get_modulus(input)
125
+ if input.dim() == 2:
126
+ data_modulus_repeated = data_modulus.repeat(1, 4)
127
+ elif input.dim() == 3:
128
+ data_modulus_repeated = data_modulus.repeat(1, 1, 4)
129
+ return input / (data_modulus_repeated.expand_as(input) + eps)
130
+
131
+
132
+ def quaternion_exp(input):
133
+ r = get_r(input)
134
+ i = get_i(input)
135
+ j = get_j(input)
136
+ k = get_k(input)
137
+
138
+ norm_v = torch.sqrt(i*i+j*j+k*k) + 0.0001
139
+ exp = torch.exp(r)
140
+
141
+ r = torch.cos(norm_v)
142
+ i = (i / norm_v) * torch.sin(norm_v)
143
+ j = (j / norm_v) * torch.sin(norm_v)
144
+ k = (k / norm_v) * torch.sin(norm_v)
145
+
146
+ return torch.cat([exp*r, exp*i, exp*j, exp*k], dim=1)
147
+
148
+
149
+ def kronecker_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride,
150
+ padding, groups, dilatation, learn_A, cuda, first_layer=False): # ,
151
+ # mat1_learn, mat2_learn, mat3_learn, mat4_learn):
152
+ """Applies a quaternion convolution to the incoming data:"""
153
+ # Define the initial matrices to build the Hamilton product
154
+ if first_layer:
155
+ mat1 = torch.zeros((4, 4), requires_grad=False).view(4, 4, 1, 1)
156
+ else:
157
+ mat1 = torch.eye(4, requires_grad=False).view(4, 4, 1, 1)
158
+
159
+ # Define the four matrices that summed up build the Hamilton product rule.
160
+ mat2 = torch.tensor([[0, -1, 0, 0],
161
+ [1, 0, 0, 0],
162
+ [0, 0, 0, -1],
163
+ [0, 0, 1, 0]], requires_grad=False).view(4, 4, 1, 1)
164
+ mat3 = torch.tensor([[0, 0, -1, 0],
165
+ [0, 0, 0, 1],
166
+ [1, 0, 0, 0],
167
+ [0, -1, 0, 0]], requires_grad=False).view(4, 4, 1, 1)
168
+ mat4 = torch.tensor([[0, 0, 0, -1],
169
+ [0, 0, -1, 0],
170
+ [0, 1, 0, 0],
171
+ [1, 0, 0, 0]], requires_grad=False).view(4, 4, 1, 1)
172
+
173
+ if cuda:
174
+ mat1, mat2, mat3, mat4 = mat1.cuda(), mat2.cuda(), mat3.cuda(), mat4.cuda()
175
+
176
+ # Sum of kronecker product between the four matrices and the learnable weights.
177
+ cat_kernels_4_quaternion = torch.kron(mat1, r_weight) + \
178
+ torch.kron(mat2, i_weight) + \
179
+ torch.kron(mat3, j_weight) + \
180
+ torch.kron(mat4, k_weight)
181
+
182
+ if input.dim() == 3:
183
+ convfunc = F.conv1d
184
+ elif input.dim() == 4:
185
+ convfunc = F.conv2d
186
+ elif input.dim() == 5:
187
+ convfunc = F.conv3d
188
+ else:
189
+ raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
190
+ ' input.dim = ' + str(input.dim()))
191
+
192
+ return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, dilatation, groups)
193
+
194
+
195
+ def quaternion_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride,
196
+ padding, groups, dilatation):
197
+ """Applies a quaternion convolution to the incoming data:"""
198
+
199
+ cat_kernels_4_r = torch.cat(
200
+ [r_weight, -i_weight, -j_weight, -k_weight], dim=1)
201
+ cat_kernels_4_i = torch.cat(
202
+ [i_weight, r_weight, -k_weight, j_weight], dim=1)
203
+ cat_kernels_4_j = torch.cat(
204
+ [j_weight, k_weight, r_weight, -i_weight], dim=1)
205
+ cat_kernels_4_k = torch.cat(
206
+ [k_weight, -j_weight, i_weight, r_weight], dim=1)
207
+
208
+ cat_kernels_4_quaternion = torch.cat(
209
+ [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=0)
210
+
211
+ if input.dim() == 3:
212
+ convfunc = F.conv1d
213
+ elif input.dim() == 4:
214
+ convfunc = F.conv2d
215
+ elif input.dim() == 5:
216
+ convfunc = F.conv3d
217
+ else:
218
+ raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
219
+ ' input.dim = ' + str(input.dim()))
220
+
221
+ return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, dilatation, groups)
222
+
223
+
224
+ def quaternion_transpose_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride,
225
+ padding, output_padding, groups, dilatation):
226
+ """Applies a quaternion transposed convolution to the incoming data:"""
227
+
228
+ cat_kernels_4_r = torch.cat(
229
+ [r_weight, -i_weight, -j_weight, -k_weight], dim=1)
230
+ cat_kernels_4_i = torch.cat(
231
+ [i_weight, r_weight, -k_weight, j_weight], dim=1)
232
+ cat_kernels_4_j = torch.cat(
233
+ [j_weight, k_weight, r_weight, -i_weight], dim=1)
234
+ cat_kernels_4_k = torch.cat(
235
+ [k_weight, -j_weight, i_weight, r_weight], dim=1)
236
+ cat_kernels_4_quaternion = torch.cat(
237
+ [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=0)
238
+
239
+ if input.dim() == 3:
240
+ convfunc = F.conv_transpose1d
241
+ elif input.dim() == 4:
242
+ convfunc = F.conv_transpose2d
243
+ elif input.dim() == 5:
244
+ convfunc = F.conv_transpose3d
245
+ else:
246
+ raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
247
+ ' input.dim = ' + str(input.dim()))
248
+
249
+ return convfunc(input, cat_kernels_4_quaternion,
250
+ bias, stride, padding, output_padding, groups, dilatation)
251
+
252
+
253
+ def quaternion_conv_rotation(input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias, stride,
254
+ padding, groups, dilatation, quaternion_format, scale=None):
255
+ """Applies a quaternion rotation and convolution transformation to the incoming data:
256
+
257
+ The rotation W*x*W^t can be replaced by R*x following:
258
+ https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
259
+
260
+ Works for unitary and non unitary weights.
261
+
262
+ The initial size of the input must be a multiple of 3 if quaternion_format = False and
263
+ 4 if quaternion_format = True.
264
+ """
265
+
266
+ square_r = (r_weight*r_weight)
267
+ square_i = (i_weight*i_weight)
268
+ square_j = (j_weight*j_weight)
269
+ square_k = (k_weight*k_weight)
270
+
271
+ norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001)
272
+
273
+ # print(norm)
274
+
275
+ r_n_weight = (r_weight / norm)
276
+ i_n_weight = (i_weight / norm)
277
+ j_n_weight = (j_weight / norm)
278
+ k_n_weight = (k_weight / norm)
279
+
280
+ norm_factor = 2.0
281
+
282
+ square_i = norm_factor*(i_n_weight*i_n_weight)
283
+ square_j = norm_factor*(j_n_weight*j_n_weight)
284
+ square_k = norm_factor*(k_n_weight*k_n_weight)
285
+
286
+ ri = (norm_factor*r_n_weight*i_n_weight)
287
+ rj = (norm_factor*r_n_weight*j_n_weight)
288
+ rk = (norm_factor*r_n_weight*k_n_weight)
289
+
290
+ ij = (norm_factor*i_n_weight*j_n_weight)
291
+ ik = (norm_factor*i_n_weight*k_n_weight)
292
+
293
+ jk = (norm_factor*j_n_weight*k_n_weight)
294
+
295
+ if quaternion_format:
296
+ if scale is not None:
297
+ rot_kernel_1 = torch.cat([zero_kernel, scale * (1.0 - (square_j + square_k)),
298
+ scale * (ij-rk), scale * (ik+rj)], dim=1)
299
+ rot_kernel_2 = torch.cat([zero_kernel, scale * (ij+rk), scale *
300
+ (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=1)
301
+ rot_kernel_3 = torch.cat([zero_kernel, scale * (ik-rj), scale * (jk+ri),
302
+ scale * (1.0 - (square_i + square_j))], dim=1)
303
+ else:
304
+ rot_kernel_1 = torch.cat(
305
+ [zero_kernel, (1.0 - (square_j + square_k)), (ij-rk), (ik+rj)], dim=1)
306
+ rot_kernel_2 = torch.cat(
307
+ [zero_kernel, (ij+rk), (1.0 - (square_i + square_k)), (jk-ri)], dim=1)
308
+ rot_kernel_3 = torch.cat(
309
+ [zero_kernel, (ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=1)
310
+
311
+ zero_kernel2 = torch.cat(
312
+ [zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=1)
313
+ global_rot_kernel = torch.cat(
314
+ [zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0)
315
+
316
+ else:
317
+ if scale is not None:
318
+ rot_kernel_1 = torch.cat([scale * (1.0 - (square_j + square_k)),
319
+ scale * (ij-rk), scale * (ik+rj)], dim=0)
320
+ rot_kernel_2 = torch.cat(
321
+ [scale * (ij+rk), scale * (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0)
322
+ rot_kernel_3 = torch.cat([scale * (ik-rj), scale * (jk+ri), scale *
323
+ (1.0 - (square_i + square_j))], dim=0)
324
+ else:
325
+ rot_kernel_1 = torch.cat(
326
+ [1.0 - (square_j + square_k), (ij-rk), (ik+rj)], dim=0)
327
+ rot_kernel_2 = torch.cat(
328
+ [(ij+rk), 1.0 - (square_i + square_k), (jk-ri)], dim=0)
329
+ rot_kernel_3 = torch.cat(
330
+ [(ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0)
331
+
332
+ global_rot_kernel = torch.cat(
333
+ [rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0)
334
+
335
+ # print(input.shape)
336
+ # print(square_r.shape)
337
+ # print(global_rot_kernel.shape)
338
+
339
+ if input.dim() == 3:
340
+ convfunc = F.conv1d
341
+ elif input.dim() == 4:
342
+ convfunc = F.conv2d
343
+ elif input.dim() == 5:
344
+ convfunc = F.conv3d
345
+ else:
346
+ raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
347
+ ' input.dim = ' + str(input.dim()))
348
+
349
+ return convfunc(input, global_rot_kernel, bias, stride, padding, dilatation, groups)
350
+
351
+
352
+ def quaternion_transpose_conv_rotation(
353
+ input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias, stride,
354
+ padding, output_padding, groups, dilatation, quaternion_format):
355
+ """Applies a quaternion rotation and transposed convolution transformation to the incoming data:
356
+
357
+ The rotation W*x*W^t can be replaced by R*x following:
358
+ https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
359
+
360
+ Works for unitary and non unitary weights.
361
+
362
+ The initial size of the input must be a multiple of 3 if quaternion_format = False and
363
+ 4 if quaternion_format = True.
364
+ """
365
+
366
+ square_r = (r_weight*r_weight)
367
+ square_i = (i_weight*i_weight)
368
+ square_j = (j_weight*j_weight)
369
+ square_k = (k_weight*k_weight)
370
+
371
+ norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001)
372
+
373
+ r_weight = (r_weight / norm)
374
+ i_weight = (i_weight / norm)
375
+ j_weight = (j_weight / norm)
376
+ k_weight = (k_weight / norm)
377
+
378
+ norm_factor = 2.0
379
+
380
+ square_i = norm_factor*(i_weight*i_weight)
381
+ square_j = norm_factor*(j_weight*j_weight)
382
+ square_k = norm_factor*(k_weight*k_weight)
383
+
384
+ ri = (norm_factor*r_weight*i_weight)
385
+ rj = (norm_factor*r_weight*j_weight)
386
+ rk = (norm_factor*r_weight*k_weight)
387
+
388
+ ij = (norm_factor*i_weight*j_weight)
389
+ ik = (norm_factor*i_weight*k_weight)
390
+
391
+ jk = (norm_factor*j_weight*k_weight)
392
+
393
+ if quaternion_format:
394
+ rot_kernel_1 = torch.cat(
395
+ [zero_kernel, 1.0 - (square_j + square_k), ij-rk, ik+rj], dim=1)
396
+ rot_kernel_2 = torch.cat(
397
+ [zero_kernel, ij+rk, 1.0 - (square_i + square_k), jk-ri], dim=1)
398
+ rot_kernel_3 = torch.cat(
399
+ [zero_kernel, ik-rj, jk+ri, 1.0 - (square_i + square_j)], dim=1)
400
+
401
+ zero_kernel2 = torch.zeros(rot_kernel_1.shape).cuda()
402
+ global_rot_kernel = torch.cat(
403
+ [zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0)
404
+ else:
405
+ rot_kernel_1 = torch.cat(
406
+ [1.0 - (square_j + square_k), ij-rk, ik+rj], dim=1)
407
+ rot_kernel_2 = torch.cat(
408
+ [ij+rk, 1.0 - (square_i + square_k), jk-ri], dim=1)
409
+ rot_kernel_3 = torch.cat(
410
+ [ik-rj, jk+ri, 1.0 - (square_i + square_j)], dim=1)
411
+ global_rot_kernel = torch.cat(
412
+ [rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0)
413
+
414
+ if input.dim() == 3:
415
+ convfunc = F.conv_transpose1d
416
+ elif input.dim() == 4:
417
+ convfunc = F.conv_transpose2d
418
+ elif input.dim() == 5:
419
+ convfunc = F.conv_transpose3d
420
+ else:
421
+ raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
422
+ ' input.dim = ' + str(input.dim()))
423
+
424
+ return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, output_padding, groups, dilatation)
425
+
426
+
427
+ def quaternion_linear(input, r_weight, i_weight, j_weight, k_weight, bias=True):
428
+ """Applies a quaternion linear transformation to the incoming data:
429
+
430
+ It is important to notice that the forward phase of a QNN is defined
431
+ as W * Inputs (with * equal to the Hamilton product). The constructed
432
+ cat_kernels_4_quaternion is a modified version of the quaternion representation
433
+ so when we do torch.mm(Input,W) it's equivalent to W * Inputs.
434
+ """
435
+
436
+ cat_kernels_4_r = torch.cat(
437
+ [r_weight, -i_weight, -j_weight, -k_weight], dim=0)
438
+ cat_kernels_4_i = torch.cat(
439
+ [i_weight, r_weight, -k_weight, j_weight], dim=0)
440
+ cat_kernels_4_j = torch.cat(
441
+ [j_weight, k_weight, r_weight, -i_weight], dim=0)
442
+ cat_kernels_4_k = torch.cat(
443
+ [k_weight, -j_weight, i_weight, r_weight], dim=0)
444
+ cat_kernels_4_quaternion = torch.cat(
445
+ [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=1)
446
+
447
+ if input.dim() == 2:
448
+
449
+ if bias is not None:
450
+ return torch.addmm(bias, input, cat_kernels_4_quaternion)
451
+ else:
452
+ return torch.mm(input, cat_kernels_4_quaternion)
453
+ else:
454
+ output = torch.matmul(input, cat_kernels_4_quaternion)
455
+ if bias is not None:
456
+ return output+bias
457
+ else:
458
+ return output
459
+
460
+
461
+ def quaternion_linear_rotation(input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias=None,
462
+ quaternion_format=False, scale=None):
463
+ """Applies a quaternion rotation transformation to the incoming data:
464
+
465
+ The rotation W*x*W^t can be replaced by R*x following:
466
+ https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
467
+
468
+ Works for unitary and non unitary weights.
469
+
470
+ The initial size of the input must be a multiple of 3 if quaternion_format = False and
471
+ 4 if quaternion_format = True.
472
+ """
473
+
474
+ square_r = (r_weight*r_weight)
475
+ square_i = (i_weight*i_weight)
476
+ square_j = (j_weight*j_weight)
477
+ square_k = (k_weight*k_weight)
478
+
479
+ norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001)
480
+
481
+ r_n_weight = (r_weight / norm)
482
+ i_n_weight = (i_weight / norm)
483
+ j_n_weight = (j_weight / norm)
484
+ k_n_weight = (k_weight / norm)
485
+
486
+ norm_factor = 2.0
487
+
488
+ square_i = norm_factor*(i_n_weight*i_n_weight)
489
+ square_j = norm_factor*(j_n_weight*j_n_weight)
490
+ square_k = norm_factor*(k_n_weight*k_n_weight)
491
+
492
+ ri = (norm_factor*r_n_weight*i_n_weight)
493
+ rj = (norm_factor*r_n_weight*j_n_weight)
494
+ rk = (norm_factor*r_n_weight*k_n_weight)
495
+
496
+ ij = (norm_factor*i_n_weight*j_n_weight)
497
+ ik = (norm_factor*i_n_weight*k_n_weight)
498
+
499
+ jk = (norm_factor*j_n_weight*k_n_weight)
500
+
501
+ if quaternion_format:
502
+ if scale is not None:
503
+ rot_kernel_1 = torch.cat([zero_kernel, scale * (1.0 - (square_j + square_k)),
504
+ scale * (ij-rk), scale * (ik+rj)], dim=0)
505
+ rot_kernel_2 = torch.cat([zero_kernel, scale * (ij+rk), scale *
506
+ (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0)
507
+ rot_kernel_3 = torch.cat([zero_kernel, scale * (ik-rj), scale * (jk+ri),
508
+ scale * (1.0 - (square_i + square_j))], dim=0)
509
+ else:
510
+ rot_kernel_1 = torch.cat(
511
+ [zero_kernel, (1.0 - (square_j + square_k)), (ij-rk), (ik+rj)], dim=0)
512
+ rot_kernel_2 = torch.cat(
513
+ [zero_kernel, (ij+rk), (1.0 - (square_i + square_k)), (jk-ri)], dim=0)
514
+ rot_kernel_3 = torch.cat(
515
+ [zero_kernel, (ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0)
516
+
517
+ zero_kernel2 = torch.cat(
518
+ [zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=0)
519
+ global_rot_kernel = torch.cat(
520
+ [zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=1)
521
+
522
+ else:
523
+ if scale is not None:
524
+ rot_kernel_1 = torch.cat([scale * (1.0 - (square_j + square_k)),
525
+ scale * (ij-rk), scale * (ik+rj)], dim=0)
526
+ rot_kernel_2 = torch.cat(
527
+ [scale * (ij+rk), scale * (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0)
528
+ rot_kernel_3 = torch.cat([scale * (ik-rj), scale * (jk+ri), scale *
529
+ (1.0 - (square_i + square_j))], dim=0)
530
+ else:
531
+ rot_kernel_1 = torch.cat(
532
+ [1.0 - (square_j + square_k), (ij-rk), (ik+rj)], dim=0)
533
+ rot_kernel_2 = torch.cat(
534
+ [(ij+rk), 1.0 - (square_i + square_k), (jk-ri)], dim=0)
535
+ rot_kernel_3 = torch.cat(
536
+ [(ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0)
537
+
538
+ global_rot_kernel = torch.cat(
539
+ [rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=1)
540
+
541
+ if input.dim() == 2:
542
+ if bias is not None:
543
+ return torch.addmm(bias, input, global_rot_kernel)
544
+ else:
545
+ return torch.mm(input, global_rot_kernel)
546
+ else:
547
+ output = torch.matmul(input, global_rot_kernel)
548
+ if bias is not None:
549
+ return output+bias
550
+ else:
551
+ return output
552
+
553
+
554
+ # Custom AUTOGRAD for lower VRAM consumption
555
+ class QuaternionLinearFunction(torch.autograd.Function):
556
+ @staticmethod
557
+ def forward(ctx, input, r_weight, i_weight, j_weight, k_weight, bias=None):
558
+ ctx.save_for_backward(input, r_weight, i_weight,
559
+ j_weight, k_weight, bias)
560
+ check_input(input)
561
+ cat_kernels_4_r = torch.cat(
562
+ [r_weight, -i_weight, -j_weight, -k_weight], dim=0)
563
+ cat_kernels_4_i = torch.cat(
564
+ [i_weight, r_weight, -k_weight, j_weight], dim=0)
565
+ cat_kernels_4_j = torch.cat(
566
+ [j_weight, k_weight, r_weight, -i_weight], dim=0)
567
+ cat_kernels_4_k = torch.cat(
568
+ [k_weight, -j_weight, i_weight, r_weight], dim=0)
569
+ cat_kernels_4_quaternion = torch.cat(
570
+ [cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=1)
571
+ if input.dim() == 2:
572
+ if bias is not None:
573
+ return torch.addmm(bias, input, cat_kernels_4_quaternion)
574
+ else:
575
+ return torch.mm(input, cat_kernels_4_quaternion)
576
+ else:
577
+ output = torch.matmul(input, cat_kernels_4_quaternion)
578
+ if bias is not None:
579
+ return output+bias
580
+ else:
581
+ return output
582
+
583
+ # This function has only a single output, so it gets only one gradient
584
+ @staticmethod
585
+ def backward(ctx, grad_output):
586
+ input, r_weight, i_weight, j_weight, k_weight, bias = ctx.saved_tensors
587
+ grad_input = grad_weight_r = grad_weight_i = grad_weight_j = grad_weight_k = grad_bias = None
588
+
589
+ input_r = torch.cat([r_weight, -i_weight, -j_weight, -k_weight], dim=0)
590
+ input_i = torch.cat([i_weight, r_weight, -k_weight, j_weight], dim=0)
591
+ input_j = torch.cat([j_weight, k_weight, r_weight, -i_weight], dim=0)
592
+ input_k = torch.cat([k_weight, -j_weight, i_weight, r_weight], dim=0)
593
+ cat_kernels_4_quaternion_T = Variable(
594
+ torch.cat([input_r, input_i, input_j, input_k], dim=1).permute(1, 0), requires_grad=False)
595
+
596
+ r = get_r(input)
597
+ i = get_i(input)
598
+ j = get_j(input)
599
+ k = get_k(input)
600
+ input_r = torch.cat([r, -i, -j, -k], dim=0)
601
+ input_i = torch.cat([i, r, -k, j], dim=0)
602
+ input_j = torch.cat([j, k, r, -i], dim=0)
603
+ input_k = torch.cat([k, -j, i, r], dim=0)
604
+ input_mat = Variable(
605
+ torch.cat([input_r, input_i, input_j, input_k], dim=1), requires_grad=False)
606
+
607
+ r = get_r(grad_output)
608
+ i = get_i(grad_output)
609
+ j = get_j(grad_output)
610
+ k = get_k(grad_output)
611
+ input_r = torch.cat([r, i, j, k], dim=1)
612
+ input_i = torch.cat([-i, r, k, -j], dim=1)
613
+ input_j = torch.cat([-j, -k, r, i], dim=1)
614
+ input_k = torch.cat([-k, j, -i, r], dim=1)
615
+ grad_mat = torch.cat([input_r, input_i, input_j, input_k], dim=0)
616
+
617
+ if ctx.needs_input_grad[0]:
618
+ grad_input = grad_output.mm(cat_kernels_4_quaternion_T)
619
+ if ctx.needs_input_grad[1]:
620
+ grad_weight = grad_mat.permute(1, 0).mm(input_mat).permute(1, 0)
621
+ unit_size_x = r_weight.size(0)
622
+ unit_size_y = r_weight.size(1)
623
+ grad_weight_r = grad_weight.narrow(
624
+ 0, 0, unit_size_x).narrow(1, 0, unit_size_y)
625
+ grad_weight_i = grad_weight.narrow(
626
+ 0, 0, unit_size_x).narrow(1, unit_size_y, unit_size_y)
627
+ grad_weight_j = grad_weight.narrow(
628
+ 0, 0, unit_size_x).narrow(1, unit_size_y*2, unit_size_y)
629
+ grad_weight_k = grad_weight.narrow(
630
+ 0, 0, unit_size_x).narrow(1, unit_size_y*3, unit_size_y)
631
+ if ctx.needs_input_grad[5]:
632
+ grad_bias = grad_output.sum(0).squeeze(0)
633
+
634
+ return grad_input, grad_weight_r, grad_weight_i, grad_weight_j, grad_weight_k, grad_bias
635
+
636
+
637
+ def hamilton_product(q0, q1):
638
+ """
639
+ Applies a Hamilton product q0 * q1:
640
+ Shape:
641
+ - q0, q1 should be (batch_size, quaternion_number)
642
+ (rr' - xx' - yy' - zz') +
643
+ (rx' + xr' + yz' - zy')i +
644
+ (ry' - xz' + yr' + zx')j +
645
+ (rz' + xy' - yx' + zr')k +
646
+ """
647
+
648
+ q1_r = get_r(q1)
649
+ q1_i = get_i(q1)
650
+ q1_j = get_j(q1)
651
+ q1_k = get_k(q1)
652
+
653
+ # rr', xx', yy', and zz'
654
+ r_base = torch.mul(q0, q1)
655
+ # (rr' - xx' - yy' - zz')
656
+ r = get_r(r_base) - get_i(r_base) - get_j(r_base) - get_k(r_base)
657
+
658
+ # rx', xr', yz', and zy'
659
+ i_base = torch.mul(q0, torch.cat([q1_i, q1_r, q1_k, q1_j], dim=1))
660
+ # (rx' + xr' + yz' - zy')
661
+ i = get_r(i_base) + get_i(i_base) + get_j(i_base) - get_k(i_base)
662
+
663
+ # ry', xz', yr', and zx'
664
+ j_base = torch.mul(q0, torch.cat([q1_j, q1_k, q1_r, q1_i], dim=1))
665
+ # (rx' + xr' + yz' - zy')
666
+ j = get_r(j_base) - get_i(j_base) + get_j(j_base) + get_k(j_base)
667
+
668
+ # rz', xy', yx', and zr'
669
+ k_base = torch.mul(q0, torch.cat([q1_k, q1_j, q1_i, q1_r], dim=1))
670
+ # (rx' + xr' + yz' - zy')
671
+ k = get_r(k_base) + get_i(k_base) - get_j(k_base) + get_k(k_base)
672
+
673
+ return torch.cat([r, i, j, k], dim=1)
674
+
675
+ #
676
+ # PARAMETERS INITIALIZATION
677
+ #
678
+
679
+
680
+ def unitary_init(in_features, out_features, rng, kernel_size=None, criterion='he'):
681
+ if kernel_size is not None:
682
+ receptive_field = np.prod(kernel_size)
683
+ fan_in = in_features * receptive_field
684
+ fan_out = out_features * receptive_field
685
+ else:
686
+ fan_in = in_features
687
+ fan_out = out_features
688
+
689
+ if kernel_size is None:
690
+ kernel_shape = (in_features, out_features)
691
+ else:
692
+ if type(kernel_size) is int:
693
+ kernel_shape = (out_features, in_features) + tuple((kernel_size,))
694
+ else:
695
+ kernel_shape = (out_features, in_features) + (*kernel_size,)
696
+
697
+ number_of_weights = np.prod(kernel_shape)
698
+ v_r = np.random.uniform(-1.0, 1.0, number_of_weights)
699
+ v_i = np.random.uniform(-1.0, 1.0, number_of_weights)
700
+ v_j = np.random.uniform(-1.0, 1.0, number_of_weights)
701
+ v_k = np.random.uniform(-1.0, 1.0, number_of_weights)
702
+
703
+ # Unitary quaternion
704
+ for i in range(0, number_of_weights):
705
+ norm = np.sqrt(v_r[i]**2 + v_i[i]**2 + v_j[i]**2 + v_k[i]**2)+0.0001
706
+ v_r[i] /= norm
707
+ v_i[i] /= norm
708
+ v_j[i] /= norm
709
+ v_k[i] /= norm
710
+ v_r = v_r.reshape(kernel_shape)
711
+ v_i = v_i.reshape(kernel_shape)
712
+ v_j = v_j.reshape(kernel_shape)
713
+ v_k = v_k.reshape(kernel_shape)
714
+
715
+ return (v_r, v_i, v_j, v_k)
716
+
717
+
718
+ def random_init(in_features, out_features, rng, kernel_size=None, criterion='glorot'):
719
+ if kernel_size is not None:
720
+ receptive_field = np.prod(kernel_size)
721
+ fan_in = in_features * receptive_field
722
+ fan_out = out_features * receptive_field
723
+ else:
724
+ fan_in = in_features
725
+ fan_out = out_features
726
+
727
+ if criterion == 'glorot':
728
+ s = 1. / np.sqrt(2*(fan_in + fan_out))
729
+ elif criterion == 'he':
730
+ s = 1. / np.sqrt(2*fan_in)
731
+ else:
732
+ raise ValueError('Invalid criterion: ' + criterion)
733
+
734
+ if kernel_size is None:
735
+ kernel_shape = (in_features, out_features)
736
+ else:
737
+ if type(kernel_size) is int:
738
+ kernel_shape = (out_features, in_features) + tuple((kernel_size,))
739
+ else:
740
+ kernel_shape = (out_features, in_features) + (*kernel_size,)
741
+
742
+ number_of_weights = np.prod(kernel_shape)
743
+ v_r = np.random.uniform(-1.0, 1.0, number_of_weights)
744
+ v_i = np.random.uniform(-1.0, 1.0, number_of_weights)
745
+ v_j = np.random.uniform(-1.0, 1.0, number_of_weights)
746
+ v_k = np.random.uniform(-1.0, 1.0, number_of_weights)
747
+
748
+ v_r = v_r.reshape(kernel_shape)
749
+ v_i = v_i.reshape(kernel_shape)
750
+ v_j = v_j.reshape(kernel_shape)
751
+ v_k = v_k.reshape(kernel_shape)
752
+
753
+ weight_r = v_r
754
+ weight_i = v_i
755
+ weight_j = v_j
756
+ weight_k = v_k
757
+ return (weight_r, weight_i, weight_j, weight_k)
758
+
759
+
760
+ def quaternion_init(in_features, out_features, rng, kernel_size=None, criterion='glorot'):
761
+ if kernel_size is not None:
762
+ receptive_field = np.prod(kernel_size)
763
+ fan_in = in_features * receptive_field
764
+ fan_out = out_features * receptive_field
765
+ else:
766
+ fan_in = in_features
767
+ fan_out = out_features
768
+
769
+ if criterion == 'glorot':
770
+ s = 1. / np.sqrt(2*(fan_in + fan_out))
771
+ elif criterion == 'he':
772
+ s = 1. / np.sqrt(2*fan_in)
773
+ else:
774
+ raise ValueError('Invalid criterion: ' + criterion)
775
+
776
+ rng = RandomState(np.random.randint(1, 1234))
777
+
778
+ # Generating randoms and purely imaginary quaternions :
779
+ if kernel_size is None:
780
+ kernel_shape = (in_features, out_features)
781
+ else:
782
+ if type(kernel_size) is int:
783
+ kernel_shape = (out_features, in_features) + tuple((kernel_size,))
784
+ else:
785
+ kernel_shape = (out_features, in_features) + (*kernel_size,)
786
+
787
+ modulus = chi.rvs(4, loc=0, scale=s, size=kernel_shape)
788
+ number_of_weights = np.prod(kernel_shape)
789
+ v_i = np.random.uniform(-1.0, 1.0, number_of_weights)
790
+ v_j = np.random.uniform(-1.0, 1.0, number_of_weights)
791
+ v_k = np.random.uniform(-1.0, 1.0, number_of_weights)
792
+
793
+ # Purely imaginary quaternions unitary
794
+ for i in range(0, number_of_weights):
795
+ norm = np.sqrt(v_i[i]**2 + v_j[i]**2 + v_k[i]**2 + 0.0001)
796
+ v_i[i] /= norm
797
+ v_j[i] /= norm
798
+ v_k[i] /= norm
799
+ v_i = v_i.reshape(kernel_shape)
800
+ v_j = v_j.reshape(kernel_shape)
801
+ v_k = v_k.reshape(kernel_shape)
802
+
803
+ phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)
804
+
805
+ weight_r = modulus * np.cos(phase)
806
+ weight_i = modulus * v_i*np.sin(phase)
807
+ weight_j = modulus * v_j*np.sin(phase)
808
+ weight_k = modulus * v_k*np.sin(phase)
809
+
810
+ return (weight_r, weight_i, weight_j, weight_k)
811
+
812
+
813
+ def create_dropout_mask(dropout_p, size, rng, as_type, operation='linear'):
814
+ if operation == 'linear':
815
+ mask = rng.binomial(n=1, p=1-dropout_p, size=size)
816
+ return Variable(torch.from_numpy(mask).type(as_type))
817
+ else:
818
+ raise Exception("create_dropout_mask accepts only 'linear'. Found operation = "
819
+ + str(operation))
820
+
821
+
822
+ def affect_init(r_weight, i_weight, j_weight, k_weight, init_func, rng, init_criterion):
823
+ if r_weight.size() != i_weight.size() or r_weight.size() != j_weight.size() or \
824
+ r_weight.size() != k_weight.size():
825
+ raise ValueError('The real and imaginary weights '
826
+ 'should have the same size . Found: r:'
827
+ + str(r_weight.size()) + ' i:'
828
+ + str(i_weight.size()) + ' j:'
829
+ + str(j_weight.size()) + ' k:'
830
+ + str(k_weight.size()))
831
+
832
+ elif r_weight.dim() != 2:
833
+ raise Exception('affect_init accepts only matrices. Found dimension = '
834
+ + str(r_weight.dim()))
835
+ kernel_size = None
836
+ r, i, j, k = init_func(r_weight.size(0), r_weight.size(
837
+ 1), rng, kernel_size, init_criterion)
838
+ r, i, j, k = torch.from_numpy(r), torch.from_numpy(
839
+ i), torch.from_numpy(j), torch.from_numpy(k)
840
+ r_weight.data = r.type_as(r_weight.data)
841
+ i_weight.data = i.type_as(i_weight.data)
842
+ j_weight.data = j.type_as(j_weight.data)
843
+ k_weight.data = k.type_as(k_weight.data)
844
+
845
+
846
+ def affect_init_conv(r_weight, i_weight, j_weight, k_weight, kernel_size, init_func, rng,
847
+ init_criterion):
848
+ if r_weight.size() != i_weight.size() or r_weight.size() != j_weight.size() or \
849
+ r_weight.size() != k_weight.size():
850
+ raise ValueError('The real and imaginary weights '
851
+ 'should have the same size . Found: r:'
852
+ + str(r_weight.size()) + ' i:'
853
+ + str(i_weight.size()) + ' j:'
854
+ + str(j_weight.size()) + ' k:'
855
+ + str(k_weight.size()))
856
+
857
+ elif 2 >= r_weight.dim():
858
+ raise Exception('affect_conv_init accepts only tensors that have more than 2 dimensions. Found dimension = '
859
+ + str(real_weight.dim()))
860
+
861
+ r, i, j, k = init_func(
862
+ r_weight.size(1),
863
+ r_weight.size(0),
864
+ rng=rng,
865
+ kernel_size=kernel_size,
866
+ criterion=init_criterion
867
+ )
868
+ r, i, j, k = torch.from_numpy(r), torch.from_numpy(
869
+ i), torch.from_numpy(j), torch.from_numpy(k)
870
+ r_weight.data = r.type_as(r_weight.data)
871
+ i_weight.data = i.type_as(i_weight.data)
872
+ j_weight.data = j.type_as(j_weight.data)
873
+ k_weight.data = k.type_as(k_weight.data)
874
+
875
+
876
+ def get_kernel_and_weight_shape(operation, in_channels, out_channels, kernel_size):
877
+ if operation == 'convolution1d':
878
+ if type(kernel_size) is not int:
879
+ raise ValueError(
880
+ """An invalid kernel_size was supplied for a 1d convolution. The kernel size
881
+ must be integer in the case. Found kernel_size = """ + str(kernel_size)
882
+ )
883
+ else:
884
+ ks = kernel_size
885
+ w_shape = (out_channels, in_channels) + tuple((ks,))
886
+ else: # in case it is 2d or 3d.
887
+ if operation == 'convolution2d' and type(kernel_size) is int:
888
+ ks = (kernel_size, kernel_size)
889
+ elif operation == 'convolution3d' and type(kernel_size) is int:
890
+ ks = (kernel_size, kernel_size, kernel_size)
891
+ elif type(kernel_size) is not int:
892
+ if operation == 'convolution2d' and len(kernel_size) != 2:
893
+ raise ValueError(
894
+ """An invalid kernel_size was supplied for a 2d convolution. The kernel size
895
+ must be either an integer or a tuple of 2. Found kernel_size = """ + str(kernel_size)
896
+ )
897
+ elif operation == 'convolution3d' and len(kernel_size) != 3:
898
+ raise ValueError(
899
+ """An invalid kernel_size was supplied for a 3d convolution. The kernel size
900
+ must be either an integer or a tuple of 3. Found kernel_size = """ + str(kernel_size)
901
+ )
902
+ else:
903
+ ks = kernel_size
904
+ w_shape = (out_channels, in_channels) + (*ks,)
905
+ return ks, w_shape
models/phc_models.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''ResNet in PyTorch.
2
+ For Pre-activation ResNet, see 'preact_resnet.py'.
3
+ Reference:
4
+ [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
6
+ '''
7
+ import sys
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from models.hypercomplex_layers import PHConv
14
+ from utils.utils import load_weights
15
+
16
+ sys.path.append('./models')
17
+
18
+
19
+ class BasicBlock(nn.Module):
20
+ expansion = 1
21
+
22
+ def __init__(self, in_planes, planes, stride=1, n=4):
23
+ super().__init__()
24
+ self.conv1 = PHConv(n,
25
+ in_planes, planes, kernel_size=3, stride=stride, padding=1)
26
+ self.bn1 = nn.BatchNorm2d(planes)
27
+ self.conv2 = PHConv(n, planes, planes, kernel_size=3,
28
+ stride=1, padding=1)
29
+ self.bn2 = nn.BatchNorm2d(planes)
30
+
31
+ self.shortcut = nn.Sequential()
32
+ if stride != 1 or in_planes != self.expansion*planes:
33
+ self.shortcut = nn.Sequential(
34
+ PHConv(n, in_planes, self.expansion*planes,
35
+ kernel_size=1, stride=stride,),
36
+ nn.BatchNorm2d(self.expansion*planes)
37
+ )
38
+
39
+ def forward(self, x):
40
+ out = F.relu(self.bn1(self.conv1(x)))
41
+ out = self.bn2(self.conv2(out))
42
+ out += self.shortcut(x)
43
+ out = F.relu(out)
44
+ return out
45
+
46
+
47
+ class Bottleneck(nn.Module):
48
+ expansion = 2
49
+
50
+ def __init__(self, in_planes, planes, stride=1, n=4):
51
+ super().__init__()
52
+ self.conv1 = PHConv(n, in_planes, planes, kernel_size=1, stride=1)
53
+ self.bn1 = nn.BatchNorm2d(planes)
54
+ self.conv2 = PHConv(n, planes, planes, kernel_size=3,
55
+ stride=stride, padding=1)
56
+ self.bn2 = nn.BatchNorm2d(planes)
57
+ self.conv3 = PHConv(n, planes, self.expansion *
58
+ planes, kernel_size=1, stride=1)
59
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
60
+
61
+ self.shortcut = nn.Sequential()
62
+ if stride != 1 or in_planes != self.expansion*planes:
63
+ self.shortcut = nn.Sequential(
64
+ PHConv(n, in_planes, self.expansion*planes,
65
+ kernel_size=1, stride=stride),
66
+ nn.BatchNorm2d(self.expansion*planes)
67
+ )
68
+
69
+ def forward(self, x):
70
+ out = F.relu(self.bn1(self.conv1(x)))
71
+ out = F.relu(self.bn2(self.conv2(out)))
72
+ out = self.bn3(self.conv3(out))
73
+ out += self.shortcut(x)
74
+ out = F.relu(out)
75
+ return out
76
+
77
+
78
+ class PHCResNet(nn.Module):
79
+ """PHCResNet.
80
+
81
+ Parameters:
82
+ - before_gap_output: True to return the output before refiner blocks and gap
83
+ - gap_output: True to return the output after gap and before final linear layer
84
+ """
85
+
86
+ def __init__(self, block, num_blocks, channels=4, n=4, num_classes=10, before_gap_output=False, gap_output=False, visualize=False):
87
+ super().__init__()
88
+ self.block = block
89
+ self.num_blocks = num_blocks
90
+ self.in_planes = 64
91
+ self.n = n
92
+ self.before_gap_out = before_gap_output
93
+ self.gap_output = gap_output
94
+ self.visualize = visualize
95
+
96
+ self.conv1 = PHConv(n, channels, 64, kernel_size=3,
97
+ stride=1, padding=1)
98
+ self.bn1 = nn.BatchNorm2d(64)
99
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, n=n)
100
+ self.layer2 = self._make_layer(
101
+ block, 128, num_blocks[1], stride=2, n=n)
102
+ self.layer3 = self._make_layer(
103
+ block, 256, num_blocks[2], stride=2, n=n)
104
+ self.layer4 = self._make_layer(
105
+ block, 512, num_blocks[3], stride=2, n=n)
106
+
107
+ # Refiner blocks
108
+ self.layer5 = None
109
+ self.layer6 = None
110
+
111
+ if not before_gap_output and not gap_output:
112
+ self.linear = nn.Linear(512*block.expansion, num_classes)
113
+
114
+ def add_top_blocks(self, num_classes=1):
115
+ # print("Adding top blocks with n = ", self.n)
116
+ self.layer5 = self._make_layer(Bottleneck, 512, 2, stride=2, n=self.n)
117
+ self.layer6 = self._make_layer(Bottleneck, 512, 2, stride=2, n=self.n)
118
+
119
+ if not self.before_gap_out and not self.gap_output:
120
+ self.linear = nn.Linear(1024, num_classes)
121
+
122
+ def _make_layer(self, block, planes, num_blocks, stride, n):
123
+ strides = [stride] + [1]*(num_blocks-1)
124
+ layers = []
125
+ for stride in strides:
126
+ layers.append(block(self.in_planes, planes, stride, n))
127
+ self.in_planes = planes * block.expansion
128
+ return nn.Sequential(*layers)
129
+
130
+ def forward(self, x):
131
+ out = F.relu(self.bn1(self.conv1(x)))
132
+ out = self.layer1(out)
133
+ out = self.layer2(out)
134
+ out = self.layer3(out)
135
+ out4 = self.layer4(out)
136
+
137
+ if self.before_gap_out:
138
+ return out4
139
+
140
+ if self.layer5:
141
+ out5 = self.layer5(out4)
142
+ out6 = self.layer6(out5)
143
+
144
+ # global average pooling (GAP)
145
+ n, c, _, _ = out6.size()
146
+ out = out6.view(n, c, -1).mean(-1)
147
+
148
+ if self.gap_output:
149
+ return out
150
+
151
+ out = self.linear(out)
152
+
153
+ if self.visualize:
154
+ # return the final output and activation maps at two different levels
155
+ return out, out4, out6
156
+ return out
157
+
158
+
159
+ class Encoder(nn.Module):
160
+ """Encoder branch in PHYSBOnet."""
161
+
162
+ def __init__(self, channels, n):
163
+ super().__init__()
164
+ self.in_planes = 64
165
+
166
+ self.conv1 = PHConv(n, channels, 64, kernel_size=3,
167
+ stride=1, padding=1)
168
+ self.bn1 = nn.BatchNorm2d(64)
169
+ self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1, n=n)
170
+ self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2, n=n)
171
+
172
+ def _make_layer(self, block, planes, num_blocks, stride, n):
173
+ strides = [stride] + [1]*(num_blocks-1)
174
+ layers = []
175
+ for stride in strides:
176
+ layers.append(block(self.in_planes, planes, stride, n))
177
+ self.in_planes = planes * block.expansion
178
+ return nn.Sequential(*layers)
179
+
180
+ def forward(self, x):
181
+ out = F.relu(self.bn1(self.conv1(x)))
182
+ out = self.layer1(out)
183
+ out = self.layer2(out)
184
+ return out
185
+
186
+
187
+ class SharedBottleneck(nn.Module):
188
+ """SharedBottleneck in PHYSBOnet."""
189
+
190
+ def __init__(self, n, in_planes):
191
+ super().__init__()
192
+ self.in_planes = in_planes
193
+
194
+ self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2, n=n)
195
+ self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2, n=n)
196
+ self.layer5 = self._make_layer(Bottleneck, 512, 2, stride=2, n=n)
197
+ self.layer6 = self._make_layer(Bottleneck, 512, 2, stride=2, n=n)
198
+
199
+ def _make_layer(self, block, planes, num_blocks, stride, n):
200
+ strides = [stride] + [1]*(num_blocks-1)
201
+ layers = []
202
+ for stride in strides:
203
+ layers.append(block(self.in_planes, planes, stride, n))
204
+ self.in_planes = planes * block.expansion
205
+ return nn.Sequential(*layers)
206
+
207
+ def forward(self, x):
208
+ out = self.layer3(x)
209
+ out = self.layer4(out)
210
+ out = self.layer5(out)
211
+ out = self.layer6(out)
212
+ n, c, _, _ = out.size()
213
+ out = out.view(n, c, -1).mean(-1)
214
+ return out
215
+
216
+
217
+ class Classifier(nn.Module):
218
+ """Classifier branch in PHYSEnet."""
219
+
220
+ def __init__(self, n, num_classes, in_planes=512, visualize=False):
221
+ super().__init__()
222
+ self.in_planes = in_planes
223
+ self.visualize = visualize
224
+
225
+ # Refiner blocks
226
+ self.layer5 = self._make_layer(Bottleneck, 512, 2, stride=2, n=n)
227
+ self.layer6 = self._make_layer(Bottleneck, 512, 2, stride=2, n=n)
228
+ self.linear = nn.Linear(1024, num_classes)
229
+
230
+ def _make_layer(self, block, planes, num_blocks, stride, n):
231
+ strides = [stride] + [1]*(num_blocks-1)
232
+ layers = []
233
+ for stride in strides:
234
+ layers.append(block(self.in_planes, planes, stride, n))
235
+ self.in_planes = planes * block.expansion
236
+ return nn.Sequential(*layers)
237
+
238
+ def forward(self, x):
239
+ out = self.layer5(x)
240
+ feature_maps = self.layer6(out)
241
+
242
+ n, c, _, _ = feature_maps.size()
243
+ out = feature_maps.view(n, c, -1).mean(-1)
244
+ out = self.linear(out)
245
+
246
+ if self.visualize:
247
+ return out, feature_maps
248
+
249
+ return out
250
+
251
+
252
+ class PHYSBOnet(nn.Module):
253
+ """PHYSBOnet.
254
+
255
+ Parameters:
256
+ - shared: True to share the Bottleneck between the two sides, False for the 'concat' version.
257
+ - weights: path to pretrained weights of patch classifier for Encoder branches
258
+ """
259
+
260
+ def __init__(self, n, shared=True, num_classes=1, weights=None):
261
+ super().__init__()
262
+
263
+ self.shared = shared
264
+
265
+ self.encoder_sx = Encoder(channels=2, n=2)
266
+ self.encoder_dx = Encoder(channels=2, n=2)
267
+
268
+ self.shared_resnet = SharedBottleneck(
269
+ n, in_planes=128 if shared else 256)
270
+
271
+ if weights:
272
+ load_weights(self.encoder_sx, weights)
273
+ load_weights(self.encoder_dx, weights)
274
+
275
+ self.classifier_sx = nn.Linear(1024, num_classes)
276
+ self.classifier_dx = nn.Linear(1024, num_classes)
277
+
278
+ def forward(self, x):
279
+ x_sx, x_dx = x
280
+
281
+ # Apply Encoder
282
+ out_sx = self.encoder_sx(x_sx)
283
+ out_dx = self.encoder_dx(x_dx)
284
+
285
+ # Shared layers
286
+ if self.shared:
287
+ out_sx = self.shared_resnet(out_sx)
288
+ out_dx = self.shared_resnet(out_dx)
289
+
290
+ out_sx = self.classifier_sx(out_sx)
291
+ out_dx = self.classifier_dx(out_dx)
292
+
293
+ else: # Concat version
294
+ out = torch.cat([out_sx, out_dx], dim=1)
295
+ out = self.shared_resnet(out)
296
+ out_sx = self.classifier_sx(out)
297
+ out_dx = self.classifier_dx(out)
298
+
299
+ out = torch.cat([out_sx, out_dx], dim=0)
300
+ return out
301
+
302
+
303
+ class PHYSEnet(nn.Module):
304
+ """PHYSEnet.
305
+
306
+ Parameters:
307
+ - weights: path to pretrained weights of patch classifier for PHCResNet18 encoder or path to whole-image classifier
308
+ - patch_weights: True if the weights correspond to patch classifier, False if they are whole-image.
309
+ In the latter case also Classifier branches will be initialized.
310
+ """
311
+
312
+ def __init__(self, n=2, num_classes=1, weights=None, patch_weights=True, visualize=False):
313
+ super().__init__()
314
+ self.visualize = visualize
315
+ self.phcresnet18 = PHCResNet18(
316
+ n=2, num_classes=num_classes, channels=2, before_gap_output=True)
317
+
318
+ if weights:
319
+ print('Loading weights for phcresnet18 from ', weights)
320
+ load_weights(self.phcresnet18, weights)
321
+
322
+ self.classifier_sx = Classifier(n, num_classes, visualize=visualize)
323
+ self.classifier_dx = Classifier(n, num_classes, visualize=visualize)
324
+
325
+ if not patch_weights and weights:
326
+ print('Loading weights for classifiers from ', weights)
327
+ load_weights(self.classifier_sx, weights)
328
+ load_weights(self.classifier_dx, weights)
329
+
330
+ def forward(self, x):
331
+ x_sx, x_dx = x
332
+
333
+ # Apply Encoder
334
+ out_enc_sx = self.phcresnet18(x_sx)
335
+ out_enc_dx = self.phcresnet18(x_dx)
336
+
337
+ if self.visualize:
338
+ out_sx, act_sx = self.classifier_sx(out_enc_sx)
339
+ out_dx, act_dx = self.classifier_dx(out_enc_dx)
340
+ else:
341
+ # Apply refiner blocks + classifier
342
+ out_sx = self.classifier_sx(out_enc_sx)
343
+ out_dx = self.classifier_dx(out_enc_dx)
344
+
345
+ out = torch.cat([out_sx, out_dx], dim=0)
346
+
347
+ if self.visualize:
348
+ return out, out_enc_sx, out_enc_dx, act_sx, act_dx
349
+
350
+ return out
351
+
352
+
353
+ def PHCResNet18(channels=4, n=4, num_classes=10, before_gap_output=False, gap_output=False, visualize=False):
354
+ return PHCResNet(BasicBlock,
355
+ [2, 2, 2, 2],
356
+ channels=channels,
357
+ n=n,
358
+ num_classes=num_classes,
359
+ before_gap_output=before_gap_output,
360
+ gap_output=gap_output,
361
+ visualize=visualize)
362
+
363
+
364
+ def PHCResNet50(channels=4, n=4, num_classes=10):
365
+ return PHCResNet(Bottleneck, [3, 4, 6, 3], channels=channels, n=n, num_classes=num_classes)
models/real_models.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''ResNet in PyTorch.
2
+ For Pre-activation ResNet, see 'preact_resnet.py'.
3
+ Reference:
4
+ [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
6
+ '''
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from utils.utils import load_weights
12
+
13
+
14
+ class BasicBlock(nn.Module):
15
+ expansion = 1
16
+
17
+ def __init__(self, in_planes, planes, stride=1):
18
+ super().__init__()
19
+ self.conv1 = nn.Conv2d(
20
+ in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
21
+ self.bn1 = nn.BatchNorm2d(planes)
22
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
23
+ stride=1, padding=1, bias=False)
24
+ self.bn2 = nn.BatchNorm2d(planes)
25
+
26
+ self.shortcut = nn.Sequential()
27
+ if stride != 1 or in_planes != self.expansion*planes:
28
+ self.shortcut = nn.Sequential(
29
+ nn.Conv2d(in_planes, self.expansion*planes,
30
+ kernel_size=1, stride=stride, bias=False),
31
+ nn.BatchNorm2d(self.expansion*planes)
32
+ )
33
+
34
+ def forward(self, x):
35
+ out = F.relu(self.bn1(self.conv1(x)))
36
+ out = self.bn2(self.conv2(out))
37
+ out += self.shortcut(x)
38
+ out = F.relu(out)
39
+ return out
40
+
41
+
42
+ class Bottleneck(nn.Module):
43
+ expansion = 2
44
+
45
+ def __init__(self, in_planes, planes, stride=1):
46
+ super().__init__()
47
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
48
+ self.bn1 = nn.BatchNorm2d(planes)
49
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
50
+ stride=stride, padding=1, bias=False)
51
+ self.bn2 = nn.BatchNorm2d(planes)
52
+ self.conv3 = nn.Conv2d(planes, self.expansion *
53
+ planes, kernel_size=1, bias=False)
54
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
55
+
56
+ self.shortcut = nn.Sequential()
57
+ if stride != 1 or in_planes != self.expansion*planes:
58
+ self.shortcut = nn.Sequential(
59
+ nn.Conv2d(in_planes, self.expansion*planes,
60
+ kernel_size=1, stride=stride, bias=False),
61
+ nn.BatchNorm2d(self.expansion*planes)
62
+ )
63
+
64
+ def forward(self, x):
65
+ out = F.relu(self.bn1(self.conv1(x)))
66
+ out = F.relu(self.bn2(self.conv2(out)))
67
+ out = self.bn3(self.conv3(out))
68
+ out += self.shortcut(x)
69
+ out = F.relu(out)
70
+ return out
71
+
72
+
73
+ class ResNet(nn.Module):
74
+ def __init__(self, block, num_blocks, channels=4, num_classes=10, gap_output=False, before_gap_output=False, visualize=False):
75
+ super().__init__()
76
+ self.block = block
77
+ self.num_blocks = num_blocks
78
+ self.in_planes = 64
79
+ self.gap_output = gap_output
80
+ self.before_gap_out = before_gap_output
81
+ self.visualize = visualize
82
+
83
+ self.conv1 = nn.Conv2d(channels, 64, kernel_size=3,
84
+ stride=1, padding=1, bias=False)
85
+ self.bn1 = nn.BatchNorm2d(64)
86
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
87
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
88
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
89
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
90
+ self.layer5 = None
91
+ self.layer6 = None
92
+ if not gap_output and not before_gap_output:
93
+ self.linear = nn.Linear(512*block.expansion, num_classes)
94
+
95
+ def add_top_blocks(self, num_classes=1):
96
+ self.layer5 = self._make_layer(Bottleneck, 512, 2, stride=2)
97
+ self.layer6 = self._make_layer(Bottleneck, 512, 2, stride=2)
98
+
99
+ if not self.gap_output and not self.before_gap_out:
100
+ self.linear = nn.Linear(1024, num_classes)
101
+
102
+ def _make_layer(self, block, planes, num_blocks, stride):
103
+ strides = [stride] + [1]*(num_blocks-1)
104
+ layers = []
105
+ for stride in strides:
106
+ layers.append(block(self.in_planes, planes, stride))
107
+ self.in_planes = planes * block.expansion
108
+ return nn.Sequential(*layers)
109
+
110
+ def forward(self, x):
111
+ out = F.relu(self.bn1(self.conv1(x)))
112
+ out = self.layer1(out)
113
+ out = self.layer2(out)
114
+ out = self.layer3(out)
115
+ out4 = self.layer4(out)
116
+
117
+ if self.before_gap_out:
118
+ return out4
119
+
120
+ if self.layer5:
121
+ out5 = self.layer5(out4)
122
+ out6 = self.layer6(out5)
123
+
124
+ n, c, _, _ = out6.size()
125
+ out = out6.view(n, c, -1).mean(-1)
126
+
127
+ if self.gap_output:
128
+ return out
129
+
130
+ out = self.linear(out)
131
+ if self.visualize:
132
+ return out, out4, out6
133
+ return out
134
+
135
+
136
+ class Encoder(nn.Module):
137
+ def __init__(self, channels):
138
+ super().__init__()
139
+ self.in_planes = 64
140
+
141
+ self.conv1 = nn.Conv2d(channels, 64, kernel_size=3,
142
+ stride=1, padding=1, bias=False)
143
+ self.bn1 = nn.BatchNorm2d(64)
144
+ self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
145
+ self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
146
+
147
+ def _make_layer(self, block, planes, num_blocks, stride):
148
+ strides = [stride] + [1]*(num_blocks-1)
149
+ layers = []
150
+ for stride in strides:
151
+ layers.append(block(self.in_planes, planes, stride))
152
+ self.in_planes = planes * block.expansion
153
+ return nn.Sequential(*layers)
154
+
155
+ def forward(self, x):
156
+ out = F.relu(self.bn1(self.conv1(x)))
157
+ out = self.layer1(out)
158
+ out = self.layer2(out)
159
+ return out
160
+
161
+
162
+ class SharedBottleneck(nn.Module):
163
+ def __init__(self, in_planes):
164
+ super().__init__()
165
+ self.in_planes = in_planes
166
+
167
+ self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
168
+ self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
169
+ self.layer5 = self._make_layer(Bottleneck, 512, 2, stride=2)
170
+ self.layer6 = self._make_layer(Bottleneck, 512, 2, stride=2)
171
+
172
+ def _make_layer(self, block, planes, num_blocks, stride):
173
+ strides = [stride] + [1]*(num_blocks-1)
174
+ layers = []
175
+ for stride in strides:
176
+ layers.append(block(self.in_planes, planes, stride))
177
+ self.in_planes = planes * block.expansion
178
+ return nn.Sequential(*layers)
179
+
180
+ def forward(self, x):
181
+ out = self.layer3(x)
182
+ out = self.layer4(out)
183
+ out = self.layer5(out)
184
+ out = self.layer6(out)
185
+ n, c, _, _ = out.size()
186
+ out = out.view(n, c, -1).mean(-1)
187
+ return out
188
+
189
+
190
+ class Classifier(nn.Module):
191
+ def __init__(self, num_classes, in_planes=512, visualize=False):
192
+ super().__init__()
193
+ self.in_planes = in_planes
194
+ self.visualize = visualize
195
+
196
+ self.layer5 = self._make_layer(Bottleneck, 512, 2, stride=2)
197
+ self.layer6 = self._make_layer(Bottleneck, 512, 2, stride=2)
198
+ self.linear = nn.Linear(1024, num_classes)
199
+
200
+ def _make_layer(self, block, planes, num_blocks, stride):
201
+ strides = [stride] + [1]*(num_blocks-1)
202
+ layers = []
203
+ for stride in strides:
204
+ layers.append(block(self.in_planes, planes, stride))
205
+ self.in_planes = planes * block.expansion
206
+ return nn.Sequential(*layers)
207
+
208
+ def forward(self, x):
209
+ out = self.layer5(x)
210
+ feature_maps = self.layer6(out)
211
+
212
+ n, c, _, _ = feature_maps.size()
213
+ out = feature_maps.view(n, c, -1).mean(-1)
214
+ out = self.linear(out)
215
+
216
+ if self.visualize:
217
+ return out, feature_maps
218
+
219
+ return out
220
+
221
+
222
+ class SBOnet(nn.Module):
223
+ """SBOnet.
224
+
225
+ Parameters:
226
+ - shared: True to share the Bottleneck between the two sides, False for the 'concat' version.
227
+ - weights: path to pretrained weights of patch classifier for Encoder branches
228
+ """
229
+
230
+ def __init__(self, shared=True, num_classes=1, weights=None):
231
+ super().__init__()
232
+
233
+ self.shared = shared
234
+
235
+ self.encoder_sx = Encoder(channels=2)
236
+ self.encoder_dx = Encoder(channels=2)
237
+
238
+ self.shared_resnet = SharedBottleneck(in_planes=128 if shared else 256)
239
+
240
+ if weights:
241
+ load_weights(self.encoder_sx, weights)
242
+ load_weights(self.encoder_dx, weights)
243
+
244
+ self.classifier_sx = nn.Linear(1024, num_classes)
245
+ self.classifier_dx = nn.Linear(1024, num_classes)
246
+
247
+ def forward(self, x):
248
+ x_sx, x_dx = x
249
+
250
+ # Apply Encoder
251
+ out_sx = self.encoder_sx(x_sx)
252
+ out_dx = self.encoder_dx(x_dx)
253
+
254
+ # Shared layers
255
+ if self.shared:
256
+ out_sx = self.shared_resnet(out_sx)
257
+ out_dx = self.shared_resnet(out_dx)
258
+
259
+ out_sx = self.classifier_sx(out_sx)
260
+ out_dx = self.classifier_dx(out_dx)
261
+
262
+ else: # Concat version
263
+ out = torch.cat([out_sx, out_dx], dim=1)
264
+ out = self.shared_resnet(out)
265
+ out_sx = self.classifier_sx(out)
266
+ out_dx = self.classifier_dx(out)
267
+
268
+ out = torch.cat([out_sx, out_dx], dim=0)
269
+ return out
270
+
271
+
272
+ class SEnet(nn.Module):
273
+ """SEnet.
274
+
275
+ Parameters:
276
+ - weights: path to pretrained weights of patch classifier for PHCResNet18 encoder or path to whole-image classifier
277
+ - patch_weights: True if the weights correspond to patch classifier, False if they are whole-image.
278
+ In the latter case also Classifier branches will be initialized.
279
+ """
280
+
281
+ def __init__(self, num_classes=1, weights=None, patch_weights=True, visualize=False):
282
+ super().__init__()
283
+ self.visualize = visualize
284
+ self.resnet18 = ResNet18(
285
+ num_classes=num_classes, channels=2, before_gap_output=True)
286
+
287
+ if weights:
288
+ print('Loading weights for resnet18 from ', weights)
289
+ load_weights(self.resnet18, weights)
290
+
291
+ self.classifier_sx = Classifier(num_classes, visualize=visualize)
292
+ self.classifier_dx = Classifier(num_classes, visualize=visualize)
293
+
294
+ if not patch_weights and weights:
295
+ print('Loading weights for classifiers from ', weights)
296
+ load_weights(self.classifier_sx, weights)
297
+ load_weights(self.classifier_dx, weights)
298
+
299
+ def forward(self, x):
300
+ x_sx, x_dx = x
301
+
302
+ # Apply Encoder
303
+ out_enc_sx = self.resnet18(x_sx)
304
+ out_enc_dx = self.resnet18(x_dx)
305
+
306
+ if self.visualize:
307
+ out_sx, act_sx = self.classifier_sx(out_enc_sx)
308
+ out_dx, act_dx = self.classifier_dx(out_enc_dx)
309
+ else:
310
+ # Apply refiner blocks + classifier
311
+ out_sx = self.classifier_sx(out_enc_sx)
312
+ out_dx = self.classifier_dx(out_enc_dx)
313
+
314
+ out = torch.cat([out_sx, out_dx], dim=0)
315
+
316
+ if self.visualize:
317
+ return out, out_enc_sx, out_enc_dx, act_sx, act_dx
318
+
319
+ return out
320
+
321
+
322
+ def ResNet18(num_classes=10, channels=4, gap_output=False, before_gap_output=False, visualize=False):
323
+ return ResNet(BasicBlock,
324
+ [2, 2, 2, 2],
325
+ num_classes=num_classes,
326
+ channels=channels,
327
+ gap_output=gap_output,
328
+ before_gap_output=before_gap_output,
329
+ visualize=visualize)
330
+
331
+
332
+ def ResNet50(num_classes=10, channels=4):
333
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, channels=channels)
utils/__init__.py ADDED
File without changes
utils/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def mean_activations(tensor):
5
+ """Computes mean of activation maps tensor."""
6
+ # squeeze to remove batch dimension
7
+ return torch.mean(tensor.detach().cpu(), dim=1).squeeze(dim=0)
8
+
9
+
10
+ def load_weights(model, weights):
11
+ """Loads the weights of only the layers present in the given model."""
12
+ pretrained_dict = torch.load(weights, map_location='cpu')
13
+ model_dict = model.state_dict()
14
+ pretrained_dict = {k: v for k,
15
+ v in pretrained_dict.items() if k in model_dict}
16
+ model_dict.update(pretrained_dict)
17
+ model.load_state_dict(model_dict)