import torch class Linear(): def __init__(self, in_n, out_n, bias=True) -> None: self.params = [] self.have_bias = bias self.weight = torch.randn((in_n,out_n)) / (in_n**0.5) self.params.append(self.weight) self.bias = None if self.have_bias: self.bias = torch.zeros(out_n) self.params.append(self.bias) def __call__(self,x, is_training =True): self.is_training = is_training self.out = x @ self.params[0] if self.have_bias: self.out += self.params[1] return self.out def set_parameters(self,p): self.params = p # self.weight = p[0] # self.bias = p[1] # self.params = [p] def parameters(self): return self.params class BatchNorm(): def __init__(self, in_n,eps=1e-5, momentum = 0.1) -> None: self.eps = eps self.is_training = True self.momentum = momentum self.running_mean = torch.zeros(in_n) self.running_std = torch.ones(in_n) self.gain = torch.ones(in_n) self.bias = torch.zeros(in_n) self.params = [self.gain , self.bias] def __call__(self, x, is_training= True): self.is_training = is_training if self.is_training: mean = x.mean(0,keepdims= True) ## unbiased?? std = x.std(0,keepdims= True) self.out = self.params[0] * (x - mean / (std + self.eps**0.5)) + self.params[1] with torch.no_grad(): self.running_mean = self.running_mean * (1- self.momentum) \ + self.momentum * mean self.running_std = self.running_std * (1- self.momentum) \ + self.momentum * std else: # print(self.running_mean , self.running_std) self.out = self.params[0] * (x - self.running_mean / (self.running_std + self.eps**0.5)) + self.params[1] return self.out def set_parameters(self,p): self.params = p # self.gain = p[0] # self.bias = p[1] # self.params = [self.gain , self.bias] def set_mean_std(self, conf): self.running_mean = conf[0] self.running_std = conf[1] def get_mean_std(self): return [self.running_mean, self.running_std] def parameters(self): return self.params class Activation(): def __init__(self, activation='tanh'): self.params = [] if activation == 'tanh': self.forward = self._forward_tanh elif activation == 'relu': self.forward = self._forward_relu else: raise Exception('Only tanh, and relu activations are supported') def _forward_relu(self,x): return torch.relu(x) def _forward_tanh(self,x): return torch.tanh(x) def __call__(self, x, is_training= True): self.is_training = is_training self.out = self.forward(x) return self.out def set_parameters(self,p): self.params = p def parameters(self): return self.params