# frequency_enhance_006.py (여러가지 변경) import math import torch import torch.nn as nn from einops import rearrange import torch.nn.functional as F import time class ReshapeLayerNorm(nn.Module): def __init__(self, dim, norm_layer=nn.LayerNorm): super(ReshapeLayerNorm, self).__init__() self.dim = dim self.norm = norm_layer(dim) def forward(self, x): B, C, H, W = x.size() x = rearrange(x, 'b c h w -> b (h w) c') x = self.norm(x) x = rearrange(x, 'b (h w) c -> b c h w', h=H) return x class ChannelSelfAttention(nn.Module): def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): super(ChannelSelfAttention, self).__init__() self.dim = dim self.num_head = num_head self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Conv2d(dim, dim, 1) self.proj_drop = nn.Dropout(proj_drop) def forward(self, q,k,v, sp=None): B, C, H, W = q.size() q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW] attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L] logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() attn = attn * logit_scale attn = F.softmax(attn, dim=-1) attn = self.attn_drop(attn) x = attn @ v # [B, L, C/L, HW] # head merge x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W] x = self.proj_drop(self.proj(x)) # [B, C, H, W] return x class FeedForward(nn.Module): def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): super(FeedForward, self).__init__() self.dim = dim self.hidden_ratio = hidden_ratio self.hidden = nn.Conv2d(dim, int(dim*hidden_ratio), 1, bias=bias) self.drop1 = nn.Dropout(drop) self.out = nn.Conv2d(int(dim*hidden_ratio), dim, 1, bias=bias) self.drop2 = nn.Dropout(drop) self.act = act_layer() def forward(self, x): return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) def dft(x, fftshift=False): fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') fft = torch.fft.fftshift(fft, dim=(2,3)) if fftshift else fft amplitude = torch.abs(fft) phase = torch.angle(fft) return amplitude, phase def idft(amplitude, phase): real = amplitude * torch.cos(phase) imag = amplitude * torch.sin(phase) out = torch.fft.ifft2(torch.complex(real, imag), dim=(2,3), norm='ortho') out = torch.abs(out) return out class FrequencyEnhancementTransformer(nn.Module): def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, fftshift=False, *args, **kwargs): super(FrequencyEnhancementTransformer, self).__init__() self.c_dim = c_dim self.feat_dim = feat_dim self.num_head = num_head self.hidden_ratio = hidden_ratio self.fftshift = fftshift self.c_conv = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=c_dim*2+4, kernel_size=3, stride=1, padding=1, groups=c_dim*2+4), nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1), nn.LeakyReLU()) self.feat_conv = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim), nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), nn.LeakyReLU()) self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) self.attn = ChannelSelfAttention(32, num_head) self.norm1 = ReshapeLayerNorm(32) self.ffn = FeedForward(32, hidden_ratio) self.norm2 = ReshapeLayerNorm(32) self.phase_conv = nn.Sequential(nn.Conv2d(in_channels=32+32, out_channels=32, kernel_size=3, stride=1, padding=1)) self.out_conv = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=32), nn.Conv2d(in_channels=32, out_channels=feat_dim, kernel_size=1, stride=1), nn.LeakyReLU()) def forward(self, c0, c1, feat, flow, *args, **kwargs): B,D,H,W = feat.size() c = self.c_conv(torch.cat([c0,c1,flow], dim=1)) # [B, 32, H, W] feat_ = self.feat_conv(feat) # [B, 32, H, W] amp_c, pha_c = dft(c, self.fftshift) # [B, 32, H, W] amp_f, pha_f = dft(feat_, self.fftshift) # [B, 32, H, W] amp_q = self.q_proj(amp_c) # [B, 32, H, W] amp_k = self.k_proj(amp_c) # [B, 32, H, W] amp_v = self.v_proj(amp_f) # [B, 32, H, W] amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) # [B, 32, H, W] amp = self.norm2(self.ffn(amp_attn)) # [B, 32, H, W] pha = self.phase_conv(torch.cat([pha_c,pha_f], dim=1)) # [B, 32, H, W] out = idft(amp, pha) # [B, 32, H, W] out = self.out_conv(out) # [B, D, H, W] return out class FrequencyEnhancementDecoder(nn.Module): def __init__(self, concat_dim, dim, fftshift, *args, **kwargs): super(FrequencyEnhancementDecoder, self).__init__() self.concat_dim = concat_dim self.dim = dim self.fftshift = fftshift self.act = nn.LeakyReLU() self.in_conv1 = nn.Sequential(nn.Conv2d(concat_dim, concat_dim, 3, 1, 1, groups=concat_dim), nn.Conv2d(concat_dim, dim, 1, 1), nn.LeakyReLU()) self.in_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), nn.Conv2d(dim, dim, 1, 1), nn.LeakyReLU()) self.amp_conv = nn.Conv2d(dim, dim, 3, 1, 1) self.pha_conv = nn.Conv2d(dim, dim, 3, 1, 1) self.out_conv1 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), nn.Conv2d(dim, dim, 1, 1), nn.LeakyReLU()) self.out_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), nn.Conv2d(dim, dim, 1, 1), nn.LeakyReLU()) def forward(self, enc_feats, warped_feats, flow): _,_,H0,W0 = enc_feats[0].size() for i, feat in enumerate(enc_feats[1:]): enc_feats[i+1] = F.pixel_shuffle(feat, H0//feat.size(2)) for i, feat in enumerate(warped_feats[2:]): warped_feats[i+2] = F.pixel_shuffle(feat, H0//feat.size(2)) x = torch.cat(enc_feats+warped_feats+[flow], dim=1) x = self.in_conv1(x) x = self.in_conv2(x) + x amp, pha = dft(x, self.fftshift) amp = self.amp_conv(amp) + amp pha = self.pha_conv(pha) + pha out = idft(amp, pha) + x out = self.out_conv1(out) + out out = self.out_conv2(out) + out return out