import torch import numpy as np from torch import nn import torch.nn.functional as F import os import math from timm.models.layers import trunc_normal_ from .blocks import CBlock_ln, SwinTransformerBlock from .global_net import Global_pred class Local_pred(nn.Module): def __init__(self, dim=16, number=4, type='ccc'): super(Local_pred, self).__init__() # initial convolution self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1) self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # main blocks block = CBlock_ln(dim) block_t = SwinTransformerBlock(dim) # head number if type =='ccc': #blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)] blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)] blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)] elif type =='ttt': blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)] elif type =='cct': blocks1, blocks2 = [block, block, block_t], [block, block, block_t] # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)] self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU()) self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh()) def forward(self, img): img1 = self.relu(self.conv1(img)) mul = self.mul_blocks(img1) add = self.add_blocks(img1) return mul, add # Short Cut Connection on Final Layer class Local_pred_S(nn.Module): def __init__(self, in_dim=3, dim=16, number=4, type='ccc'): super(Local_pred_S, self).__init__() # initial convolution self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1) self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # main blocks block = CBlock_ln(dim) block_t = SwinTransformerBlock(dim) # head number if type =='ccc': blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)] blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)] elif type =='ttt': blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)] elif type =='cct': blocks1, blocks2 = [block, block, block_t], [block, block, block_t] # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)] self.mul_blocks = nn.Sequential(*blocks1) self.add_blocks = nn.Sequential(*blocks2) self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU()) self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh()) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, img): img1 = self.relu(self.conv1(img)) # short cut connection mul = self.mul_blocks(img1) + img1 add = self.add_blocks(img1) + img1 mul = self.mul_end(mul) add = self.add_end(add) return mul, add class IAT(nn.Module): def __init__(self, in_dim=3, with_global=True, type='lol'): super(IAT, self).__init__() self.local_net = Local_pred_S(in_dim=in_dim) self.with_global = with_global if self.with_global: self.global_net = Global_pred(in_channels=in_dim, type=type) def apply_color(self, image, ccm): shape = image.shape image = image.view(-1, 3) image = torch.tensordot(image, ccm, dims=[[-1], [-1]]) image = image.view(shape) return torch.clamp(image, 1e-8, 1.0) def forward(self, img_low): #print(self.with_global) mul, add = self.local_net(img_low) img_high = (img_low.mul(mul)).add(add) if not self.with_global: return img_high else: gamma, color = self.global_net(img_low) b = img_high.shape[0] img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C) img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0) img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W) return img_high if __name__ == "__main__": img = torch.Tensor(1, 3, 400, 600) net = IAT() print('total parameters:', sum(param.numel() for param in net.parameters())) high = net(img)