|
|
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange |
|
import torch.nn.functional as F |
|
import time |
|
|
|
class ReshapeLayerNorm(nn.Module): |
|
def __init__(self, dim, norm_layer=nn.LayerNorm): |
|
super(ReshapeLayerNorm, self).__init__() |
|
|
|
self.dim = dim |
|
self.norm = norm_layer(dim) |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.size() |
|
x = rearrange(x, 'b c h w -> b (h w) c') |
|
x = self.norm(x) |
|
x = rearrange(x, 'b (h w) c -> b c h w', h=H) |
|
return x |
|
|
|
class ChannelSelfAttention(nn.Module): |
|
def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): |
|
super(ChannelSelfAttention, self).__init__() |
|
self.dim = dim |
|
self.num_head = num_head |
|
|
|
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) |
|
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
|
self.proj = nn.Conv2d(dim, dim, 1) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, q,k,v, sp=None): |
|
B, C, H, W = q.size() |
|
|
|
q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) |
|
|
|
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) |
|
logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() |
|
attn = attn * logit_scale |
|
|
|
attn = F.softmax(attn, dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = attn @ v |
|
|
|
|
|
x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) |
|
x = self.proj_drop(self.proj(x)) |
|
|
|
return x |
|
|
|
class FeedForward(nn.Module): |
|
def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): |
|
super(FeedForward, self).__init__() |
|
|
|
self.dim = dim |
|
self.hidden_ratio = hidden_ratio |
|
|
|
self.hidden = nn.Conv2d(dim, int(dim*hidden_ratio), 1, bias=bias) |
|
self.drop1 = nn.Dropout(drop) |
|
self.out = nn.Conv2d(int(dim*hidden_ratio), dim, 1, bias=bias) |
|
self.drop2 = nn.Dropout(drop) |
|
self.act = act_layer() |
|
|
|
def forward(self, x): |
|
return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) |
|
|
|
def dft(x, fftshift=False): |
|
fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') |
|
fft = torch.fft.fftshift(fft, dim=(2,3)) if fftshift else fft |
|
amplitude = torch.abs(fft) |
|
phase = torch.angle(fft) |
|
return amplitude, phase |
|
|
|
def idft(amplitude, phase): |
|
real = amplitude * torch.cos(phase) |
|
imag = amplitude * torch.sin(phase) |
|
out = torch.fft.ifft2(torch.complex(real, imag), dim=(2,3), norm='ortho') |
|
out = torch.abs(out) |
|
return out |
|
|
|
class FrequencyEnhancementTransformer(nn.Module): |
|
def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, fftshift=False, *args, **kwargs): |
|
super(FrequencyEnhancementTransformer, self).__init__() |
|
self.c_dim = c_dim |
|
self.feat_dim = feat_dim |
|
self.num_head = num_head |
|
self.hidden_ratio = hidden_ratio |
|
self.fftshift = fftshift |
|
|
|
self.c_conv = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=c_dim*2+4, kernel_size=3, stride=1, padding=1, groups=c_dim*2+4), |
|
nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1), |
|
nn.LeakyReLU()) |
|
self.feat_conv = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim), |
|
nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), |
|
nn.LeakyReLU()) |
|
|
|
self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) |
|
self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) |
|
self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) |
|
self.attn = ChannelSelfAttention(32, num_head) |
|
self.norm1 = ReshapeLayerNorm(32) |
|
|
|
self.ffn = FeedForward(32, hidden_ratio) |
|
self.norm2 = ReshapeLayerNorm(32) |
|
|
|
self.phase_conv = nn.Sequential(nn.Conv2d(in_channels=32+32, out_channels=32, kernel_size=3, stride=1, padding=1)) |
|
|
|
self.out_conv = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=32), |
|
nn.Conv2d(in_channels=32, out_channels=feat_dim, kernel_size=1, stride=1), |
|
nn.LeakyReLU()) |
|
|
|
def forward(self, c0, c1, feat, flow, *args, **kwargs): |
|
B,D,H,W = feat.size() |
|
|
|
c = self.c_conv(torch.cat([c0,c1,flow], dim=1)) |
|
feat_ = self.feat_conv(feat) |
|
|
|
amp_c, pha_c = dft(c, self.fftshift) |
|
amp_f, pha_f = dft(feat_, self.fftshift) |
|
|
|
amp_q = self.q_proj(amp_c) |
|
amp_k = self.k_proj(amp_c) |
|
amp_v = self.v_proj(amp_f) |
|
amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) |
|
amp = self.norm2(self.ffn(amp_attn)) |
|
|
|
pha = self.phase_conv(torch.cat([pha_c,pha_f], dim=1)) |
|
|
|
out = idft(amp, pha) |
|
out = self.out_conv(out) |
|
|
|
return out |
|
|
|
class FrequencyEnhancementDecoder(nn.Module): |
|
def __init__(self, concat_dim, dim, fftshift, *args, **kwargs): |
|
super(FrequencyEnhancementDecoder, self).__init__() |
|
self.concat_dim = concat_dim |
|
self.dim = dim |
|
self.fftshift = fftshift |
|
|
|
self.act = nn.LeakyReLU() |
|
|
|
self.in_conv1 = nn.Sequential(nn.Conv2d(concat_dim, concat_dim, 3, 1, 1, groups=concat_dim), |
|
nn.Conv2d(concat_dim, dim, 1, 1), |
|
nn.LeakyReLU()) |
|
self.in_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), |
|
nn.Conv2d(dim, dim, 1, 1), |
|
nn.LeakyReLU()) |
|
|
|
self.amp_conv = nn.Conv2d(dim, dim, 3, 1, 1) |
|
self.pha_conv = nn.Conv2d(dim, dim, 3, 1, 1) |
|
|
|
self.out_conv1 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), |
|
nn.Conv2d(dim, dim, 1, 1), |
|
nn.LeakyReLU()) |
|
self.out_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), |
|
nn.Conv2d(dim, dim, 1, 1), |
|
nn.LeakyReLU()) |
|
|
|
def forward(self, enc_feats, warped_feats, flow): |
|
_,_,H0,W0 = enc_feats[0].size() |
|
for i, feat in enumerate(enc_feats[1:]): |
|
enc_feats[i+1] = F.pixel_shuffle(feat, H0//feat.size(2)) |
|
for i, feat in enumerate(warped_feats[2:]): |
|
warped_feats[i+2] = F.pixel_shuffle(feat, H0//feat.size(2)) |
|
|
|
x = torch.cat(enc_feats+warped_feats+[flow], dim=1) |
|
x = self.in_conv1(x) |
|
x = self.in_conv2(x) + x |
|
|
|
amp, pha = dft(x, self.fftshift) |
|
amp = self.amp_conv(amp) + amp |
|
pha = self.pha_conv(pha) + pha |
|
|
|
out = idft(amp, pha) + x |
|
|
|
out = self.out_conv1(out) + out |
|
out = self.out_conv2(out) + out |
|
|
|
return out |