# -*- coding: utf-8 -*- """ Created on Sun Jun 20 16:14:37 2021 @author: Administrator """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from torchvision import transforms import torch, math import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat import numbers from thop import profile import numpy as np import time from torchvision import transforms class OneRestore(nn.Module): def __init__(self, channel = 32): super(OneRestore,self).__init__() self.norm = lambda x: (x-0.5)/0.5 self.denorm = lambda x: (x+1)/2 self.in_conv = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False) self.encoder = encoder(channel) self.middle = backbone(channel) self.decoder = decoder(channel) self.out_conv = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False) def forward(self,x,embedding): x_in = self.in_conv(self.norm(x)) x_l, x_m, x_s, x_ss = self.encoder(x_in, embedding) x_mid = self.middle(x_ss, embedding) x_out = self.decoder(x_mid, x_ss, x_s, x_m, x_l, embedding) out = self.out_conv(x_out) + x return self.denorm(out) class encoder(nn.Module): def __init__(self,channel): super(encoder,self).__init__() self.el = ResidualBlock(channel)#16 self.em = ResidualBlock(channel*2)#32 self.es = ResidualBlock(channel*4)#64 self.ess = ResidualBlock(channel*8)#128 self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32 self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64 self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128 self.conv_esstesss = nn.Conv2d(8*channel,16*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 256 def forward(self,x,embedding): elout = self.el(x, embedding)#16 x_emin = self.conv_eltem(self.maxpool(elout))#32 emout = self.em(x_emin, embedding) x_esin = self.conv_emtes(self.maxpool(emout)) esout = self.es(x_esin, embedding) x_esin = self.conv_estess(self.maxpool(esout)) essout = self.ess(x_esin, embedding)#128 return elout, emout, esout, essout#,esssout class backbone(nn.Module): def __init__(self,channel): super(backbone,self).__init__() self.s1 = ResidualBlock(channel*8)#128 self.s2 = ResidualBlock(channel*8)#128 def forward(self,x,embedding): share1 = self.s1(x, embedding) share2 = self.s2(share1, embedding) return share2 class decoder(nn.Module): def __init__(self,channel): super(decoder,self).__init__() self.dss = ResidualBlock(channel*8)#128 self.ds = ResidualBlock(channel*4)#64 self.dm = ResidualBlock(channel*2)#32 self.dl = ResidualBlock(channel)#16 #self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128 self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64 self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32 self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16 def _upsample(self,x,y): _,_,H0,W0 = y.size() return F.interpolate(x,size=(H0,W0),mode='bilinear') def forward(self, x, x_ss, x_s, x_m, x_l, embedding): dssout = self.dss(x + x_ss, embedding) x_dsin = self.conv_dsstds(self._upsample(dssout, x_s)) dsout = self.ds(x_dsin + x_s, embedding) x_dmin = self.conv_dstdm(self._upsample(dsout, x_m)) dmout = self.dm(x_dmin + x_m, embedding) x_dlin = self.conv_dmtdl(self._upsample(dmout, x_l)) dlout = self.dl(x_dlin + x_l, embedding) return dlout class ResidualBlock(nn.Module): # Edge-oriented Residual Convolution Block 面向边缘的残差网络块 解决梯度消失的问题 def __init__(self, channel, norm=False): super(ResidualBlock, self).__init__() self.el = TransformerBlock(channel, num_heads=8, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias') def forward(self, x,embedding): return self.el(x,embedding) def to_3d(x): return rearrange(x, 'b c h w -> b (h w) c') def to_4d(x, h, w): return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) class BiasFree_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(BiasFree_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): sigma = x.var(-1, keepdim=True, unbiased=False) return x / torch.sqrt(sigma + 1e-5) * self.weight class WithBias_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(WithBias_LayerNorm, self).__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) normalized_shape = torch.Size(normalized_shape) assert len(normalized_shape) == 1 self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): mu = x.mean(-1, keepdim=True) sigma = x.var(-1, keepdim=True, unbiased=False) return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias class LayerNorm(nn.Module): def __init__(self, dim, LayerNorm_type): super(LayerNorm, self).__init__() if LayerNorm_type == 'BiasFree': self.body = BiasFree_LayerNorm(dim) else: self.body = WithBias_LayerNorm(dim) def forward(self, x): h, w = x.shape[-2:] return to_4d(self.body(to_3d(x)), h, w) class Cross_Attention(nn.Module): def __init__(self, dim, num_heads, bias, q_dim = 324): super(Cross_Attention, self).__init__() self.dim = dim self.num_heads = num_heads sqrt_q_dim = int(math.sqrt(q_dim)) self.resize = transforms.Resize([sqrt_q_dim, sqrt_q_dim]) self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.q = nn.Linear(q_dim, q_dim, bias=bias) self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias) self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) def forward(self, x, query): b,c,h,w = x.shape q = self.q(query) k, v = self.kv_dwconv(self.kv(x)).chunk(2, dim=1) k = self.resize(k) q = repeat(q, 'b l -> b head c l', head=self.num_heads, c=self.dim//self.num_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) out = (attn @ v) out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) out = self.project_out(out) return out class Self_Attention(nn.Module): def __init__(self, dim, num_heads, bias): super(Self_Attention, self).__init__() self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) def forward(self, x): b,c,h,w = x.shape qkv = self.qkv_dwconv(self.qkv(x)) q,k,v = qkv.chunk(3, dim=1) q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) out = (attn @ v) out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) out = self.project_out(out) return out class FeedForward(nn.Module): def __init__(self, dim, ffn_expansion_factor, bias): super(FeedForward, self).__init__() hidden_features = int(dim * ffn_expansion_factor) self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, groups=hidden_features * 2, bias=bias) self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) def forward(self, x): x = self.project_in(x) x1, x2 = self.dwconv(x).chunk(2, dim=1) x = F.gelu(x1) * x2 x = self.project_out(x) return x class TransformerBlock(nn.Module): def __init__(self, dim, num_heads=8, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias'): super(TransformerBlock, self).__init__() self.norm1 = LayerNorm(dim, LayerNorm_type) self.cross_attn = Cross_Attention(dim, num_heads, bias) self.norm2 = LayerNorm(dim, LayerNorm_type) self.self_attn = Self_Attention(dim, num_heads, bias) self.norm3 = LayerNorm(dim, LayerNorm_type) self.ffn = FeedForward(dim, ffn_expansion_factor, bias) def forward(self, x, query): x = x + self.cross_attn(self.norm1(x),query) x = x + self.self_attn(self.norm2(x)) x = x + self.ffn(self.norm3(x)) return x if __name__ == '__main__': net = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") # x = torch.Tensor(np.random.random((2,3,256,256))).to("cuda" if torch.cuda.is_available() else "cpu") # query = torch.Tensor(np.random.random((2, 324))).to("cuda" if torch.cuda.is_available() else "cpu") # out = net(x, query) # print(out.shape) input = torch.randn(1, 3, 512, 512).to("cuda" if torch.cuda.is_available() else "cpu") query = torch.Tensor(np.random.random((1, 324))).to("cuda" if torch.cuda.is_available() else "cpu") macs, _ = profile(net, inputs=(input, query)) total = sum([param.nelement() for param in net.parameters()]) print('Macs = ' + str(macs/1000**3) + 'G') print('Params = ' + str(total/1e6) + 'M') from fvcore.nn import FlopCountAnalysis, parameter_count_table flops = FlopCountAnalysis(net, (input, query)) print("FLOPs", flops.total()/1000**3)