|
|
|
|
|
|
|
|
|
import math |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from numpy.random import RandomState |
|
from torch.nn import Module, init |
|
from torch.nn.parameter import Parameter |
|
|
|
from models import hypercomplex_ops as hp_ops |
|
|
|
|
|
|
|
|
|
|
|
|
|
class PHMLinear(nn.Module): |
|
def __init__(self, n, in_features, out_features, cuda=True): |
|
super().__init__() |
|
self.n = n |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.cuda = cuda |
|
|
|
self.bias = nn.Parameter(torch.Tensor(out_features)) |
|
|
|
self.A = nn.Parameter( |
|
torch.nn.init.xavier_uniform_(torch.zeros((n, n, n)))) |
|
|
|
self.S = nn.Parameter(torch.nn.init.xavier_uniform_( |
|
torch.zeros((n, self.out_features//n, self.in_features//n)))) |
|
|
|
self.weight = torch.zeros((self.out_features, self.in_features)) |
|
|
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) |
|
bound = 1 / math.sqrt(fan_in) |
|
init.uniform_(self.bias, -bound, bound) |
|
|
|
|
|
def kronecker_product1(self, a, b): |
|
siz1 = torch.Size(torch.tensor( |
|
a.shape[-2:]) * torch.tensor(b.shape[-2:])) |
|
res = a.unsqueeze(-1).unsqueeze(-3) * b.unsqueeze(-2).unsqueeze(-4) |
|
siz0 = res.shape[:-4] |
|
out = res.reshape(siz0 + siz1) |
|
return out |
|
|
|
def kronecker_product2(self): |
|
H = torch.zeros((self.out_features, self.in_features)) |
|
for i in range(self.n): |
|
H = H + torch.kron(self.A[i], self.S[i]) |
|
return H |
|
|
|
def forward(self, input): |
|
self.weight = torch.sum(self.kronecker_product1(self.A, self.S), dim=0) |
|
|
|
input = input.type(dtype=self.weight.type()) |
|
return F.linear(input, weight=self.weight, bias=self.bias) |
|
|
|
def extra_repr(self) -> str: |
|
return 'in_features={}, out_features={}, bias={}'.format( |
|
self.in_features, self.out_features, self.bias is not None) |
|
|
|
def reset_parameters(self) -> None: |
|
init.kaiming_uniform_(self.A, a=math.sqrt(5)) |
|
init.kaiming_uniform_(self.S, a=math.sqrt(5)) |
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder) |
|
bound = 1 / math.sqrt(fan_in) |
|
init.uniform_(self.bias, -bound, bound) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class PHConv(Module): |
|
def __init__(self, n, in_features, out_features, kernel_size, padding=0, stride=1, cuda=True): |
|
super().__init__() |
|
self.n = n |
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.padding = padding |
|
self.stride = stride |
|
self.cuda = cuda |
|
|
|
self.bias = nn.Parameter(torch.Tensor(out_features)) |
|
self.A = nn.Parameter( |
|
torch.nn.init.xavier_uniform_(torch.zeros((n, n, n)))) |
|
self.F = nn.Parameter(torch.nn.init.xavier_uniform_( |
|
torch.zeros((n, self.out_features//n, self.in_features//n, kernel_size, kernel_size)))) |
|
self.weight = torch.zeros((self.out_features, self.in_features)) |
|
self.kernel_size = kernel_size |
|
|
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) |
|
bound = 1 / math.sqrt(fan_in) |
|
init.uniform_(self.bias, -bound, bound) |
|
|
|
def kronecker_product1(self, A, F): |
|
siz1 = torch.Size(torch.tensor( |
|
A.shape[-2:]) * torch.tensor(F.shape[-4:-2])) |
|
siz2 = torch.Size(torch.tensor(F.shape[-2:])) |
|
res = A.unsqueeze(-1).unsqueeze(-3).unsqueeze(-1).unsqueeze(-1) * \ |
|
F.unsqueeze(-4).unsqueeze(-6) |
|
siz0 = res.shape[:1] |
|
out = res.reshape(siz0 + siz1 + siz2) |
|
return out |
|
|
|
def kronecker_product2(self): |
|
H = torch.zeros((self.out_features, self.in_features, |
|
self.kernel_size, self.kernel_size)) |
|
if self.cuda: |
|
H = H.cuda() |
|
for i in range(self.n): |
|
kron_prod = torch.kron(self.A[i], self.F[i]).view( |
|
self.out_features, self.in_features, self.kernel_size, self.kernel_size) |
|
H = H + kron_prod |
|
return H |
|
|
|
def forward(self, input): |
|
self.weight = torch.sum(self.kronecker_product1(self.A, self.F), dim=0) |
|
|
|
|
|
|
|
|
|
input = input.type(dtype=self.weight.type()) |
|
|
|
return F.conv2d(input, weight=self.weight, stride=self.stride, padding=self.padding) |
|
|
|
def extra_repr(self) -> str: |
|
return 'in_features={}, out_features={}, bias={}'.format( |
|
self.in_features, self.out_features, self.bias is not None) |
|
|
|
def reset_parameters(self) -> None: |
|
init.kaiming_uniform_(self.A, a=math.sqrt(5)) |
|
init.kaiming_uniform_(self.F, a=math.sqrt(5)) |
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.placeholder) |
|
bound = 1 / math.sqrt(fan_in) |
|
init.uniform_(self.bias, -bound, bound) |
|
|
|
|
|
class KroneckerConv(Module): |
|
r"""Applies a Quaternion Convolution to the incoming data.""" |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, |
|
dilatation=1, padding=0, groups=1, bias=True, init_criterion='glorot', |
|
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False, |
|
quaternion_format=True, scale=False, learn_A=False, cuda=True, first_layer=False): |
|
|
|
super().__init__() |
|
|
|
self.in_channels = in_channels // 4 |
|
self.out_channels = out_channels // 4 |
|
self.stride = stride |
|
self.padding = padding |
|
self.groups = groups |
|
self.dilatation = dilatation |
|
self.init_criterion = init_criterion |
|
self.weight_init = weight_init |
|
self.seed = seed if seed is not None else np.random.randint(0, 1234) |
|
self.rng = RandomState(self.seed) |
|
self.operation = operation |
|
self.rotation = rotation |
|
self.quaternion_format = quaternion_format |
|
self.winit = {'quaternion': hp_ops.quaternion_init, |
|
'unitary': hp_ops.unitary_init, |
|
'random': hp_ops.random_init}[self.weight_init] |
|
self.scale = scale |
|
self.learn_A = learn_A |
|
self.cuda = cuda |
|
self.first_layer = first_layer |
|
|
|
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation, |
|
self.in_channels, self.out_channels, kernel_size) |
|
|
|
self.r_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.i_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.j_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.k_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
|
|
if self.scale: |
|
self.scale_param = Parameter(torch.Tensor(self.r_weight.shape)) |
|
else: |
|
self.scale_param = None |
|
|
|
if self.rotation: |
|
self.zero_kernel = Parameter(torch.zeros( |
|
self.r_weight.shape), requires_grad=False) |
|
if bias: |
|
self.bias = Parameter(torch.Tensor(out_channels)) |
|
else: |
|
self.register_parameter('bias', None) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight, |
|
self.kernel_size, self.winit, self.rng, self.init_criterion) |
|
if self.scale_param is not None: |
|
torch.nn.init.xavier_uniform_(self.scale_param.data) |
|
if self.bias is not None: |
|
self.bias.data.zero_() |
|
|
|
def forward(self, input): |
|
if self.rotation: |
|
|
|
|
|
|
|
pass |
|
else: |
|
return hp_ops.kronecker_conv(input, self.r_weight, self.i_weight, self.j_weight, |
|
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation, self.learn_A, self.cuda, self.first_layer) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(' \ |
|
+ 'in_channels=' + str(self.in_channels) \ |
|
+ ', out_channels=' + str(self.out_channels) \ |
|
+ ', bias=' + str(self.bias is not None) \ |
|
+ ', kernel_size=' + str(self.kernel_size) \ |
|
+ ', stride=' + str(self.stride) \ |
|
+ ', padding=' + str(self.padding) \ |
|
+ ', init_criterion=' + str(self.init_criterion) \ |
|
+ ', weight_init=' + str(self.weight_init) \ |
|
+ ', seed=' + str(self.seed) \ |
|
+ ', rotation=' + str(self.rotation) \ |
|
+ ', q_format=' + str(self.quaternion_format) \ |
|
+ ', operation=' + str(self.operation) + ')' |
|
|
|
|
|
class QuaternionTransposeConv(Module): |
|
r"""Applies a Quaternion Transposed Convolution (or Deconvolution) to the incoming data.""" |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, |
|
dilatation=1, padding=0, output_padding=0, groups=1, bias=True, init_criterion='he', |
|
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False, |
|
quaternion_format=False): |
|
|
|
super().__init__() |
|
|
|
self.in_channels = in_channels // 4 |
|
self.out_channels = out_channels // 4 |
|
self.stride = stride |
|
self.padding = padding |
|
self.output_padding = output_padding |
|
self.groups = groups |
|
self.dilatation = dilatation |
|
self.init_criterion = init_criterion |
|
self.weight_init = weight_init |
|
self.seed = seed if seed is not None else np.random.randint(0, 1234) |
|
self.rng = RandomState(self.seed) |
|
self.operation = operation |
|
self.rotation = rotation |
|
self.quaternion_format = quaternion_format |
|
self.winit = {'quaternion': hp_ops.quaternion_init, |
|
'unitary': hp_ops.unitary_init, |
|
'random': hp_ops.random_init}[self.weight_init] |
|
|
|
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation, |
|
self.out_channels, self.in_channels, kernel_size) |
|
|
|
self.r_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.i_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.j_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.k_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
|
|
if bias: |
|
self.bias = Parameter(torch.Tensor(out_channels)) |
|
else: |
|
self.register_parameter('bias', None) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight, |
|
self.kernel_size, self.winit, self.rng, self.init_criterion) |
|
if self.bias is not None: |
|
self.bias.data.zero_() |
|
|
|
def forward(self, input): |
|
if self.rotation: |
|
return hp_ops.quaternion_tranpose_conv_rotation(input, self.r_weight, self.i_weight, |
|
self.j_weight, self.k_weight, self.bias, self.stride, self.padding, |
|
self.output_padding, self.groups, self.dilatation, self.quaternion_format) |
|
else: |
|
return hp_ops.quaternion_transpose_conv(input, self.r_weight, self.i_weight, self.j_weight, |
|
self.k_weight, self.bias, self.stride, self.padding, self.output_padding, |
|
self.groups, self.dilatation) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(' \ |
|
+ 'in_channels=' + str(self.in_channels) \ |
|
+ ', out_channels=' + str(self.out_channels) \ |
|
+ ', bias=' + str(self.bias is not None) \ |
|
+ ', kernel_size=' + str(self.kernel_size) \ |
|
+ ', stride=' + str(self.stride) \ |
|
+ ', padding=' + str(self.padding) \ |
|
+ ', dilation=' + str(self.dilation) \ |
|
+ ', init_criterion=' + str(self.init_criterion) \ |
|
+ ', weight_init=' + str(self.weight_init) \ |
|
+ ', seed=' + str(self.seed) \ |
|
+ ', operation=' + str(self.operation) + ')' |
|
|
|
|
|
class QuaternionConv(Module): |
|
r"""Applies a Quaternion Convolution to the incoming data.""" |
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, |
|
dilatation=1, padding=0, groups=1, bias=True, init_criterion='glorot', |
|
weight_init='quaternion', seed=None, operation='convolution2d', rotation=False, quaternion_format=True, scale=False): |
|
|
|
super().__init__() |
|
|
|
self.in_channels = in_channels // 4 |
|
self.out_channels = out_channels // 4 |
|
self.stride = stride |
|
self.padding = padding |
|
self.groups = groups |
|
self.dilatation = dilatation |
|
self.init_criterion = init_criterion |
|
self.weight_init = weight_init |
|
self.seed = seed if seed is not None else np.random.randint(0, 1234) |
|
self.rng = RandomState(self.seed) |
|
self.operation = operation |
|
self.rotation = rotation |
|
self.quaternion_format = quaternion_format |
|
self.winit = {'quaternion': hp_ops.quaternion_init, |
|
'unitary': hp_ops.unitary_init, |
|
'random': hp_ops.random_init}[self.weight_init] |
|
self.scale = scale |
|
|
|
(self.kernel_size, self.w_shape) = hp_ops.get_kernel_and_weight_shape(self.operation, |
|
self.in_channels, self.out_channels, kernel_size) |
|
|
|
self.r_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.i_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.j_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
self.k_weight = Parameter(torch.Tensor(*self.w_shape)) |
|
|
|
if self.scale: |
|
self.scale_param = Parameter(torch.Tensor(self.r_weight.shape)) |
|
else: |
|
self.scale_param = None |
|
|
|
if self.rotation: |
|
self.zero_kernel = Parameter(torch.zeros( |
|
self.r_weight.shape), requires_grad=False) |
|
if bias: |
|
self.bias = Parameter(torch.Tensor(out_channels)) |
|
else: |
|
self.register_parameter('bias', None) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
hp_ops.affect_init_conv(self.r_weight, self.i_weight, self.j_weight, self.k_weight, |
|
self.kernel_size, self.winit, self.rng, self.init_criterion) |
|
if self.scale_param is not None: |
|
torch.nn.init.xavier_uniform_(self.scale_param.data) |
|
if self.bias is not None: |
|
self.bias.data.zero_() |
|
|
|
def forward(self, input): |
|
if self.rotation: |
|
return hp_ops.quaternion_conv_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight, |
|
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation, |
|
self.quaternion_format, self.scale_param) |
|
else: |
|
return hp_ops.quaternion_conv(input, self.r_weight, self.i_weight, self.j_weight, |
|
self.k_weight, self.bias, self.stride, self.padding, self.groups, self.dilatation) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(' \ |
|
+ 'in_channels=' + str(self.in_channels) \ |
|
+ ', out_channels=' + str(self.out_channels) \ |
|
+ ', bias=' + str(self.bias is not None) \ |
|
+ ', kernel_size=' + str(self.kernel_size) \ |
|
+ ', stride=' + str(self.stride) \ |
|
+ ', padding=' + str(self.padding) \ |
|
+ ', init_criterion=' + str(self.init_criterion) \ |
|
+ ', weight_init=' + str(self.weight_init) \ |
|
+ ', seed=' + str(self.seed) \ |
|
+ ', rotation=' + str(self.rotation) \ |
|
+ ', q_format=' + str(self.quaternion_format) \ |
|
+ ', operation=' + str(self.operation) + ')' |
|
|
|
|
|
class QuaternionLinearAutograd(Module): |
|
r"""Applies a quaternion linear transformation to the incoming data. |
|
|
|
A custom Autograd function is call to drastically reduce the VRAM consumption. Nonetheless, computing time |
|
is also slower compared to QuaternionLinear(). |
|
""" |
|
|
|
def __init__(self, in_features, out_features, bias=True, |
|
init_criterion='glorot', weight_init='quaternion', |
|
seed=None, rotation=False, quaternion_format=True, scale=False): |
|
|
|
super().__init__() |
|
self.in_features = in_features//4 |
|
self.out_features = out_features//4 |
|
self.rotation = rotation |
|
self.quaternion_format = quaternion_format |
|
self.r_weight = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
self.i_weight = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
self.j_weight = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
self.k_weight = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
self.scale = scale |
|
|
|
if self.scale: |
|
self.scale_param = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
else: |
|
self.scale_param = None |
|
|
|
if self.rotation: |
|
self.zero_kernel = Parameter(torch.zeros( |
|
self.r_weight.shape), requires_grad=False) |
|
|
|
if bias: |
|
self.bias = Parameter(torch.Tensor(self.out_features*4)) |
|
else: |
|
self.register_parameter('bias', None) |
|
self.init_criterion = init_criterion |
|
self.weight_init = weight_init |
|
self.seed = seed if seed is not None else np.random.randint(0, 1234) |
|
self.rng = RandomState(self.seed) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
winit = {'quaternion': hp_ops.quaternion_init, 'unitary': hp_ops.unitary_init, |
|
'random': hp_ops.random_init}[self.weight_init] |
|
if self.scale_param is not None: |
|
torch.nn.init.xavier_uniform_(self.scale_param.data) |
|
if self.bias is not None: |
|
self.bias.data.fill_(0) |
|
hp_ops.affect_init(self.r_weight, self.i_weight, self.j_weight, self.k_weight, winit, |
|
self.rng, self.init_criterion) |
|
|
|
def forward(self, input): |
|
|
|
if self.rotation: |
|
return hp_ops.quaternion_linear_rotation(input, self.zero_kernel, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias, self.quaternion_format, self.scale_param) |
|
else: |
|
return hp_ops.quaternion_linear(input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias) |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(' \ |
|
+ 'in_features=' + str(self.in_features) \ |
|
+ ', out_features=' + str(self.out_features) \ |
|
+ ', bias=' + str(self.bias is not None) \ |
|
+ ', init_criterion=' + str(self.init_criterion) \ |
|
+ ', weight_init=' + str(self.weight_init) \ |
|
+ ', rotation=' + str(self.rotation) \ |
|
+ ', seed=' + str(self.seed) + ')' |
|
|
|
|
|
class QuaternionLinear(Module): |
|
r"""Applies a quaternion linear transformation to the incoming data.""" |
|
|
|
def __init__(self, in_features, out_features, bias=True, |
|
init_criterion='he', weight_init='quaternion', |
|
seed=None): |
|
|
|
super().__init__() |
|
self.in_features = in_features//4 |
|
self.out_features = out_features//4 |
|
self.r_weight = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
self.i_weight = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
self.j_weight = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
self.k_weight = Parameter(torch.Tensor( |
|
self.in_features, self.out_features)) |
|
|
|
if bias: |
|
self.bias = Parameter(torch.Tensor(self.out_features*4)) |
|
else: |
|
self.register_parameter('bias', None) |
|
|
|
self.init_criterion = init_criterion |
|
self.weight_init = weight_init |
|
self.seed = seed if seed is not None else np.random.randint(0, 1234) |
|
self.rng = RandomState(self.seed) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
winit = {'quaternion': hp_ops.quaternion_init, |
|
'unitary': hp_ops.unitary_init}[self.weight_init] |
|
if self.bias is not None: |
|
self.bias.data.fill_(0) |
|
affect_init(self.r_weight, self.i_weight, self.j_weight, self.k_weight, winit, |
|
self.rng, self.init_criterion) |
|
|
|
def forward(self, input): |
|
|
|
if input.dim() == 3: |
|
T, N, C = input.size() |
|
input = input.view(T * N, C) |
|
output = hp_ops.QuaternionLinearFunction.apply( |
|
input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias) |
|
output = output.view(T, N, output.size(1)) |
|
elif input.dim() == 2: |
|
output = hp_ops.QuaternionLinearFunction.apply( |
|
input, self.r_weight, self.i_weight, self.j_weight, self.k_weight, self.bias) |
|
else: |
|
raise NotImplementedError |
|
|
|
return output |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + '(' \ |
|
+ 'in_features=' + str(self.in_features) \ |
|
+ ', out_features=' + str(self.out_features) \ |
|
+ ', bias=' + str(self.bias is not None) \ |
|
+ ', init_criterion=' + str(self.init_criterion) \ |
|
+ ', weight_init=' + str(self.weight_init) \ |
|
+ ', seed=' + str(self.seed) + ')' |
|
|