SakuraD's picture
init
19c9e2c
raw
history blame
5.21 kB
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import os
import math
from timm.models.layers import trunc_normal_
from .blocks import CBlock_ln, SwinTransformerBlock
from .global_net import Global_pred
class Local_pred(nn.Module):
def __init__(self, dim=16, number=4, type='ccc'):
super(Local_pred, self).__init__()
# initial convolution
self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# main blocks
block = CBlock_ln(dim)
block_t = SwinTransformerBlock(dim) # head number
if type =='ccc':
#blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
elif type =='ttt':
blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
elif type =='cct':
blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
# block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
def forward(self, img):
img1 = self.relu(self.conv1(img))
mul = self.mul_blocks(img1)
add = self.add_blocks(img1)
return mul, add
# Short Cut Connection on Final Layer
class Local_pred_S(nn.Module):
def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
super(Local_pred_S, self).__init__()
# initial convolution
self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# main blocks
block = CBlock_ln(dim)
block_t = SwinTransformerBlock(dim) # head number
if type =='ccc':
blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
elif type =='ttt':
blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
elif type =='cct':
blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
# block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
self.mul_blocks = nn.Sequential(*blocks1)
self.add_blocks = nn.Sequential(*blocks2)
self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, img):
img1 = self.relu(self.conv1(img))
# short cut connection
mul = self.mul_blocks(img1) + img1
add = self.add_blocks(img1) + img1
mul = self.mul_end(mul)
add = self.add_end(add)
return mul, add
class IAT(nn.Module):
def __init__(self, in_dim=3, with_global=True, type='lol'):
super(IAT, self).__init__()
self.local_net = Local_pred_S(in_dim=in_dim)
self.with_global = with_global
if self.with_global:
self.global_net = Global_pred(in_channels=in_dim, type=type)
def apply_color(self, image, ccm):
shape = image.shape
image = image.view(-1, 3)
image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
image = image.view(shape)
return torch.clamp(image, 1e-8, 1.0)
def forward(self, img_low):
#print(self.with_global)
mul, add = self.local_net(img_low)
img_high = (img_low.mul(mul)).add(add)
if not self.with_global:
return img_high
else:
gamma, color = self.global_net(img_low)
b = img_high.shape[0]
img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C)
img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0)
img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W)
return img_high
if __name__ == "__main__":
img = torch.Tensor(1, 3, 400, 600)
net = IAT()
print('total parameters:', sum(param.numel() for param in net.parameters()))
high = net(img)