|
from typing import ( |
|
Tuple, |
|
List, |
|
Optional, |
|
Dict, |
|
Callable, |
|
Union, |
|
cast, |
|
) |
|
from collections import namedtuple |
|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
|
|
import numpy as np |
|
|
|
import torch as T |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from torch import Tensor |
|
|
|
from .rnn_base import ( |
|
IRecurrentCell, |
|
IRecurrentCellBuilder, |
|
RecurrentLayer, |
|
RecurrentLayerStack, |
|
) |
|
|
|
__all__ = [ |
|
'K_LSTM', |
|
'K_LSTM_Cell', |
|
'K_LSTM_Cell_Builder', |
|
] |
|
|
|
ACTIVATIONS = { |
|
'sigmoid': nn.Sigmoid(), |
|
'tanh': nn.Tanh(), |
|
'hard_tanh': nn.Hardtanh(), |
|
'relu': nn.ReLU(), |
|
} |
|
|
|
GateSpans = namedtuple('GateSpans', ['I', 'F', 'G', 'O']) |
|
|
|
@dataclass |
|
class K_LSTM_Cell_Builder(IRecurrentCellBuilder): |
|
vertical_dropout : float = 0.0 |
|
recurrent_dropout : float = 0.0 |
|
recurrent_dropout_mode : str = 'gal_tied' |
|
input_kernel_initialization : str = 'xavier_uniform' |
|
recurrent_activation : str = 'sigmoid' |
|
tied_forget_gate : bool = False |
|
|
|
def make(self, input_size: int): |
|
return K_LSTM_Cell(input_size, self) |
|
|
|
class K_LSTM_Cell(IRecurrentCell): |
|
def __repr__(self): |
|
return ( |
|
f'{self.__class__.__name__}(' |
|
+ ', '.join( |
|
[ |
|
f'in: {self.Dx}', |
|
f'hid: {self.Dh}', |
|
f'rdo: {self.recurrent_dropout_p} @{self.recurrent_dropout_mode}', |
|
f'vdo: {self.vertical_dropout_p}' |
|
] |
|
) |
|
+')' |
|
) |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
args: K_LSTM_Cell_Builder, |
|
): |
|
super().__init__() |
|
self._args = args |
|
self.Dx = input_size |
|
self.Dh = args.hidden_size |
|
self.recurrent_kernel = nn.Linear(self.Dh, self.Dh * 4) |
|
self.input_kernel = nn.Linear(self.Dx, self.Dh * 4) |
|
|
|
self.recurrent_dropout_p = args.recurrent_dropout or 0.0 |
|
self.vertical_dropout_p = args.vertical_dropout or 0.0 |
|
self.recurrent_dropout_mode = args.recurrent_dropout_mode |
|
|
|
self.recurrent_dropout = nn.Dropout(self.recurrent_dropout_p) |
|
self.vertical_dropout = nn.Dropout(self.vertical_dropout_p) |
|
|
|
self.tied_forget_gate = args.tied_forget_gate |
|
|
|
if isinstance(args.recurrent_activation, str): |
|
self.fun_rec = ACTIVATIONS[args.recurrent_activation] |
|
else: |
|
self.fun_rec = args.recurrent_activation |
|
|
|
self.reset_parameters_() |
|
|
|
|
|
def get_recurrent_weights(self): |
|
|
|
W = self.recurrent_kernel.weight.chunk(4, 0) |
|
b = self.recurrent_kernel.bias.chunk(4, 0) |
|
W = GateSpans(W[0], W[1], W[2], W[3]) |
|
b = GateSpans(b[0], b[1], b[2], b[3]) |
|
return W, b |
|
|
|
|
|
def get_input_weights(self): |
|
|
|
W = self.input_kernel.weight.chunk(4, 0) |
|
b = self.input_kernel.bias.chunk(4, 0) |
|
W = GateSpans(W[0], W[1], W[2], W[3]) |
|
b = GateSpans(b[0], b[1], b[2], b[3]) |
|
return W, b |
|
|
|
@T.jit.ignore |
|
def reset_parameters_(self): |
|
rw, rb = self.get_recurrent_weights() |
|
iw, ib = self.get_input_weights() |
|
|
|
nn.init.zeros_(self.input_kernel.bias) |
|
nn.init.zeros_(self.recurrent_kernel.bias) |
|
nn.init.ones_(rb.F) |
|
|
|
|
|
for W in rw: |
|
nn.init.orthogonal_(W) |
|
for W in iw: |
|
nn.init.xavier_uniform_(W) |
|
|
|
@T.jit.export |
|
def get_init_state(self, input: Tensor) -> Tuple[Tensor, Tensor]: |
|
batch_size = input.shape[1] |
|
h0 = T.zeros(batch_size, self.Dh, device=input.device) |
|
c0 = T.zeros(batch_size, self.Dh, device=input.device) |
|
return (h0, c0) |
|
|
|
def apply_input_kernel(self, xt: Tensor) -> List[Tensor]: |
|
xto = self.vertical_dropout(xt) |
|
out = self.input_kernel(xto).chunk(4, 1) |
|
|
|
return out |
|
|
|
def apply_recurrent_kernel(self, h_tm1: Tensor): |
|
|
|
mode = self.recurrent_dropout_mode |
|
if mode == 'gal_tied': |
|
hto = self.recurrent_dropout(h_tm1) |
|
out = self.recurrent_kernel(hto) |
|
|
|
outs = out.chunk(4, -1) |
|
elif mode == 'gal_gates': |
|
outs = [] |
|
WW, bb = self.get_recurrent_weights() |
|
for i in range(4): |
|
hto = self.recurrent_dropout(h_tm1) |
|
outs.append(F.linear(hto, WW[i], bb[i])) |
|
else: |
|
outs = self.recurrent_kernel(h_tm1).chunk(4, -1) |
|
return outs |
|
|
|
def forward(self, input, state): |
|
|
|
|
|
|
|
|
|
(h_tm1, c_tm1) = state |
|
|
|
Xi, Xf, Xg, Xo = self.apply_input_kernel(input) |
|
Hi, Hf, Hg, Ho = self.apply_recurrent_kernel(h_tm1) |
|
|
|
ft = self.fun_rec(Xf + Hf) |
|
ot = self.fun_rec(Xo + Ho) |
|
if self.tied_forget_gate: |
|
it = 1.0 - ft |
|
else: |
|
it = self.fun_rec(Xi + Hi) |
|
|
|
gt = T.tanh(Xg + Hg) |
|
if self.recurrent_dropout_mode == 'semeniuta': |
|
|
|
gt = self.recurrent_dropout(gt) |
|
|
|
ct = (ft * c_tm1) + (it * gt) |
|
|
|
ht = ot * T.tanh(ct) |
|
|
|
return ht, (ht, ct) |
|
|
|
@T.jit.export |
|
def loop(self, inputs, state_t0, mask=None): |
|
|
|
''' |
|
This loops over t (time) steps |
|
''' |
|
|
|
|
|
|
|
state = state_t0 |
|
outs = [] |
|
for xt in inputs: |
|
ht, state = self(xt, state) |
|
outs.append(ht) |
|
|
|
return outs, state |
|
|
|
class K_LSTM(RecurrentLayerStack): |
|
def __init__( |
|
self, |
|
*args, |
|
**kargs, |
|
): |
|
builder = K_LSTM_Cell_Builder |
|
super().__init__( |
|
builder, |
|
*args, **kargs |
|
) |
|
|