VfiTest / modules /components /upr_net_freq2 /frequency_enhance_006.py
SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
# frequency_enhance_006.py (여러가지 변경)
import math
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
import time
class ReshapeLayerNorm(nn.Module):
def __init__(self, dim, norm_layer=nn.LayerNorm):
super(ReshapeLayerNorm, self).__init__()
self.dim = dim
self.norm = norm_layer(dim)
def forward(self, x):
B, C, H, W = x.size()
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.norm(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=H)
return x
class ChannelSelfAttention(nn.Module):
def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0):
super(ChannelSelfAttention, self).__init__()
self.dim = dim
self.num_head = num_head
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim, dim, 1)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, q,k,v, sp=None):
B, C, H, W = q.size()
q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW]
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L]
logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp()
attn = attn * logit_scale
attn = F.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [B, L, C/L, HW]
# head merge
x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W]
x = self.proj_drop(self.proj(x)) # [B, C, H, W]
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0):
super(FeedForward, self).__init__()
self.dim = dim
self.hidden_ratio = hidden_ratio
self.hidden = nn.Conv2d(dim, int(dim*hidden_ratio), 1, bias=bias)
self.drop1 = nn.Dropout(drop)
self.out = nn.Conv2d(int(dim*hidden_ratio), dim, 1, bias=bias)
self.drop2 = nn.Dropout(drop)
self.act = act_layer()
def forward(self, x):
return self.drop2(self.out(self.drop1(self.act(self.hidden(x)))))
def dft(x, fftshift=False):
fft = torch.fft.fft2(x, dim=(2,3), norm='ortho')
fft = torch.fft.fftshift(fft, dim=(2,3)) if fftshift else fft
amplitude = torch.abs(fft)
phase = torch.angle(fft)
return amplitude, phase
def idft(amplitude, phase):
real = amplitude * torch.cos(phase)
imag = amplitude * torch.sin(phase)
out = torch.fft.ifft2(torch.complex(real, imag), dim=(2,3), norm='ortho')
out = torch.abs(out)
return out
class FrequencyEnhancementTransformer(nn.Module):
def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, fftshift=False, *args, **kwargs):
super(FrequencyEnhancementTransformer, self).__init__()
self.c_dim = c_dim
self.feat_dim = feat_dim
self.num_head = num_head
self.hidden_ratio = hidden_ratio
self.fftshift = fftshift
self.c_conv = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=c_dim*2+4, kernel_size=3, stride=1, padding=1, groups=c_dim*2+4),
nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1),
nn.LeakyReLU())
self.feat_conv = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim),
nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1),
nn.LeakyReLU())
self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1)
self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1)
self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1)
self.attn = ChannelSelfAttention(32, num_head)
self.norm1 = ReshapeLayerNorm(32)
self.ffn = FeedForward(32, hidden_ratio)
self.norm2 = ReshapeLayerNorm(32)
self.phase_conv = nn.Sequential(nn.Conv2d(in_channels=32+32, out_channels=32, kernel_size=3, stride=1, padding=1))
self.out_conv = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=32),
nn.Conv2d(in_channels=32, out_channels=feat_dim, kernel_size=1, stride=1),
nn.LeakyReLU())
def forward(self, c0, c1, feat, flow, *args, **kwargs):
B,D,H,W = feat.size()
c = self.c_conv(torch.cat([c0,c1,flow], dim=1)) # [B, 32, H, W]
feat_ = self.feat_conv(feat) # [B, 32, H, W]
amp_c, pha_c = dft(c, self.fftshift) # [B, 32, H, W]
amp_f, pha_f = dft(feat_, self.fftshift) # [B, 32, H, W]
amp_q = self.q_proj(amp_c) # [B, 32, H, W]
amp_k = self.k_proj(amp_c) # [B, 32, H, W]
amp_v = self.v_proj(amp_f) # [B, 32, H, W]
amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) # [B, 32, H, W]
amp = self.norm2(self.ffn(amp_attn)) # [B, 32, H, W]
pha = self.phase_conv(torch.cat([pha_c,pha_f], dim=1)) # [B, 32, H, W]
out = idft(amp, pha) # [B, 32, H, W]
out = self.out_conv(out) # [B, D, H, W]
return out
class FrequencyEnhancementDecoder(nn.Module):
def __init__(self, concat_dim, dim, fftshift, *args, **kwargs):
super(FrequencyEnhancementDecoder, self).__init__()
self.concat_dim = concat_dim
self.dim = dim
self.fftshift = fftshift
self.act = nn.LeakyReLU()
self.in_conv1 = nn.Sequential(nn.Conv2d(concat_dim, concat_dim, 3, 1, 1, groups=concat_dim),
nn.Conv2d(concat_dim, dim, 1, 1),
nn.LeakyReLU())
self.in_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim),
nn.Conv2d(dim, dim, 1, 1),
nn.LeakyReLU())
self.amp_conv = nn.Conv2d(dim, dim, 3, 1, 1)
self.pha_conv = nn.Conv2d(dim, dim, 3, 1, 1)
self.out_conv1 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim),
nn.Conv2d(dim, dim, 1, 1),
nn.LeakyReLU())
self.out_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim),
nn.Conv2d(dim, dim, 1, 1),
nn.LeakyReLU())
def forward(self, enc_feats, warped_feats, flow):
_,_,H0,W0 = enc_feats[0].size()
for i, feat in enumerate(enc_feats[1:]):
enc_feats[i+1] = F.pixel_shuffle(feat, H0//feat.size(2))
for i, feat in enumerate(warped_feats[2:]):
warped_feats[i+2] = F.pixel_shuffle(feat, H0//feat.size(2))
x = torch.cat(enc_feats+warped_feats+[flow], dim=1)
x = self.in_conv1(x)
x = self.in_conv2(x) + x
amp, pha = dft(x, self.fftshift)
amp = self.amp_conv(amp) + amp
pha = self.pha_conv(pha) + pha
out = idft(amp, pha) + x
out = self.out_conv1(out) + out
out = self.out_conv2(out) + out
return out