|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy.random as random |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
from math import exp |
|
|
|
class SiLogLoss(nn.Module): |
|
def __init__(self, lambd=0.5, eps=1e-6): |
|
super().__init__() |
|
self.lambd = lambd |
|
self.eps = eps |
|
|
|
def forward(self, pred, target): |
|
|
|
pred = pred.float() |
|
target = target.float() |
|
|
|
|
|
diff_log = torch.log(target + self.eps) - torch.log(pred + self.eps) |
|
loss = torch.sqrt( |
|
(diff_log ** 2).mean() - self.lambd * (diff_log.mean() ** 2) + self.eps |
|
) |
|
return loss |
|
|
|
class IntegrityPriorLoss(nn.Module): |
|
def __init__(self, epsilon=1e-8): |
|
super().__init__() |
|
|
|
self.epsilon = epsilon |
|
self.max_variance = 0.05 |
|
self.max_grad = 0.05 |
|
|
|
self.sobel_x = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) |
|
self.sobel_y = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False) |
|
|
|
sobel_kernel_x = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], dtype=torch.float32) |
|
sobel_kernel_y = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], dtype=torch.float32) |
|
|
|
self.sobel_x.weight.data = sobel_kernel_x |
|
self.sobel_y.weight.data = sobel_kernel_y |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, mask, depth_map, gt): |
|
|
|
|
|
py = gt*mask + (1-gt)*(1-mask) |
|
FP = (1-py)*mask |
|
FN = (1-py)*gt |
|
logP = -torch.log(py + self.epsilon) |
|
diff = (depth_map-((depth_map*gt).sum()/gt.sum()))**2 |
|
FPdiff = (diff)*FP |
|
FNdiff = (1-diff)*FN |
|
vareight = (FPdiff+FNdiff)*py |
|
variance = logP * vareight |
|
variance_loss = torch.mean(variance) |
|
|
|
grad_x = abs(self.sobel_x(depth_map)) |
|
grad_y = abs(self.sobel_y(depth_map)) |
|
|
|
masked_grad_x = grad_x * logP |
|
masked_grad_y = grad_y * logP |
|
|
|
grad = (masked_grad_x + masked_grad_y) |
|
grad_loss = torch.mean(grad) |
|
|
|
total_loss = variance_loss + grad_loss |
|
return total_loss |
|
def gaussian(window_size, sigma): |
|
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) |
|
return gauss/gauss.sum() |
|
|
|
def create_window(window_size, channel): |
|
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) |
|
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) |
|
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) |
|
return window |
|
|
|
def _ssim(img1, img2, window, window_size, channel, size_average=True): |
|
mu1 = F.conv2d(img1, window, padding = window_size//2, groups=channel) |
|
mu2 = F.conv2d(img2, window, padding = window_size//2, groups=channel) |
|
|
|
mu1_sq = mu1.pow(2) |
|
mu2_sq = mu2.pow(2) |
|
mu1_mu2 = mu1*mu2 |
|
|
|
sigma1_sq = F.conv2d(img1*img1, window, padding=window_size//2, groups=channel) - mu1_sq |
|
sigma2_sq = F.conv2d(img2*img2, window, padding=window_size//2, groups=channel) - mu2_sq |
|
sigma12 = F.conv2d(img1*img2, window, padding=window_size//2, groups=channel) - mu1_mu2 |
|
|
|
C1 = 0.01**2 |
|
C2 = 0.03**2 |
|
|
|
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) |
|
|
|
if size_average: |
|
return ssim_map.mean() |
|
else: |
|
return ssim_map.mean(1).mean(1).mean(1) |
|
|
|
class SSIMLoss(torch.nn.Module): |
|
def __init__(self, window_size=11, size_average=True): |
|
super(SSIMLoss, self).__init__() |
|
self.window_size = window_size |
|
self.size_average = size_average |
|
self.channel = 1 |
|
self.window = create_window(window_size, self.channel) |
|
|
|
def forward(self, img1, img2): |
|
(_, channel, _, _) = img1.size() |
|
if channel == self.channel and self.window.data.type() == img1.data.type(): |
|
window = self.window |
|
else: |
|
window = create_window(self.window_size, channel) |
|
if img1.is_cuda: |
|
window = window.cuda(img1.get_device()) |
|
window = window.type_as(img1) |
|
self.window = window |
|
self.channel = channel |
|
return 1 - (1 + _ssim(img1, img2, window, self.window_size, channel, self.size_average)) / 2 |
|
|
|
def circular_highPassFiltering(img, ratio): |
|
device = img.device |
|
batch_size,_,height,width = img.shape |
|
sigma = (height * (ratio[...,None,None])) / 4 |
|
center_h = height // 2 |
|
center_w = width // 2 |
|
grid_y, grid_x = torch.meshgrid(torch.arange(-center_h, height - center_h), |
|
torch.arange(-center_w, width - center_w)) |
|
grid_y = grid_y[None,None,...].repeat(batch_size, 1, 1, 1).to(device) |
|
grid_x = grid_x[None,None,...].repeat(batch_size, 1, 1, 1).to(device) |
|
|
|
gaussian_values = (1 / (2 * torch.pi * sigma ** 2)) * torch.exp(-(grid_x ** 2 + grid_y ** 2) / (2 * sigma ** 2)) |
|
gmin = gaussian_values.flatten(-2).min(dim=-1)[0][...,None,None] |
|
gmax = gaussian_values.flatten(-2).max(dim=-1)[0][...,None,None] |
|
decreasing_matrix = (gaussian_values-gmin) / (gmax-gmin) |
|
mask = ((0.5-decreasing_matrix)*100).sigmoid() |
|
fft = torch.fft.fft2(img) |
|
fft_shift = torch.fft.fftshift(fft,dim=(2,3)) |
|
fft_shift = torch.mul(fft_shift, mask) |
|
idft_shift = torch.fft.ifftshift(fft_shift,dim=(2,3)) |
|
ifimg = torch.fft.ifft2(idft_shift) |
|
ifimg = torch.abs(ifimg) |
|
ifmin = ifimg.flatten(-2).min(dim=-1)[0][...,None,None] |
|
ifmax = ifimg.flatten(-2).max(dim=-1)[0][...,None,None] |
|
ifimg = (ifimg-ifmin) / (ifmax-ifmin) |
|
return mask,ifimg |
|
|
|
def _upsample_like(src,tar,mode='bilinear'): |
|
if mode == 'bilinear': |
|
src = F.upsample(src,size=tar.shape[2:],mode=mode,align_corners=True) |
|
else: |
|
src = F.upsample(src,size=tar.shape[2:],mode=mode) |
|
return src |
|
|
|
def _upsample_(src,size,mode='bilinear'): |
|
if mode == 'bilinear': |
|
src = F.upsample(src,size=size,mode=mode,align_corners=True) |
|
else: |
|
src = F.upsample(src,size=size,mode=mode) |
|
return src |
|
|
|
def patchfy(x,p=4,c=4): |
|
h = w = x.shape[2] // p |
|
x = x.reshape(shape=(x.shape[0], c, h, p, w, p)) |
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
x = x.reshape(shape=(x.shape[0], h * w, p**2 * c)) |
|
return x |
|
|
|
def unpatchfy(x,p=4,c=4): |
|
h = w = round(x.shape[1]**0.5) |
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) |
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
x = x.reshape(shape=(x.shape[0], c, h * p, h * p)) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def structure_loss(pred, mask): |
|
weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) |
|
wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') |
|
wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) |
|
|
|
|
|
pred = torch.sigmoid(pred) |
|
inter = ((pred * mask) * weit).sum(dim=(2, 3)) |
|
|
|
union = ((pred + mask) * weit).sum(dim=(2, 3)) |
|
wiou = 1-(inter+1)/(union-inter+1) |
|
|
|
return (wbce+wiou).mean() |
|
|
|
def iou_loss(pred, mask): |
|
eps = 1e-6 |
|
inter = (pred * mask).sum(dim=(2, 3)) |
|
union = (pred + mask).sum(dim=(2, 3)) - inter |
|
iou = 1 - (inter + eps) / (union + eps) |
|
return iou.mean() |
|
|
|
def dice_loss(pred, mask): |
|
eps = 1e-6 |
|
N = pred.size()[0] |
|
pred_flat = pred.view(N,-1) |
|
mask_flat = mask.view(N,-1) |
|
|
|
intersection = (pred_flat * mask_flat).sum(1) |
|
dice_coefficient = (2. * intersection + eps) / (pred_flat.sum(1) + mask_flat.sum(1) + eps) |
|
dice_loss_value = 1 - dice_coefficient.sum()/N |
|
return dice_loss_value |
|
|
|
class LargeK(nn.Module): |
|
""" LargeK Block. |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
drop_path (float): Stochastic depth rate. Default: 0.0 |
|
""" |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.channel_split = nn.Conv2d(dim,dim*3,kernel_size=1) |
|
self.dwconv1 = nn.Conv2d(dim, dim, kernel_size=7, dilation=2, padding=6, groups=dim) |
|
self.dwconv2 = nn.Conv2d(dim, dim, kernel_size=7, dilation=4, padding=12, groups=dim) |
|
self.dwconv3 = nn.Conv2d(dim, dim, kernel_size=7, dilation=8, padding=24, groups=dim) |
|
self.channel_mix = nn.Conv2d(dim*3,dim,kernel_size=1) |
|
|
|
def forward(self, x): |
|
x = self.channel_split(x) |
|
x1,x2,x3 = torch.chunk(x,3,dim=1) |
|
x1 = self.dwconv1(x1) |
|
x2 = self.dwconv2(x2) |
|
x3 = self.dwconv3(x3) |
|
x = torch.cat([x1,x2,x3],dim=1) |
|
x = self.channel_mix(x) |
|
return x |
|
|
|
class GANLoss(nn.Module): |
|
"""Define different GAN objectives. |
|
|
|
The GANLoss class abstracts away the need to create the target label tensor |
|
that has the same size as the input. |
|
""" |
|
|
|
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): |
|
""" Initialize the GANLoss class. |
|
|
|
Parameters: |
|
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. |
|
target_real_label (bool) - - label for a real image |
|
target_fake_label (bool) - - label of a fake image |
|
|
|
Note: Do not use sigmoid as the last layer of Discriminator. |
|
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. |
|
""" |
|
super(GANLoss, self).__init__() |
|
self.register_buffer('real_label', torch.tensor(target_real_label)) |
|
self.register_buffer('fake_label', torch.tensor(target_fake_label)) |
|
self.gan_mode = gan_mode |
|
if gan_mode == 'lsgan': |
|
self.loss = nn.MSELoss() |
|
elif gan_mode == 'vanilla': |
|
self.loss = nn.BCEWithLogitsLoss() |
|
elif gan_mode in ['wgangp']: |
|
self.loss = None |
|
else: |
|
raise NotImplementedError('gan mode %s not implemented' % gan_mode) |
|
|
|
def get_target_tensor(self, prediction, target_is_real): |
|
"""Create label tensors with the same size as the input. |
|
|
|
Parameters: |
|
prediction (tensor) - - tpyically the prediction from a discriminator |
|
target_is_real (bool) - - if the ground truth label is for real images or fake images |
|
|
|
Returns: |
|
A label tensor filled with ground truth label, and with the size of the input |
|
""" |
|
|
|
if target_is_real: |
|
target_tensor = self.real_label |
|
else: |
|
target_tensor = self.fake_label |
|
return target_tensor.expand_as(prediction) |
|
|
|
def __call__(self, prediction, target_is_real): |
|
"""Calculate loss given Discriminator's output and grount truth labels. |
|
|
|
Parameters: |
|
prediction (tensor) - - tpyically the prediction output from a discriminator |
|
target_is_real (bool) - - if the ground truth label is for real images or fake images |
|
|
|
Returns: |
|
the calculated loss. |
|
""" |
|
if self.gan_mode in ['lsgan', 'vanilla']: |
|
target_tensor = self.get_target_tensor(prediction, target_is_real) |
|
loss = self.loss(prediction, target_tensor) |
|
elif self.gan_mode == 'wgangp': |
|
if target_is_real: |
|
loss = -prediction.mean() |
|
else: |
|
loss = prediction.mean() |
|
return loss |
|
|
|
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device): |
|
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1) |
|
|
|
ids = torch.arange(0, output_dim // 2, dtype=torch.float) |
|
theta = torch.pow(10000, -2 * ids / output_dim) |
|
|
|
|
|
embeddings = position * theta |
|
|
|
|
|
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) |
|
|
|
|
|
embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) |
|
|
|
|
|
|
|
embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim)) |
|
embeddings = embeddings.to(device) |
|
return embeddings |
|
|
|
def RoPE(q, k): |
|
|
|
use_multi_head = True |
|
if q.size() == 3 and k.size() == 3: |
|
use_multi_head = False |
|
q, k = q[:,None,...], k[:,None,...] |
|
batch_size = q.shape[0] |
|
nums_head = q.shape[1] |
|
max_len = q.shape[2] |
|
output_dim = q.shape[-1] |
|
|
|
|
|
pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device) |
|
|
|
|
|
|
|
|
|
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) |
|
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) |
|
|
|
|
|
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) |
|
q2 = q2.reshape(q.shape) |
|
|
|
|
|
q = q * cos_pos + q2 * sin_pos |
|
|
|
k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) |
|
k2 = k2.reshape(k.shape) |
|
|
|
k = k * cos_pos + k2 * sin_pos |
|
if not use_multi_head: |
|
q, k = q[:,0], k[:,0] |
|
return q, k |
|
|
|
class SwiGLU(nn.Module): |
|
def __init__(self, hidden_size: int, intermediate_size: int) -> None: |
|
super().__init__() |
|
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) |
|
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) |
|
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
class LayerNorm(nn.Module): |
|
""" LayerNorm that supports two data formats: channels_last (default) or channels_first. |
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
|
with shape (batch_size, channels, height, width). |
|
""" |
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
|
self.eps = eps |
|
self.data_format = data_format |
|
if self.data_format not in ["channels_last", "channels_first"]: |
|
raise NotImplementedError |
|
self.normalized_shape = (normalized_shape, ) |
|
|
|
def forward(self, x): |
|
if self.data_format == "channels_last": |
|
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
elif self.data_format == "channels_first": |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
return x |
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, hidden_size: int, eps: float = 1e-6, data_format="channels_first") -> None: |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
self.data_format = data_format |
|
|
|
def _norm(self, hidden_states): |
|
if self.data_format == "channels_first": |
|
variance = hidden_states.pow(2).mean(dim=(1), keepdim=True) |
|
elif self.data_format == "channels_last": |
|
variance = hidden_states.pow(2).mean(dim=-1, keepdim=True) |
|
return hidden_states * torch.rsqrt(variance + self.eps) |
|
|
|
def forward(self, hidden_states): |
|
if self.data_format == "channels_first": |
|
return self.weight[..., None, None] * self._norm(hidden_states.float()).type_as(hidden_states) |
|
elif self.data_format == "channels_last": |
|
return self.weight * self._norm(hidden_states.float()).type_as(hidden_states) |
|
|
|
class GRN(nn.Module): |
|
""" GRN (Global Response Normalization) layer |
|
""" |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
|
def forward(self, x): |
|
Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) |
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
|
return self.gamma * (x * Nx) + self.beta + x |
|
|
|
class DUpsampling(nn.Module): |
|
def __init__(self, inplanes, scale, pad=0): |
|
super(DUpsampling, self).__init__() |
|
self.conv1 = nn.Conv2d(inplanes, inplanes* scale * scale, kernel_size=1, padding = pad) |
|
self.scale = scale |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
N, C, H, W = x.size() |
|
|
|
x_permuted = x.permute(0, 2, 3, 1) |
|
|
|
|
|
x_permuted = x_permuted.contiguous().view((N, H, W * self.scale, int(C / (self.scale)))) |
|
|
|
|
|
x_permuted = x_permuted.permute(0, 2, 1, 3) |
|
|
|
x_permuted = x_permuted.contiguous().view((N, W * self.scale, H * self.scale, int(C / (self.scale * self.scale)))) |
|
|
|
|
|
x = x_permuted.permute(0, 3, 2, 1) |
|
|
|
return x |
|
|
|
class REsampling(nn.Module): |
|
def __init__(self, scale): |
|
super(REsampling, self).__init__() |
|
self.scale = scale |
|
|
|
def forward(self, x): |
|
N, C, H, W = x.size() |
|
|
|
x_permuted = x.permute(0, 2, 3, 1) |
|
|
|
|
|
x_permuted = x_permuted.contiguous().view((N, H, W * self.scale, int(C / (self.scale)))) |
|
|
|
|
|
x_permuted = x_permuted.permute(0, 2, 1, 3) |
|
|
|
x_permuted = x_permuted.contiguous().view((N, W * self.scale, H * self.scale, int(C / (self.scale * self.scale)))) |
|
|
|
|
|
x = x_permuted.permute(0, 3, 2, 1) |
|
|
|
return x |
|
|
|
class Dcrop(nn.Module): |
|
def __init__(self,inplanes,cropscale=2): |
|
super(Dcrop, self).__init__() |
|
self.conv = nn.Conv2d(inplanes*cropscale*cropscale, inplanes*cropscale*cropscale, kernel_size=3, padding = 1) |
|
self.cropscale = cropscale |
|
|
|
def forward(self, x): |
|
B,C,H,W = x.size() |
|
x_permuted = x.permute(0, 2, 3, 1) |
|
x_permuted = x_permuted.contiguous().view((B, H, W//self.cropscale, C*self.cropscale)) |
|
x_permuted = x_permuted.permute(0, 2, 1, 3) |
|
x_permuted = x_permuted.contiguous().view((B, W//self.cropscale, H//self.cropscale, C*self.cropscale*self.cropscale)) |
|
x = x_permuted.permute(0, 3, 2, 1) |
|
x = self.conv(x)+x |
|
return x |
|
|
|
def show_gray_images(images, m=8, alpha=3, cmap='coolwarm',save_path=None): |
|
if len(images.size()) == 2: |
|
plt.imshow(images, cmap=cmap) |
|
plt.axis('off') |
|
else: |
|
n, h, w = images.shape |
|
if n == 1: |
|
plt.imshow(images[0], cmap=cmap) |
|
plt.axis('off') |
|
else: |
|
if m > n: m = n |
|
num_rows = (n + m - 1) // m |
|
fig, axes = plt.subplots(num_rows, m, figsize=(m * 2*alpha, num_rows * 2*alpha)) |
|
plt.subplots_adjust(wspace=0.05, hspace=0.05) |
|
for i in range(num_rows): |
|
for j in range(m): |
|
idx = i*m + j |
|
if m == 1 or num_rows == 1: |
|
axes[idx].imshow(images[idx], cmap=cmap) |
|
axes[idx].axis('off') |
|
elif idx < n: |
|
axes[i, j].imshow(images[idx], cmap=cmap) |
|
axes[i, j].axis('off') |
|
if save_path is not None: |
|
plt.savefig(save_path) |
|
plt.close() |
|
else: |
|
plt.show() |
|
|
|
|
|
|
|
|