import torch from torch import nn, einsum import torch.nn.functional as F from einops import rearrange # my layer normalization class LayerNorm(nn.Module): def __init__(self, eps= 1e-5): super(LayerNorm, self).__init__() self.eps = eps def forward(self, input): shape=tuple(input.size()[1:]) return F.layer_norm(input, shape, eps=self.eps) def extra_repr(self): return f'eps={self.eps}' # helpers functions def exists(val): return val is not None def default(val, d): return val if exists(val) else d # lambda layer class LambdaLayer(nn.Module): def __init__( self, dim, *, dim_k, n = None, r = None, heads = 4, dim_out = None, dim_u = 1, normalization="batch"): super().__init__() dim_out = default(dim_out, dim) self.u = dim_u # intra-depth dimension self.heads = heads assert (dim_out % heads) == 0, 'values dimension must be divisible by number of heads for multi-head query' dim_v = dim_out // heads self.to_q = nn.Conv2d(dim, dim_k * heads, 1, bias = False) self.to_k = nn.Conv2d(dim, dim_k * dim_u, 1, bias = False) self.to_v = nn.Conv2d(dim, dim_v * dim_u, 1, bias = False) print(f"using {normalization} in lambda layer") if normalization=="none": self.norm_q = nn.Identity(dim_k * heads) self.norm_v = nn.Identity(dim_v * dim_u) elif normalization=="instance": self.norm_q = nn.InstanceNorm2d(dim_k * heads) self.norm_v = nn.InstanceNorm2d(dim_v * dim_u) elif normalization=="layer": self.norm_q = LayerNorm() self.norm_v = LayerNorm() else: self.norm_q = nn.BatchNorm2d(dim_k * heads) self.norm_v = nn.BatchNorm2d(dim_v * dim_u) print(f"using BN in lambda layer?") self.local_contexts = exists(r) if exists(r): assert (r % 2) == 1, 'Receptive kernel size should be odd' self.pos_conv = nn.Conv3d(dim_u, dim_k, (1, r, r), padding = (0, r // 2, r // 2)) else: assert exists(n), 'You must specify the total sequence length (h x w)' self.pos_emb = nn.Parameter(torch.randn(n, n, dim_k, dim_u)) def forward(self, x): b, c, hh, ww, u, h = *x.shape, self.u, self.heads q = self.to_q(x) k = self.to_k(x) v = self.to_v(x) q = self.norm_q(q) v = self.norm_v(v) q = rearrange(q, 'b (h k) hh ww -> b h k (hh ww)', h = h) k = rearrange(k, 'b (u k) hh ww -> b u k (hh ww)', u = u) v = rearrange(v, 'b (u v) hh ww -> b u v (hh ww)', u = u) k = k.softmax(dim=-1) λc = einsum('b u k m, b u v m -> b k v', k, v) Yc = einsum('b h k n, b k v -> b h v n', q, λc) if self.local_contexts: v = rearrange(v, 'b u v (hh ww) -> b u v hh ww', hh = hh, ww = ww) λp = self.pos_conv(v) Yp = einsum('b h k n, b k v n -> b h v n', q, λp.flatten(3)) else: λp = einsum('n m k u, b u v m -> b n k v', self.pos_emb, v) Yp = einsum('b h k n, b n k v -> b h v n', q, λp) Y = Yc + Yp out = rearrange(Y, 'b h v (hh ww) -> b (h v) hh ww', hh = hh, ww = ww) return out # i'm not sure whether this will work or not class Recursion(nn.Module): def __init__(self, N: int, hidden_dim:int=64): super(Recursion,self).__init__() self.N = N self.lambdaNxN_identity = LambdaLayer(dim=hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1) # merge upstream information here self.lambdaNxN_merge = LambdaLayer(dim=2*hidden_dim, dim_out=hidden_dim, n=N * N, dim_k=16, heads=2, dim_u=1) self.downscale_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=N, stride=N) self.upscale_conv = nn.Conv2d(hidden_dim, hidden_dim * N * N, kernel_size=3,padding=1) self.pixel_shuffle = nn.PixelShuffle(N) def forward(self, x: torch.Tensor): N = self.N def to_patch(blocks:torch.Tensor)->torch.Tensor: shape = blocks.shape blocks_patch = F.unfold(blocks, kernel_size=N, stride=N) blocks_patch = blocks_patch.view(shape[0], shape[1], N, N, -1) num_patch = blocks_patch.shape[-1] blocks_patch = blocks_patch.permute(0, 4, 1, 2, 3).reshape(-1, shape[1], N, N).contiguous() return blocks_patch, num_patch def combine_patch(processed_patch,shape,num_patch): processed_patch = processed_patch.reshape(shape[0], num_patch, shape[1], N, N) processed_patch=processed_patch.permute(0, 2, 3, 4, 1).reshape(shape[0],shape[1] * N * N,num_patch).contiguous() processed=F.fold(processed_patch,output_size=(shape[-2],shape[-1]),kernel_size=N,stride=N) return processed def process(blocks:torch.Tensor)->torch.Tensor: shape = blocks.shape if blocks.shape[-1] == N: processed = self.lambdaNxN_identity(blocks) return processed # to NxN patchs blocks_patch,num_patch=to_patch(blocks) # pass through identity processed_patch = self.lambdaNxN_identity(blocks_patch) # back to HxW processed=combine_patch(processed_patch,shape,num_patch) # get feedback feedback = process(self.downscale_conv(processed)) # upscale feedback upscale_feedback = self.upscale_conv(feedback) upscale_feedback=self.pixel_shuffle(upscale_feedback) # combine results combined = torch.cat([processed, upscale_feedback], dim=1) combined_shape=combined.shape combined_patch,num_patch=to_patch(combined) combined_patch_reduced = self.lambdaNxN_merge(combined_patch) ret_shape=(combined_shape[0],combined_shape[1]//2,combined_shape[2],combined_shape[3]) ret=combine_patch(combined_patch_reduced,ret_shape,num_patch) return ret return process(x)