|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.fft as fft |
|
import cv2 |
|
import numpy as np |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
|
|
|
|
def lowpass(input, limit): |
|
pass1 = torch.abs(fft.rfftfreq(input.shape[-1])) < limit |
|
pass2 = torch.abs(fft.fftfreq(input.shape[-2])) < limit |
|
kernel = torch.outer(pass2, pass1) |
|
fft_input = fft.rfft2(input) |
|
return fft.irfft2(fft_input*kernel, s=input.shape[-2:]) |
|
|
|
class HighFrequencyLoss(nn.Module): |
|
def __init__(self, size=(224,224)): |
|
super(HighFrequencyLoss, self).__init__() |
|
''' |
|
self.h,self.w = size |
|
self.lpf = torch.zeros((self.h,1)) |
|
R = (self.h+self.w)//8 |
|
for x in range(self.w): |
|
for y in range(self.h): |
|
if ((x-(self.w-1)/2)**2 + (y-(self.h-1)/2)**2) < (R**2): |
|
self.lpf[y,x] = 1 |
|
self.hpf = 1-self.lpf |
|
''' |
|
|
|
def forward(self, x): |
|
f = fft.fftn(x, dim=(2,3)) |
|
loss = torch.abs(f).mean() |
|
|
|
|
|
|
|
|
|
|
|
return loss |
|
|
|
if __name__ == '__main__': |
|
import pdb |
|
pdb.set_trace() |
|
HF = HighFrequencyLoss() |
|
transform = transforms.Compose([transforms.ToTensor()]) |
|
|
|
|
|
img = cv2.imread('../tmp.jpg') |
|
H,W,C = img.shape |
|
imgs = [] |
|
for i in range(10): |
|
img_ = img[:, 224*i:224*(i+1), :] |
|
print(img_.shape) |
|
img_tensor = transform(Image.fromarray(img_[:,:,::-1])).unsqueeze(0) |
|
loss = HF(img_tensor).item() |
|
cv2.putText(img_, str(loss)[:6], (5,50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2) |
|
imgs.append(img_) |
|
|
|
cv2.imwrite('tmp.jpg', cv2.hconcat(imgs)) |
|
|
|
|
|
|
|
|
|
|