haritsahm
Add model files
861e32a
raw
history blame
34.2 kB
##########################################################
# pytorch-qnn v1.0
# Titouan Parcollet
# LIA, Université d'Avignon et des Pays du Vaucluse
# ORKIS, Aix-en-provence
# October 2018
##########################################################
import numpy as np
import torch
import torch.nn.functional as F
from numpy.random import RandomState
from scipy.stats import chi
from torch.autograd import Variable
def q_normalize(input, channel=1):
r = get_r(input)
i = get_i(input)
j = get_j(input)
k = get_k(input)
norm = torch.sqrt(r*r + i*i + j*j + k*k + 0.0001)
r = r / norm
i = i / norm
j = j / norm
k = k / norm
return torch.cat([r, i, j, k], dim=channel)
def check_input(input):
if input.dim() not in {2, 3, 4, 5}:
raise RuntimeError(
'Quaternion linear accepts only input of dimension 2 or 3. Quaternion conv accepts up to 5 dim '
' input.dim = ' + str(input.dim())
)
if input.dim() < 4:
nb_hidden = input.size()[-1]
else:
nb_hidden = input.size()[1]
if nb_hidden % 4 != 0:
raise RuntimeError(
'Quaternion Tensors must be divisible by 4.'
' input.size()[1] = ' + str(nb_hidden)
)
#
# Getters
#
def get_r(input):
check_input(input)
if input.dim() < 4:
nb_hidden = input.size()[-1]
else:
nb_hidden = input.size()[1]
if input.dim() == 2:
return input.narrow(1, 0, nb_hidden // 4)
if input.dim() == 3:
return input.narrow(2, 0, nb_hidden // 4)
if input.dim() >= 4:
return input.narrow(1, 0, nb_hidden // 4)
def get_i(input):
if input.dim() < 4:
nb_hidden = input.size()[-1]
else:
nb_hidden = input.size()[1]
if input.dim() == 2:
return input.narrow(1, nb_hidden // 4, nb_hidden // 4)
if input.dim() == 3:
return input.narrow(2, nb_hidden // 4, nb_hidden // 4)
if input.dim() >= 4:
return input.narrow(1, nb_hidden // 4, nb_hidden // 4)
def get_j(input):
check_input(input)
if input.dim() < 4:
nb_hidden = input.size()[-1]
else:
nb_hidden = input.size()[1]
if input.dim() == 2:
return input.narrow(1, nb_hidden // 2, nb_hidden // 4)
if input.dim() == 3:
return input.narrow(2, nb_hidden // 2, nb_hidden // 4)
if input.dim() >= 4:
return input.narrow(1, nb_hidden // 2, nb_hidden // 4)
def get_k(input):
check_input(input)
if input.dim() < 4:
nb_hidden = input.size()[-1]
else:
nb_hidden = input.size()[1]
if input.dim() == 2:
return input.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4)
if input.dim() == 3:
return input.narrow(2, nb_hidden - nb_hidden // 4, nb_hidden // 4)
if input.dim() >= 4:
return input.narrow(1, nb_hidden - nb_hidden // 4, nb_hidden // 4)
def get_modulus(input, vector_form=False):
check_input(input)
r = get_r(input)
i = get_i(input)
j = get_j(input)
k = get_k(input)
if vector_form:
return torch.sqrt(r * r + i * i + j * j + k * k)
else:
return torch.sqrt((r * r + i * i + j * j + k * k).sum(dim=0))
def get_normalized(input, eps=0.0001):
check_input(input)
data_modulus = get_modulus(input)
if input.dim() == 2:
data_modulus_repeated = data_modulus.repeat(1, 4)
elif input.dim() == 3:
data_modulus_repeated = data_modulus.repeat(1, 1, 4)
return input / (data_modulus_repeated.expand_as(input) + eps)
def quaternion_exp(input):
r = get_r(input)
i = get_i(input)
j = get_j(input)
k = get_k(input)
norm_v = torch.sqrt(i*i+j*j+k*k) + 0.0001
exp = torch.exp(r)
r = torch.cos(norm_v)
i = (i / norm_v) * torch.sin(norm_v)
j = (j / norm_v) * torch.sin(norm_v)
k = (k / norm_v) * torch.sin(norm_v)
return torch.cat([exp*r, exp*i, exp*j, exp*k], dim=1)
def kronecker_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride,
padding, groups, dilatation, learn_A, cuda, first_layer=False): # ,
# mat1_learn, mat2_learn, mat3_learn, mat4_learn):
"""Applies a quaternion convolution to the incoming data:"""
# Define the initial matrices to build the Hamilton product
if first_layer:
mat1 = torch.zeros((4, 4), requires_grad=False).view(4, 4, 1, 1)
else:
mat1 = torch.eye(4, requires_grad=False).view(4, 4, 1, 1)
# Define the four matrices that summed up build the Hamilton product rule.
mat2 = torch.tensor([[0, -1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, -1],
[0, 0, 1, 0]], requires_grad=False).view(4, 4, 1, 1)
mat3 = torch.tensor([[0, 0, -1, 0],
[0, 0, 0, 1],
[1, 0, 0, 0],
[0, -1, 0, 0]], requires_grad=False).view(4, 4, 1, 1)
mat4 = torch.tensor([[0, 0, 0, -1],
[0, 0, -1, 0],
[0, 1, 0, 0],
[1, 0, 0, 0]], requires_grad=False).view(4, 4, 1, 1)
if cuda:
mat1, mat2, mat3, mat4 = mat1.cuda(), mat2.cuda(), mat3.cuda(), mat4.cuda()
# Sum of kronecker product between the four matrices and the learnable weights.
cat_kernels_4_quaternion = torch.kron(mat1, r_weight) + \
torch.kron(mat2, i_weight) + \
torch.kron(mat3, j_weight) + \
torch.kron(mat4, k_weight)
if input.dim() == 3:
convfunc = F.conv1d
elif input.dim() == 4:
convfunc = F.conv2d
elif input.dim() == 5:
convfunc = F.conv3d
else:
raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
' input.dim = ' + str(input.dim()))
return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, dilatation, groups)
def quaternion_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride,
padding, groups, dilatation):
"""Applies a quaternion convolution to the incoming data:"""
cat_kernels_4_r = torch.cat(
[r_weight, -i_weight, -j_weight, -k_weight], dim=1)
cat_kernels_4_i = torch.cat(
[i_weight, r_weight, -k_weight, j_weight], dim=1)
cat_kernels_4_j = torch.cat(
[j_weight, k_weight, r_weight, -i_weight], dim=1)
cat_kernels_4_k = torch.cat(
[k_weight, -j_weight, i_weight, r_weight], dim=1)
cat_kernels_4_quaternion = torch.cat(
[cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=0)
if input.dim() == 3:
convfunc = F.conv1d
elif input.dim() == 4:
convfunc = F.conv2d
elif input.dim() == 5:
convfunc = F.conv3d
else:
raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
' input.dim = ' + str(input.dim()))
return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, dilatation, groups)
def quaternion_transpose_conv(input, r_weight, i_weight, j_weight, k_weight, bias, stride,
padding, output_padding, groups, dilatation):
"""Applies a quaternion transposed convolution to the incoming data:"""
cat_kernels_4_r = torch.cat(
[r_weight, -i_weight, -j_weight, -k_weight], dim=1)
cat_kernels_4_i = torch.cat(
[i_weight, r_weight, -k_weight, j_weight], dim=1)
cat_kernels_4_j = torch.cat(
[j_weight, k_weight, r_weight, -i_weight], dim=1)
cat_kernels_4_k = torch.cat(
[k_weight, -j_weight, i_weight, r_weight], dim=1)
cat_kernels_4_quaternion = torch.cat(
[cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=0)
if input.dim() == 3:
convfunc = F.conv_transpose1d
elif input.dim() == 4:
convfunc = F.conv_transpose2d
elif input.dim() == 5:
convfunc = F.conv_transpose3d
else:
raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
' input.dim = ' + str(input.dim()))
return convfunc(input, cat_kernels_4_quaternion,
bias, stride, padding, output_padding, groups, dilatation)
def quaternion_conv_rotation(input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias, stride,
padding, groups, dilatation, quaternion_format, scale=None):
"""Applies a quaternion rotation and convolution transformation to the incoming data:
The rotation W*x*W^t can be replaced by R*x following:
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
Works for unitary and non unitary weights.
The initial size of the input must be a multiple of 3 if quaternion_format = False and
4 if quaternion_format = True.
"""
square_r = (r_weight*r_weight)
square_i = (i_weight*i_weight)
square_j = (j_weight*j_weight)
square_k = (k_weight*k_weight)
norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001)
# print(norm)
r_n_weight = (r_weight / norm)
i_n_weight = (i_weight / norm)
j_n_weight = (j_weight / norm)
k_n_weight = (k_weight / norm)
norm_factor = 2.0
square_i = norm_factor*(i_n_weight*i_n_weight)
square_j = norm_factor*(j_n_weight*j_n_weight)
square_k = norm_factor*(k_n_weight*k_n_weight)
ri = (norm_factor*r_n_weight*i_n_weight)
rj = (norm_factor*r_n_weight*j_n_weight)
rk = (norm_factor*r_n_weight*k_n_weight)
ij = (norm_factor*i_n_weight*j_n_weight)
ik = (norm_factor*i_n_weight*k_n_weight)
jk = (norm_factor*j_n_weight*k_n_weight)
if quaternion_format:
if scale is not None:
rot_kernel_1 = torch.cat([zero_kernel, scale * (1.0 - (square_j + square_k)),
scale * (ij-rk), scale * (ik+rj)], dim=1)
rot_kernel_2 = torch.cat([zero_kernel, scale * (ij+rk), scale *
(1.0 - (square_i + square_k)), scale * (jk-ri)], dim=1)
rot_kernel_3 = torch.cat([zero_kernel, scale * (ik-rj), scale * (jk+ri),
scale * (1.0 - (square_i + square_j))], dim=1)
else:
rot_kernel_1 = torch.cat(
[zero_kernel, (1.0 - (square_j + square_k)), (ij-rk), (ik+rj)], dim=1)
rot_kernel_2 = torch.cat(
[zero_kernel, (ij+rk), (1.0 - (square_i + square_k)), (jk-ri)], dim=1)
rot_kernel_3 = torch.cat(
[zero_kernel, (ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=1)
zero_kernel2 = torch.cat(
[zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=1)
global_rot_kernel = torch.cat(
[zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0)
else:
if scale is not None:
rot_kernel_1 = torch.cat([scale * (1.0 - (square_j + square_k)),
scale * (ij-rk), scale * (ik+rj)], dim=0)
rot_kernel_2 = torch.cat(
[scale * (ij+rk), scale * (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0)
rot_kernel_3 = torch.cat([scale * (ik-rj), scale * (jk+ri), scale *
(1.0 - (square_i + square_j))], dim=0)
else:
rot_kernel_1 = torch.cat(
[1.0 - (square_j + square_k), (ij-rk), (ik+rj)], dim=0)
rot_kernel_2 = torch.cat(
[(ij+rk), 1.0 - (square_i + square_k), (jk-ri)], dim=0)
rot_kernel_3 = torch.cat(
[(ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0)
global_rot_kernel = torch.cat(
[rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0)
# print(input.shape)
# print(square_r.shape)
# print(global_rot_kernel.shape)
if input.dim() == 3:
convfunc = F.conv1d
elif input.dim() == 4:
convfunc = F.conv2d
elif input.dim() == 5:
convfunc = F.conv3d
else:
raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
' input.dim = ' + str(input.dim()))
return convfunc(input, global_rot_kernel, bias, stride, padding, dilatation, groups)
def quaternion_transpose_conv_rotation(
input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias, stride,
padding, output_padding, groups, dilatation, quaternion_format):
"""Applies a quaternion rotation and transposed convolution transformation to the incoming data:
The rotation W*x*W^t can be replaced by R*x following:
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
Works for unitary and non unitary weights.
The initial size of the input must be a multiple of 3 if quaternion_format = False and
4 if quaternion_format = True.
"""
square_r = (r_weight*r_weight)
square_i = (i_weight*i_weight)
square_j = (j_weight*j_weight)
square_k = (k_weight*k_weight)
norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001)
r_weight = (r_weight / norm)
i_weight = (i_weight / norm)
j_weight = (j_weight / norm)
k_weight = (k_weight / norm)
norm_factor = 2.0
square_i = norm_factor*(i_weight*i_weight)
square_j = norm_factor*(j_weight*j_weight)
square_k = norm_factor*(k_weight*k_weight)
ri = (norm_factor*r_weight*i_weight)
rj = (norm_factor*r_weight*j_weight)
rk = (norm_factor*r_weight*k_weight)
ij = (norm_factor*i_weight*j_weight)
ik = (norm_factor*i_weight*k_weight)
jk = (norm_factor*j_weight*k_weight)
if quaternion_format:
rot_kernel_1 = torch.cat(
[zero_kernel, 1.0 - (square_j + square_k), ij-rk, ik+rj], dim=1)
rot_kernel_2 = torch.cat(
[zero_kernel, ij+rk, 1.0 - (square_i + square_k), jk-ri], dim=1)
rot_kernel_3 = torch.cat(
[zero_kernel, ik-rj, jk+ri, 1.0 - (square_i + square_j)], dim=1)
zero_kernel2 = torch.zeros(rot_kernel_1.shape).cuda()
global_rot_kernel = torch.cat(
[zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0)
else:
rot_kernel_1 = torch.cat(
[1.0 - (square_j + square_k), ij-rk, ik+rj], dim=1)
rot_kernel_2 = torch.cat(
[ij+rk, 1.0 - (square_i + square_k), jk-ri], dim=1)
rot_kernel_3 = torch.cat(
[ik-rj, jk+ri, 1.0 - (square_i + square_j)], dim=1)
global_rot_kernel = torch.cat(
[rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=0)
if input.dim() == 3:
convfunc = F.conv_transpose1d
elif input.dim() == 4:
convfunc = F.conv_transpose2d
elif input.dim() == 5:
convfunc = F.conv_transpose3d
else:
raise Exception('The convolutional input is either 3, 4 or 5 dimensions.'
' input.dim = ' + str(input.dim()))
return convfunc(input, cat_kernels_4_quaternion, bias, stride, padding, output_padding, groups, dilatation)
def quaternion_linear(input, r_weight, i_weight, j_weight, k_weight, bias=True):
"""Applies a quaternion linear transformation to the incoming data:
It is important to notice that the forward phase of a QNN is defined
as W * Inputs (with * equal to the Hamilton product). The constructed
cat_kernels_4_quaternion is a modified version of the quaternion representation
so when we do torch.mm(Input,W) it's equivalent to W * Inputs.
"""
cat_kernels_4_r = torch.cat(
[r_weight, -i_weight, -j_weight, -k_weight], dim=0)
cat_kernels_4_i = torch.cat(
[i_weight, r_weight, -k_weight, j_weight], dim=0)
cat_kernels_4_j = torch.cat(
[j_weight, k_weight, r_weight, -i_weight], dim=0)
cat_kernels_4_k = torch.cat(
[k_weight, -j_weight, i_weight, r_weight], dim=0)
cat_kernels_4_quaternion = torch.cat(
[cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=1)
if input.dim() == 2:
if bias is not None:
return torch.addmm(bias, input, cat_kernels_4_quaternion)
else:
return torch.mm(input, cat_kernels_4_quaternion)
else:
output = torch.matmul(input, cat_kernels_4_quaternion)
if bias is not None:
return output+bias
else:
return output
def quaternion_linear_rotation(input, zero_kernel, r_weight, i_weight, j_weight, k_weight, bias=None,
quaternion_format=False, scale=None):
"""Applies a quaternion rotation transformation to the incoming data:
The rotation W*x*W^t can be replaced by R*x following:
https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation
Works for unitary and non unitary weights.
The initial size of the input must be a multiple of 3 if quaternion_format = False and
4 if quaternion_format = True.
"""
square_r = (r_weight*r_weight)
square_i = (i_weight*i_weight)
square_j = (j_weight*j_weight)
square_k = (k_weight*k_weight)
norm = torch.sqrt(square_r+square_i+square_j+square_k + 0.0001)
r_n_weight = (r_weight / norm)
i_n_weight = (i_weight / norm)
j_n_weight = (j_weight / norm)
k_n_weight = (k_weight / norm)
norm_factor = 2.0
square_i = norm_factor*(i_n_weight*i_n_weight)
square_j = norm_factor*(j_n_weight*j_n_weight)
square_k = norm_factor*(k_n_weight*k_n_weight)
ri = (norm_factor*r_n_weight*i_n_weight)
rj = (norm_factor*r_n_weight*j_n_weight)
rk = (norm_factor*r_n_weight*k_n_weight)
ij = (norm_factor*i_n_weight*j_n_weight)
ik = (norm_factor*i_n_weight*k_n_weight)
jk = (norm_factor*j_n_weight*k_n_weight)
if quaternion_format:
if scale is not None:
rot_kernel_1 = torch.cat([zero_kernel, scale * (1.0 - (square_j + square_k)),
scale * (ij-rk), scale * (ik+rj)], dim=0)
rot_kernel_2 = torch.cat([zero_kernel, scale * (ij+rk), scale *
(1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0)
rot_kernel_3 = torch.cat([zero_kernel, scale * (ik-rj), scale * (jk+ri),
scale * (1.0 - (square_i + square_j))], dim=0)
else:
rot_kernel_1 = torch.cat(
[zero_kernel, (1.0 - (square_j + square_k)), (ij-rk), (ik+rj)], dim=0)
rot_kernel_2 = torch.cat(
[zero_kernel, (ij+rk), (1.0 - (square_i + square_k)), (jk-ri)], dim=0)
rot_kernel_3 = torch.cat(
[zero_kernel, (ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0)
zero_kernel2 = torch.cat(
[zero_kernel, zero_kernel, zero_kernel, zero_kernel], dim=0)
global_rot_kernel = torch.cat(
[zero_kernel2, rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=1)
else:
if scale is not None:
rot_kernel_1 = torch.cat([scale * (1.0 - (square_j + square_k)),
scale * (ij-rk), scale * (ik+rj)], dim=0)
rot_kernel_2 = torch.cat(
[scale * (ij+rk), scale * (1.0 - (square_i + square_k)), scale * (jk-ri)], dim=0)
rot_kernel_3 = torch.cat([scale * (ik-rj), scale * (jk+ri), scale *
(1.0 - (square_i + square_j))], dim=0)
else:
rot_kernel_1 = torch.cat(
[1.0 - (square_j + square_k), (ij-rk), (ik+rj)], dim=0)
rot_kernel_2 = torch.cat(
[(ij+rk), 1.0 - (square_i + square_k), (jk-ri)], dim=0)
rot_kernel_3 = torch.cat(
[(ik-rj), (jk+ri), (1.0 - (square_i + square_j))], dim=0)
global_rot_kernel = torch.cat(
[rot_kernel_1, rot_kernel_2, rot_kernel_3], dim=1)
if input.dim() == 2:
if bias is not None:
return torch.addmm(bias, input, global_rot_kernel)
else:
return torch.mm(input, global_rot_kernel)
else:
output = torch.matmul(input, global_rot_kernel)
if bias is not None:
return output+bias
else:
return output
# Custom AUTOGRAD for lower VRAM consumption
class QuaternionLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, r_weight, i_weight, j_weight, k_weight, bias=None):
ctx.save_for_backward(input, r_weight, i_weight,
j_weight, k_weight, bias)
check_input(input)
cat_kernels_4_r = torch.cat(
[r_weight, -i_weight, -j_weight, -k_weight], dim=0)
cat_kernels_4_i = torch.cat(
[i_weight, r_weight, -k_weight, j_weight], dim=0)
cat_kernels_4_j = torch.cat(
[j_weight, k_weight, r_weight, -i_weight], dim=0)
cat_kernels_4_k = torch.cat(
[k_weight, -j_weight, i_weight, r_weight], dim=0)
cat_kernels_4_quaternion = torch.cat(
[cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k], dim=1)
if input.dim() == 2:
if bias is not None:
return torch.addmm(bias, input, cat_kernels_4_quaternion)
else:
return torch.mm(input, cat_kernels_4_quaternion)
else:
output = torch.matmul(input, cat_kernels_4_quaternion)
if bias is not None:
return output+bias
else:
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
input, r_weight, i_weight, j_weight, k_weight, bias = ctx.saved_tensors
grad_input = grad_weight_r = grad_weight_i = grad_weight_j = grad_weight_k = grad_bias = None
input_r = torch.cat([r_weight, -i_weight, -j_weight, -k_weight], dim=0)
input_i = torch.cat([i_weight, r_weight, -k_weight, j_weight], dim=0)
input_j = torch.cat([j_weight, k_weight, r_weight, -i_weight], dim=0)
input_k = torch.cat([k_weight, -j_weight, i_weight, r_weight], dim=0)
cat_kernels_4_quaternion_T = Variable(
torch.cat([input_r, input_i, input_j, input_k], dim=1).permute(1, 0), requires_grad=False)
r = get_r(input)
i = get_i(input)
j = get_j(input)
k = get_k(input)
input_r = torch.cat([r, -i, -j, -k], dim=0)
input_i = torch.cat([i, r, -k, j], dim=0)
input_j = torch.cat([j, k, r, -i], dim=0)
input_k = torch.cat([k, -j, i, r], dim=0)
input_mat = Variable(
torch.cat([input_r, input_i, input_j, input_k], dim=1), requires_grad=False)
r = get_r(grad_output)
i = get_i(grad_output)
j = get_j(grad_output)
k = get_k(grad_output)
input_r = torch.cat([r, i, j, k], dim=1)
input_i = torch.cat([-i, r, k, -j], dim=1)
input_j = torch.cat([-j, -k, r, i], dim=1)
input_k = torch.cat([-k, j, -i, r], dim=1)
grad_mat = torch.cat([input_r, input_i, input_j, input_k], dim=0)
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(cat_kernels_4_quaternion_T)
if ctx.needs_input_grad[1]:
grad_weight = grad_mat.permute(1, 0).mm(input_mat).permute(1, 0)
unit_size_x = r_weight.size(0)
unit_size_y = r_weight.size(1)
grad_weight_r = grad_weight.narrow(
0, 0, unit_size_x).narrow(1, 0, unit_size_y)
grad_weight_i = grad_weight.narrow(
0, 0, unit_size_x).narrow(1, unit_size_y, unit_size_y)
grad_weight_j = grad_weight.narrow(
0, 0, unit_size_x).narrow(1, unit_size_y*2, unit_size_y)
grad_weight_k = grad_weight.narrow(
0, 0, unit_size_x).narrow(1, unit_size_y*3, unit_size_y)
if ctx.needs_input_grad[5]:
grad_bias = grad_output.sum(0).squeeze(0)
return grad_input, grad_weight_r, grad_weight_i, grad_weight_j, grad_weight_k, grad_bias
def hamilton_product(q0, q1):
"""
Applies a Hamilton product q0 * q1:
Shape:
- q0, q1 should be (batch_size, quaternion_number)
(rr' - xx' - yy' - zz') +
(rx' + xr' + yz' - zy')i +
(ry' - xz' + yr' + zx')j +
(rz' + xy' - yx' + zr')k +
"""
q1_r = get_r(q1)
q1_i = get_i(q1)
q1_j = get_j(q1)
q1_k = get_k(q1)
# rr', xx', yy', and zz'
r_base = torch.mul(q0, q1)
# (rr' - xx' - yy' - zz')
r = get_r(r_base) - get_i(r_base) - get_j(r_base) - get_k(r_base)
# rx', xr', yz', and zy'
i_base = torch.mul(q0, torch.cat([q1_i, q1_r, q1_k, q1_j], dim=1))
# (rx' + xr' + yz' - zy')
i = get_r(i_base) + get_i(i_base) + get_j(i_base) - get_k(i_base)
# ry', xz', yr', and zx'
j_base = torch.mul(q0, torch.cat([q1_j, q1_k, q1_r, q1_i], dim=1))
# (rx' + xr' + yz' - zy')
j = get_r(j_base) - get_i(j_base) + get_j(j_base) + get_k(j_base)
# rz', xy', yx', and zr'
k_base = torch.mul(q0, torch.cat([q1_k, q1_j, q1_i, q1_r], dim=1))
# (rx' + xr' + yz' - zy')
k = get_r(k_base) + get_i(k_base) - get_j(k_base) + get_k(k_base)
return torch.cat([r, i, j, k], dim=1)
#
# PARAMETERS INITIALIZATION
#
def unitary_init(in_features, out_features, rng, kernel_size=None, criterion='he'):
if kernel_size is not None:
receptive_field = np.prod(kernel_size)
fan_in = in_features * receptive_field
fan_out = out_features * receptive_field
else:
fan_in = in_features
fan_out = out_features
if kernel_size is None:
kernel_shape = (in_features, out_features)
else:
if type(kernel_size) is int:
kernel_shape = (out_features, in_features) + tuple((kernel_size,))
else:
kernel_shape = (out_features, in_features) + (*kernel_size,)
number_of_weights = np.prod(kernel_shape)
v_r = np.random.uniform(-1.0, 1.0, number_of_weights)
v_i = np.random.uniform(-1.0, 1.0, number_of_weights)
v_j = np.random.uniform(-1.0, 1.0, number_of_weights)
v_k = np.random.uniform(-1.0, 1.0, number_of_weights)
# Unitary quaternion
for i in range(0, number_of_weights):
norm = np.sqrt(v_r[i]**2 + v_i[i]**2 + v_j[i]**2 + v_k[i]**2)+0.0001
v_r[i] /= norm
v_i[i] /= norm
v_j[i] /= norm
v_k[i] /= norm
v_r = v_r.reshape(kernel_shape)
v_i = v_i.reshape(kernel_shape)
v_j = v_j.reshape(kernel_shape)
v_k = v_k.reshape(kernel_shape)
return (v_r, v_i, v_j, v_k)
def random_init(in_features, out_features, rng, kernel_size=None, criterion='glorot'):
if kernel_size is not None:
receptive_field = np.prod(kernel_size)
fan_in = in_features * receptive_field
fan_out = out_features * receptive_field
else:
fan_in = in_features
fan_out = out_features
if criterion == 'glorot':
s = 1. / np.sqrt(2*(fan_in + fan_out))
elif criterion == 'he':
s = 1. / np.sqrt(2*fan_in)
else:
raise ValueError('Invalid criterion: ' + criterion)
if kernel_size is None:
kernel_shape = (in_features, out_features)
else:
if type(kernel_size) is int:
kernel_shape = (out_features, in_features) + tuple((kernel_size,))
else:
kernel_shape = (out_features, in_features) + (*kernel_size,)
number_of_weights = np.prod(kernel_shape)
v_r = np.random.uniform(-1.0, 1.0, number_of_weights)
v_i = np.random.uniform(-1.0, 1.0, number_of_weights)
v_j = np.random.uniform(-1.0, 1.0, number_of_weights)
v_k = np.random.uniform(-1.0, 1.0, number_of_weights)
v_r = v_r.reshape(kernel_shape)
v_i = v_i.reshape(kernel_shape)
v_j = v_j.reshape(kernel_shape)
v_k = v_k.reshape(kernel_shape)
weight_r = v_r
weight_i = v_i
weight_j = v_j
weight_k = v_k
return (weight_r, weight_i, weight_j, weight_k)
def quaternion_init(in_features, out_features, rng, kernel_size=None, criterion='glorot'):
if kernel_size is not None:
receptive_field = np.prod(kernel_size)
fan_in = in_features * receptive_field
fan_out = out_features * receptive_field
else:
fan_in = in_features
fan_out = out_features
if criterion == 'glorot':
s = 1. / np.sqrt(2*(fan_in + fan_out))
elif criterion == 'he':
s = 1. / np.sqrt(2*fan_in)
else:
raise ValueError('Invalid criterion: ' + criterion)
rng = RandomState(np.random.randint(1, 1234))
# Generating randoms and purely imaginary quaternions :
if kernel_size is None:
kernel_shape = (in_features, out_features)
else:
if type(kernel_size) is int:
kernel_shape = (out_features, in_features) + tuple((kernel_size,))
else:
kernel_shape = (out_features, in_features) + (*kernel_size,)
modulus = chi.rvs(4, loc=0, scale=s, size=kernel_shape)
number_of_weights = np.prod(kernel_shape)
v_i = np.random.uniform(-1.0, 1.0, number_of_weights)
v_j = np.random.uniform(-1.0, 1.0, number_of_weights)
v_k = np.random.uniform(-1.0, 1.0, number_of_weights)
# Purely imaginary quaternions unitary
for i in range(0, number_of_weights):
norm = np.sqrt(v_i[i]**2 + v_j[i]**2 + v_k[i]**2 + 0.0001)
v_i[i] /= norm
v_j[i] /= norm
v_k[i] /= norm
v_i = v_i.reshape(kernel_shape)
v_j = v_j.reshape(kernel_shape)
v_k = v_k.reshape(kernel_shape)
phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape)
weight_r = modulus * np.cos(phase)
weight_i = modulus * v_i*np.sin(phase)
weight_j = modulus * v_j*np.sin(phase)
weight_k = modulus * v_k*np.sin(phase)
return (weight_r, weight_i, weight_j, weight_k)
def create_dropout_mask(dropout_p, size, rng, as_type, operation='linear'):
if operation == 'linear':
mask = rng.binomial(n=1, p=1-dropout_p, size=size)
return Variable(torch.from_numpy(mask).type(as_type))
else:
raise Exception("create_dropout_mask accepts only 'linear'. Found operation = "
+ str(operation))
def affect_init(r_weight, i_weight, j_weight, k_weight, init_func, rng, init_criterion):
if r_weight.size() != i_weight.size() or r_weight.size() != j_weight.size() or \
r_weight.size() != k_weight.size():
raise ValueError('The real and imaginary weights '
'should have the same size . Found: r:'
+ str(r_weight.size()) + ' i:'
+ str(i_weight.size()) + ' j:'
+ str(j_weight.size()) + ' k:'
+ str(k_weight.size()))
elif r_weight.dim() != 2:
raise Exception('affect_init accepts only matrices. Found dimension = '
+ str(r_weight.dim()))
kernel_size = None
r, i, j, k = init_func(r_weight.size(0), r_weight.size(
1), rng, kernel_size, init_criterion)
r, i, j, k = torch.from_numpy(r), torch.from_numpy(
i), torch.from_numpy(j), torch.from_numpy(k)
r_weight.data = r.type_as(r_weight.data)
i_weight.data = i.type_as(i_weight.data)
j_weight.data = j.type_as(j_weight.data)
k_weight.data = k.type_as(k_weight.data)
def affect_init_conv(r_weight, i_weight, j_weight, k_weight, kernel_size, init_func, rng,
init_criterion):
if r_weight.size() != i_weight.size() or r_weight.size() != j_weight.size() or \
r_weight.size() != k_weight.size():
raise ValueError('The real and imaginary weights '
'should have the same size . Found: r:'
+ str(r_weight.size()) + ' i:'
+ str(i_weight.size()) + ' j:'
+ str(j_weight.size()) + ' k:'
+ str(k_weight.size()))
elif 2 >= r_weight.dim():
raise Exception('affect_conv_init accepts only tensors that have more than 2 dimensions. Found dimension = '
+ str(real_weight.dim()))
r, i, j, k = init_func(
r_weight.size(1),
r_weight.size(0),
rng=rng,
kernel_size=kernel_size,
criterion=init_criterion
)
r, i, j, k = torch.from_numpy(r), torch.from_numpy(
i), torch.from_numpy(j), torch.from_numpy(k)
r_weight.data = r.type_as(r_weight.data)
i_weight.data = i.type_as(i_weight.data)
j_weight.data = j.type_as(j_weight.data)
k_weight.data = k.type_as(k_weight.data)
def get_kernel_and_weight_shape(operation, in_channels, out_channels, kernel_size):
if operation == 'convolution1d':
if type(kernel_size) is not int:
raise ValueError(
"""An invalid kernel_size was supplied for a 1d convolution. The kernel size
must be integer in the case. Found kernel_size = """ + str(kernel_size)
)
else:
ks = kernel_size
w_shape = (out_channels, in_channels) + tuple((ks,))
else: # in case it is 2d or 3d.
if operation == 'convolution2d' and type(kernel_size) is int:
ks = (kernel_size, kernel_size)
elif operation == 'convolution3d' and type(kernel_size) is int:
ks = (kernel_size, kernel_size, kernel_size)
elif type(kernel_size) is not int:
if operation == 'convolution2d' and len(kernel_size) != 2:
raise ValueError(
"""An invalid kernel_size was supplied for a 2d convolution. The kernel size
must be either an integer or a tuple of 2. Found kernel_size = """ + str(kernel_size)
)
elif operation == 'convolution3d' and len(kernel_size) != 3:
raise ValueError(
"""An invalid kernel_size was supplied for a 3d convolution. The kernel size
must be either an integer or a tuple of 3. Found kernel_size = """ + str(kernel_size)
)
else:
ks = kernel_size
w_shape = (out_channels, in_channels) + (*ks,)
return ks, w_shape