haritsahm
Add model files
861e32a
raw
history blame
22.6 kB
# This layers are borrowed from: https://github.com/eleGAN23/HyperNets
# by Eleonora Grassucci,
# Please check the original reposiotry for further explanations.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy.random import RandomState
from torch.nn import Module, init
from torch.nn.parameter import Parameter
from models import hypercomplex_ops as hp_ops
########################
## STANDARD PHM LAYER ##
########################
class PHMLinear(nn.Module):
def __init__(self, n, in_features, out_features, cuda=True):
super().__init__()
self.n = n
self.in_features = in_features
self.out_features = out_features
self.cuda = cuda
self.bias = nn.Parameter(torch.Tensor(out_features))
self.A = nn.Parameter(
torch.nn.init.xavier_uniform_(torch.zeros((n, n, n))))
self.S = nn.Parameter(torch.nn.init.xavier_uniform_(
torch.zeros((n, self.out_features//n, self.in_features//n))))
self.weight = torch.zeros((self.out_features, self.in_features))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
# adapted from Bayer Research's implementation
def kronecker_product1(self, a, b):
siz1 = torch.Size(torch.tensor(
a.shape[-2:]) * torch.tensor(b.shape[-2:]))
res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4)
siz0 = res.shape[:-4]
out = res.reshape(siz0 + siz1)
return out
def kronecker_product2(self):
H = torch.zeros((self.out_features, self.in_features))
for i in range(self.n):
H = H + torch.kron(self.A[i], self.S[i])
return H
def forward(self, input):
self.weight = torch.sum(self.kronecker_product1(self.A, self.S), dim=0)
# self.weight = self.kronecker_product2() <- SLOWER
input = input.type(dtype=self.weight.type())
return F.linear(input, weight=self.weight, bias=self.bias)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None)
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.A, a=math.sqrt(5))
init.kaiming_uniform_(self.S, a=math.sqrt(5))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
#############################
## CONVOLUTIONAL PH LAYER ##
#############################
class PHConv(Module):
def __init__(self, n, in_features, out_features, kernel_size, padding=0, stride=1, cuda=True):
super().__init__()
self.n = n
self.in_features = in_features
self.out_features = out_features
self.padding = padding
self.stride = stride
self.cuda = cuda
self.bias = nn.Parameter(torch.Tensor(out_features))
self.A = nn.Parameter(
torch.nn.init.xavier_uniform_(torch.zeros((n, n, n))))
self.F = nn.Parameter(torch.nn.init.xavier_uniform_(
torch.zeros((n, self.out_features//n, self.in_features//n, kernel_size, kernel_size))))
self.weight = torch.zeros((self.out_features, self.in_features))
self.kernel_size = kernel_size
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def kronecker_product1(self, A, F):
siz1 = torch.Size(torch.tensor(
A.shape[-2:]) * torch.tensor(F.shape[-4:-2]))
siz2 = torch.Size(torch.tensor(F.shape[-2:]))
res = A.unsqueeze(-1).unsqueeze(-3).unsqueeze(-1).unsqueeze(-1) * \
F.unsqueeze(-4).unsqueeze(-6)
siz0 = res.shape[:1]
out = res.reshape(siz0 + siz1 + siz2)
return out
def kronecker_product2(self):
H = torch.zeros((self.out_features, self.in_features,
self.kernel_size, self.kernel_size))
if self.cuda:
H = H.cuda()
for i in range(self.n):
kron_prod = torch.kron(self.A[i], self.F[i]).view(
self.out_features, self.in_features, self.kernel_size, self.kernel_size)
H = H + kron_prod
return H
def forward(self, input):
self.weight = torch.sum(self.kronecker_product1(self.A, self.F), dim=0)
# self.weight = self.kronecker_product2()
# if self.cuda:
# self.weight = self.weight.cuda()
input = input.type(dtype=self.weight.type())
return F.conv2d(input, weight=self.weight, stride=self.stride, padding=self.padding)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None)
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.A, a=math.sqrt(5))
init.kaiming_uniform_(self.F, a=math.sqrt(5))
fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
class KroneckerConv(Module):
r"""Applies a Quaternion Convolution to the incoming data."""
def __init__(self, in_channels, out_channels, kernel_size, stride,
dilatation=1, padding=0, groups=1, bias=True, init_criterion='glorot',
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False,
quaternion_format=True, scale=False, learn_A=False, cuda=True, first_layer=False):
super().__init__()
self.in_channels = in_channels // 4
self.out_channels = out_channels // 4
self.stride = stride
self.padding = padding
self.groups = groups
self.dilatation = dilatation
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.operation = operation
self.rotation = rotation
self.quaternion_format = quaternion_format
self.winit = {'quaternion': hp_ops.quaternion_init,
'unitary': hp_ops.unitary_init,
'random': hp_ops.random_init}[self.weight_init]
self.scale = scale
self.learn_A = learn_A
self.cuda = cuda
self.first_layer = first_layer
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
self.in_channels, self.out_channels, kernel_size)
self.r_weight = Parameter(torch.Tensor(*self.w_shape))
self.i_weight = Parameter(torch.Tensor(*self.w_shape))
self.j_weight = Parameter(torch.Tensor(*self.w_shape))
self.k_weight = Parameter(torch.Tensor(*self.w_shape))
if self.scale:
self.scale_param = Parameter(torch.Tensor(self.r_weight.shape))
else:
self.scale_param = None
if self.rotation:
self.zero_kernel = Parameter(torch.zeros(
self.r_weight.shape), requires_grad=False)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
self.kernel_size, self.winit, self.rng, self.init_criterion)
if self.scale_param is not None:
torch.nn.init.xavier_uniform_(self.scale_param.data)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input):
if self.rotation:
# return quaternion_conv_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight,
# self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation,
# self.quaternion_format, self.scale_param)
pass
else:
return hp_ops.kronecker_conv(input, self.r_weight, self.i_weight, self.j_weight,
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation, self.learn_A, self.cuda, self.first_layer)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_channels=' + str(self.in_channels) \
+ ', out_channels=' + str(self.out_channels) \
+ ', bias=' + str(self.bias is not None) \
+ ', kernel_size=' + str(self.kernel_size) \
+ ', stride=' + str(self.stride) \
+ ', padding=' + str(self.padding) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', seed=' + str(self.seed) \
+ ', rotation=' + str(self.rotation) \
+ ', q_format=' + str(self.quaternion_format) \
+ ', operation=' + str(self.operation) + ')'
class QuaternionTransposeConv(Module):
r"""Applies a Quaternion Transposed Convolution (or Deconvolution) to the incoming data."""
def __init__(self, in_channels, out_channels, kernel_size, stride,
dilatation=1, padding=0, output_padding=0, groups=1, bias=True, init_criterion='he',
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False,
quaternion_format=False):
super().__init__()
self.in_channels = in_channels // 4
self.out_channels = out_channels // 4
self.stride = stride
self.padding = padding
self.output_padding = output_padding
self.groups = groups
self.dilatation = dilatation
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.operation = operation
self.rotation = rotation
self.quaternion_format = quaternion_format
self.winit = {'quaternion': hp_ops.quaternion_init,
'unitary': hp_ops.unitary_init,
'random': hp_ops.random_init}[self.weight_init]
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
self.out_channels, self.in_channels, kernel_size)
self.r_weight = Parameter(torch.Tensor(*self.w_shape))
self.i_weight = Parameter(torch.Tensor(*self.w_shape))
self.j_weight = Parameter(torch.Tensor(*self.w_shape))
self.k_weight = Parameter(torch.Tensor(*self.w_shape))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
self.kernel_size, self.winit, self.rng, self.init_criterion)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input):
if self.rotation:
return hp_ops.quaternion_tranpose_conv_rotation(input, self.r_weight, self.i_weight,
self.j_weight, self.k_weight, self.bias, self.stride, self.padding,
self.output_padding, self.groups, self.dilatation, self.quaternion_format)
else:
return hp_ops.quaternion_transpose_conv(input, self.r_weight, self.i_weight, self.j_weight,
self.k_weight, self.bias, self.stride, self.padding, self.output_padding,
self.groups, self.dilatation)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_channels=' + str(self.in_channels) \
+ ', out_channels=' + str(self.out_channels) \
+ ', bias=' + str(self.bias is not None) \
+ ', kernel_size=' + str(self.kernel_size) \
+ ', stride=' + str(self.stride) \
+ ', padding=' + str(self.padding) \
+ ', dilation=' + str(self.dilation) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', seed=' + str(self.seed) \
+ ', operation=' + str(self.operation) + ')'
class QuaternionConv(Module):
r"""Applies a Quaternion Convolution to the incoming data."""
def __init__(self, in_channels, out_channels, kernel_size, stride,
dilatation=1, padding=0, groups=1, bias=True, init_criterion='glorot',
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False, quaternion_format=True, scale=False):
super().__init__()
self.in_channels = in_channels // 4
self.out_channels = out_channels // 4
self.stride = stride
self.padding = padding
self.groups = groups
self.dilatation = dilatation
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.operation = operation
self.rotation = rotation
self.quaternion_format = quaternion_format
self.winit = {'quaternion': hp_ops.quaternion_init,
'unitary': hp_ops.unitary_init,
'random': hp_ops.random_init}[self.weight_init]
self.scale = scale
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation,
self.in_channels, self.out_channels, kernel_size)
self.r_weight = Parameter(torch.Tensor(*self.w_shape))
self.i_weight = Parameter(torch.Tensor(*self.w_shape))
self.j_weight = Parameter(torch.Tensor(*self.w_shape))
self.k_weight = Parameter(torch.Tensor(*self.w_shape))
if self.scale:
self.scale_param = Parameter(torch.Tensor(self.r_weight.shape))
else:
self.scale_param = None
if self.rotation:
self.zero_kernel = Parameter(torch.zeros(
self.r_weight.shape), requires_grad=False)
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight,
self.kernel_size, self.winit, self.rng, self.init_criterion)
if self.scale_param is not None:
torch.nn.init.xavier_uniform_(self.scale_param.data)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, input):
if self.rotation:
return hp_ops.quaternion_conv_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight,
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation,
self.quaternion_format, self.scale_param)
else:
return hp_ops.quaternion_conv(input, self.r_weight, self.i_weight, self.j_weight,
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_channels=' + str(self.in_channels) \
+ ', out_channels=' + str(self.out_channels) \
+ ', bias=' + str(self.bias is not None) \
+ ', kernel_size=' + str(self.kernel_size) \
+ ', stride=' + str(self.stride) \
+ ', padding=' + str(self.padding) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', seed=' + str(self.seed) \
+ ', rotation=' + str(self.rotation) \
+ ', q_format=' + str(self.quaternion_format) \
+ ', operation=' + str(self.operation) + ')'
class QuaternionLinearAutograd(Module):
r"""Applies a quaternion linear transformation to the incoming data.
A custom Autograd function is call to drastically reduce the VRAM consumption. Nonetheless, computing time
is also slower compared to QuaternionLinear().
"""
def __init__(self, in_features, out_features, bias=True,
init_criterion='glorot', weight_init='quaternion',
seed=None, rotation=False, quaternion_format=True, scale=False):
super().__init__()
self.in_features = in_features//4
self.out_features = out_features//4
self.rotation = rotation
self.quaternion_format = quaternion_format
self.r_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.i_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.j_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.k_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.scale = scale
if self.scale:
self.scale_param = Parameter(torch.Tensor(
self.in_features, self.out_features))
else:
self.scale_param = None
if self.rotation:
self.zero_kernel = Parameter(torch.zeros(
self.r_weight.shape), requires_grad=False)
if bias:
self.bias = Parameter(torch.Tensor(self.out_features*4))
else:
self.register_parameter('bias', None)
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.reset_parameters()
def reset_parameters(self):
winit = {'quaternion': hp_ops.quaternion_init, 'unitary': hp_ops.unitary_init,
'random': hp_ops.random_init}[self.weight_init]
if self.scale_param is not None:
torch.nn.init.xavier_uniform_(self.scale_param.data)
if self.bias is not None:
self.bias.data.fill_(0)
hp_ops.affect_init(self.r_weight, self.i_weight, self.j_weight, self.k_weight, winit,
self.rng, self.init_criterion)
def forward(self, input):
# See the autograd section for explanation of what happens here.
if self.rotation:
return hp_ops.quaternion_linear_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias, self.quaternion_format, self.scale_param)
else:
return hp_ops.quaternion_linear(input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features=' + str(self.in_features) \
+ ', out_features=' + str(self.out_features) \
+ ', bias=' + str(self.bias is not None) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', rotation=' + str(self.rotation) \
+ ', seed=' + str(self.seed) + ')'
class QuaternionLinear(Module):
r"""Applies a quaternion linear transformation to the incoming data."""
def __init__(self, in_features, out_features, bias=True,
init_criterion='he', weight_init='quaternion',
seed=None):
super().__init__()
self.in_features = in_features//4
self.out_features = out_features//4
self.r_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.i_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.j_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
self.k_weight = Parameter(torch.Tensor(
self.in_features, self.out_features))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features*4))
else:
self.register_parameter('bias', None)
self.init_criterion = init_criterion
self.weight_init = weight_init
self.seed = seed if seed is not None else np.random.randint(0, 1234)
self.rng = RandomState(self.seed)
self.reset_parameters()
def reset_parameters(self):
winit = {'quaternion': hp_ops.quaternion_init,
'unitary': hp_ops.unitary_init}[self.weight_init]
if self.bias is not None:
self.bias.data.fill_(0)
affect_init(self.r_weight, self.i_weight, self.j_weight, self.k_weight, winit,
self.rng, self.init_criterion)
def forward(self, input):
# See the autograd section for explanation of what happens here.
if input.dim() == 3:
T, N, C = input.size()
input = input.view(T * N, C)
output = hp_ops.QuaternionLinearFunction.apply(
input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
output = output.view(T, N, output.size(1))
elif input.dim() == 2:
output = hp_ops.QuaternionLinearFunction.apply(
input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias)
else:
raise NotImplementedError
return output
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'in_features=' + str(self.in_features) \
+ ', out_features=' + str(self.out_features) \
+ ', bias=' + str(self.bias is not None) \
+ ', init_criterion=' + str(self.init_criterion) \
+ ', weight_init=' + str(self.weight_init) \
+ ', seed=' + str(self.seed) + ')'