File size: 22,632 Bytes
861e32a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 |
# This layers are borrowed from: https://github.com/eleGAN23/HyperNets
# by Eleonora Grassucci,
# Please check the original reposiotry for further explanations.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy.random import RandomState
from torch.nn import Module, init
from torch.nn.parameter import Parameter
from models import hypercomplex_ops as hp_ops
########################
## STANDARD PHM LAYER ##
########################
class PHMLinear(nn.Module):
def __init__(self, n, in_features, out_features, cuda=True):
super().__init__()
self.n = n
self.in_features = in_features
self.out_features = out_features
self.cuda = cuda
self.bias = nn.Parameter(torch.Tensor(out_features))
self.A = nn.Parameter(
torch.nn.init.xavier_uniform_(torch.zeros((n, n, n))))
self.S = nn.Parameter(torch.nn.init.xavier_uniform_(
torch.zeros((n, self.out_features//n, self.in_features//n))))
self.weight = torch.zeros((self.out_features, self.in_features))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
# adapted from Bayer Research's implementation
def kronecker_product1(self, a, b):
siz1 = torch.Size(torch.tensor(
a.shape[-2:]) * torch.tensor(b.shape[-2:]))
res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
siz0 = res.shape[:-4]
out = res.reshape(siz0 + siz1)
return out
def kronecker_product2(self):
H = torch.zeros((self.out_features, self.in_features))
for i in range(self.n):
H = H + torch.kron(self.A[i], self.S[i])
return H
def forward(self, input):
self.weight = torch.sum(self.kronecker_product1(self.A, self.S), dim=0)
# self.weight = self.kronecker_product2() <- SLOWER
input = input.type(dtype=self.weight.type())
return F.linear(input, weight=self.weight, bias=self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None)
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.A, a=math.sqrt(5))
init.kaiming_uniform_(self.S, a=math.sqrt(5))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
#############################
## CONVOLUTIONAL PH LAYER ##
#############################
class PHConv(Module):
def __init__(self, n, in_features, out_features, kernel_size, padding=0, stride=1, cuda=True):
super().__init__()
self.n = n
self.in_features = in_features
self.out_features = out_features
self.padding = padding
self.stride = stride
self.cuda = cuda
self.bias = nn.Parameter(torch.Tensor(out_features))
self.A = nn.Parameter(
torch.nn.init.xavier_uniform_(torch.zeros((n, n, n))))
self.F = nn.Parameter(torch.nn.init.xavier_uniform_(
torch.zeros((n, self.out_features//n, self.in_features//n, kernel_size, kernel_size))))
self.weight = torch.zeros((self.out_features, self.in_features))
self.kernel_size = kernel_size
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def kronecker_product1(self, A, F):
siz1 = torch.Size(torch.tensor(
A.shape[-2:]) * torch.tensor(F.shape[-4:-2]))
siz2 = torch.Size(torch.tensor(F.shape[-2:]))
res = A.unsqueeze(-1).unsqueeze(-3).unsqueeze(-1).unsqueeze(-1) * \
F.unsqueeze(-4).unsqueeze(-6)
siz0 = res.shape[:1]
out = res.reshape(siz0 + siz1 + siz2)
return out
def kronecker_product2(self):
H = torch.zeros((self.out_features, self.in_features,
self.kernel_size, self.kernel_size))
if self.cuda:
H = H.cuda()
for i in range(self.n):
kron_prod = torch.kron(self.A[i], self.F[i]).view(
self.out_features, self.in_features, self.kernel_size, self.kernel_size)
H = H + kron_prod
return H
def forward(self, input):
self.weight = torch.sum(self.kronecker_product1(self.A, self.F), dim=0)
# self.weight = self.kronecker_product2()
# if self.cuda:
# self.weight = self.weight.cuda()
input = input.type(dtype=self.weight.type())
return F.conv2d(input, weight=self.weight, stride=self.stride, padding=self.padding)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None)
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.A, a=math.sqrt(5))
init.kaiming_uniform_(self.F, a=math.sqrt(5))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
class KroneckerConv(Module):
r"""Applies a Quaternion Convolution to the incoming data."""
def __init__(self, in_channels, out_channels, kernel_size, stride,
dilatation=1, padding=0, groups=1, bias=True, init_criterion='glorot',
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False,
quaternion_format=True, scale=False, learn_A=False, cuda=True, first_layer=False):
super().__init__()
self.in_channels = in_channels // 4
self.out_channels = out_channels // 4
self.stride = stride
self.padding = padding
self.groups = groups
self.dilatation = dilatation
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.operation = operation
self.rotation = rotation
self.quaternion_format = quaternion_format
self.winit = {'quaternion': hp_ops.quaternion_init,
'unitary': hp_ops.unitary_init,
'random': hp_ops.random_init}[self.weight_init]
self.scale = scale
self.learn_A = learn_A
self.cuda = cuda
self.first_layer = first_layer
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
self.in_channels, self.out_channels, kernel_size)
self.r_weight = Parameter(torch.Tensor(*self.w_shape))
self.i_weight = Parameter(torch.Tensor(*self.w_shape))
self.j_weight = Parameter(torch.Tensor(*self.w_shape))
self.k_weight = Parameter(torch.Tensor(*self.w_shape))
if self.scale:
self.scale_param = Parameter(torch.Tensor(self.r_weight.shape))
else:
self.scale_param = None
if self.rotation:
self.zero_kernel = Parameter(torch.zeros(
self.r_weight.shape), requires_grad=False)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
self.kernel_size, self.winit, self.rng, self.init_criterion)
if self.scale_param is not None:
torch.nn.init.xavier_uniform_(self.scale_param.data)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input):
if self.rotation:
# return quaternion_conv_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight,
# self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation,
# self.quaternion_format, self.scale_param)
pass
else:
return hp_ops.kronecker_conv(input, self.r_weight, self.i_weight, self.j_weight,
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation, self.learn_A, self.cuda, self.first_layer)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_channels=' + str(self.in_channels) \
+ ', out_channels=' + str(self.out_channels) \
+ ', bias=' + str(self.bias is not None) \
+ ', kernel_size=' + str(self.kernel_size) \
+ ', stride=' + str(self.stride) \
+ ', padding=' + str(self.padding) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', seed=' + str(self.seed) \
+ ', rotation=' + str(self.rotation) \
+ ', q_format=' + str(self.quaternion_format) \
+ ', operation=' + str(self.operation) + ')'
class QuaternionTransposeConv(Module):
r"""Applies a Quaternion Transposed Convolution (or Deconvolution) to the incoming data."""
def __init__(self, in_channels, out_channels, kernel_size, stride,
dilatation=1, padding=0, output_padding=0, groups=1, bias=True, init_criterion='he',
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False,
quaternion_format=False):
super().__init__()
self.in_channels = in_channels // 4
self.out_channels = out_channels // 4
self.stride = stride
self.padding = padding
self.output_padding = output_padding
self.groups = groups
self.dilatation = dilatation
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.operation = operation
self.rotation = rotation
self.quaternion_format = quaternion_format
self.winit = {'quaternion': hp_ops.quaternion_init,
'unitary': hp_ops.unitary_init,
'random': hp_ops.random_init}[self.weight_init]
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
self.out_channels, self.in_channels, kernel_size)
self.r_weight = Parameter(torch.Tensor(*self.w_shape))
self.i_weight = Parameter(torch.Tensor(*self.w_shape))
self.j_weight = Parameter(torch.Tensor(*self.w_shape))
self.k_weight = Parameter(torch.Tensor(*self.w_shape))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
self.kernel_size, self.winit, self.rng, self.init_criterion)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input):
if self.rotation:
return hp_ops.quaternion_tranpose_conv_rotation(input, self.r_weight, self.i_weight,
self.j_weight, self.k_weight, self.bias, self.stride, self.padding,
self.output_padding, self.groups, self.dilatation, self.quaternion_format)
else:
return hp_ops.quaternion_transpose_conv(input, self.r_weight, self.i_weight, self.j_weight,
self.k_weight, self.bias, self.stride, self.padding, self.output_padding,
self.groups, self.dilatation)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_channels=' + str(self.in_channels) \
+ ', out_channels=' + str(self.out_channels) \
+ ', bias=' + str(self.bias is not None) \
+ ', kernel_size=' + str(self.kernel_size) \
+ ', stride=' + str(self.stride) \
+ ', padding=' + str(self.padding) \
+ ', dilation=' + str(self.dilation) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', seed=' + str(self.seed) \
+ ', operation=' + str(self.operation) + ')'
class QuaternionConv(Module):
r"""Applies a Quaternion Convolution to the incoming data."""
def __init__(self, in_channels, out_channels, kernel_size, stride,
dilatation=1, padding=0, groups=1, bias=True, init_criterion='glorot',
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False, quaternion_format=True, scale=False):
super().__init__()
self.in_channels = in_channels // 4
self.out_channels = out_channels // 4
self.stride = stride
self.padding = padding
self.groups = groups
self.dilatation = dilatation
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.operation = operation
self.rotation = rotation
self.quaternion_format = quaternion_format
self.winit = {'quaternion': hp_ops.quaternion_init,
'unitary': hp_ops.unitary_init,
'random': hp_ops.random_init}[self.weight_init]
self.scale = scale
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
self.in_channels, self.out_channels, kernel_size)
self.r_weight = Parameter(torch.Tensor(*self.w_shape))
self.i_weight = Parameter(torch.Tensor(*self.w_shape))
self.j_weight = Parameter(torch.Tensor(*self.w_shape))
self.k_weight = Parameter(torch.Tensor(*self.w_shape))
if self.scale:
self.scale_param = Parameter(torch.Tensor(self.r_weight.shape))
else:
self.scale_param = None
if self.rotation:
self.zero_kernel = Parameter(torch.zeros(
self.r_weight.shape), requires_grad=False)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
self.kernel_size, self.winit, self.rng, self.init_criterion)
if self.scale_param is not None:
torch.nn.init.xavier_uniform_(self.scale_param.data)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input):
if self.rotation:
return hp_ops.quaternion_conv_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight,
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation,
self.quaternion_format, self.scale_param)
else:
return hp_ops.quaternion_conv(input, self.r_weight, self.i_weight, self.j_weight,
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_channels=' + str(self.in_channels) \
+ ', out_channels=' + str(self.out_channels) \
+ ', bias=' + str(self.bias is not None) \
+ ', kernel_size=' + str(self.kernel_size) \
+ ', stride=' + str(self.stride) \
+ ', padding=' + str(self.padding) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', seed=' + str(self.seed) \
+ ', rotation=' + str(self.rotation) \
+ ', q_format=' + str(self.quaternion_format) \
+ ', operation=' + str(self.operation) + ')'
class QuaternionLinearAutograd(Module):
r"""Applies a quaternion linear transformation to the incoming data.
A custom Autograd function is call to drastically reduce the VRAM consumption. Nonetheless, computing time
is also slower compared to QuaternionLinear().
"""
def __init__(self, in_features, out_features, bias=True,
init_criterion='glorot', weight_init='quaternion',
seed=None, rotation=False, quaternion_format=True, scale=False):
super().__init__()
self.in_features = in_features//4
self.out_features = out_features//4
self.rotation = rotation
self.quaternion_format = quaternion_format
self.r_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.i_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.j_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.k_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.scale = scale
if self.scale:
self.scale_param = Parameter(torch.Tensor(
self.in_features, self.out_features))
else:
self.scale_param = None
if self.rotation:
self.zero_kernel = Parameter(torch.zeros(
self.r_weight.shape), requires_grad=False)
if bias:
self.bias = Parameter(torch.Tensor(self.out_features*4))
else:
self.register_parameter('bias', None)
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.reset_parameters()
def reset_parameters(self):
winit = {'quaternion': hp_ops.quaternion_init, 'unitary': hp_ops.unitary_init,
'random': hp_ops.random_init}[self.weight_init]
if self.scale_param is not None:
torch.nn.init.xavier_uniform_(self.scale_param.data)
if self.bias is not None:
self.bias.data.fill_(0)
hp_ops.affect_init(self.r_weight, self.i_weight, self.j_weight, self.k_weight, winit,
self.rng, self.init_criterion)
def forward(self, input):
# See the autograd section for explanation of what happens here.
if self.rotation:
return hp_ops.quaternion_linear_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias, self.quaternion_format, self.scale_param)
else:
return hp_ops.quaternion_linear(input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features=' + str(self.in_features) \
+ ', out_features=' + str(self.out_features) \
+ ', bias=' + str(self.bias is not None) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', rotation=' + str(self.rotation) \
+ ', seed=' + str(self.seed) + ')'
class QuaternionLinear(Module):
r"""Applies a quaternion linear transformation to the incoming data."""
def __init__(self, in_features, out_features, bias=True,
init_criterion='he', weight_init='quaternion',
seed=None):
super().__init__()
self.in_features = in_features//4
self.out_features = out_features//4
self.r_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.i_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.j_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.k_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features*4))
else:
self.register_parameter('bias', None)
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.reset_parameters()
def reset_parameters(self):
winit = {'quaternion': hp_ops.quaternion_init,
'unitary': hp_ops.unitary_init}[self.weight_init]
if self.bias is not None:
self.bias.data.fill_(0)
affect_init(self.r_weight, self.i_weight, self.j_weight, self.k_weight, winit,
self.rng, self.init_criterion)
def forward(self, input):
# See the autograd section for explanation of what happens here.
if input.dim() == 3:
T, N, C = input.size()
input = input.view(T * N, C)
output = hp_ops.QuaternionLinearFunction.apply(
input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
output = output.view(T, N, output.size(1))
elif input.dim() == 2:
output = hp_ops.QuaternionLinearFunction.apply(
input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
else:
raise NotImplementedError
return output
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features=' + str(self.in_features) \
+ ', out_features=' + str(self.out_features) \
+ ', bias=' + str(self.bias is not None) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', seed=' + str(self.seed) + ')'
|