SakuraD commited on
Commit
19c9e2c
·
1 Parent(s): 4aeffda
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .ipynb_checkpoints
2
+ __pycache__/
3
+ .DS_Store
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import Compose, ToTensor, Normalize, ConvertImageDtype
6
+
7
+ import numpy as np
8
+ import cv2
9
+
10
+ import gradio as gr
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from model import IAT
14
+
15
+
16
+ def set_example_image(example: list) -> dict:
17
+ return gr.Image.update(value=example[0])
18
+
19
+
20
+ def dark_inference(img):
21
+ model = IAT()
22
+ checkpoint_file_path = './checkpoint/best_Epoch_lol.pth'
23
+ state_dict = torch.load(checkpoint_file_path, map_location='cpu')
24
+ model.load_state_dict(state_dict)
25
+ model.eval()
26
+
27
+ transform = Compose([
28
+ ToTensor(),
29
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
30
+ ConvertImageDtype(torch.float)
31
+ ])
32
+
33
+ enhanced_img = model(transform(img).unsqueeze(0))
34
+ return enhanced_img[0].permute(1, 2, 0).detach().numpy()
35
+
36
+
37
+ def exposure_inference(img):
38
+ model = IAT()
39
+ checkpoint_file_path = './checkpoint/best_Epoch_exposure.pth'
40
+ state_dict = torch.load(checkpoint_file_path, map_location='cpu')
41
+ model.load_state_dict(state_dict)
42
+ model.eval()
43
+
44
+ transform = Compose([
45
+ ToTensor(),
46
+ ConvertImageDtype(torch.float)
47
+ ])
48
+
49
+ enhanced_img = model(transform(img).unsqueeze(0))
50
+ return enhanced_img[0].permute(1, 2, 0).detach().numpy()
51
+
52
+
53
+ demo = gr.Blocks()
54
+ with demo:
55
+ gr.Markdown(
56
+ """
57
+ # IAT
58
+ Gradio demo for <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>IAT</a>: To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.
59
+ """
60
+ )
61
+
62
+ with gr.Box():
63
+ with gr.Row():
64
+ with gr.Column():
65
+ with gr.Row():
66
+ input_image = gr.Image(label='Input Image', type='numpy')
67
+ with gr.Row():
68
+ dark_button = gr.Button('Low-light Enhancement')
69
+ with gr.Row():
70
+ exposure_button = gr.Button('Exposure Correction')
71
+ with gr.Column():
72
+ res_image = gr.Image(type='numpy', label='Resutls')
73
+ with gr.Row():
74
+ dark_example_images = gr.Dataset(
75
+ components=[input_image],
76
+ samples=[['dark_imgs/1.jpg'], ['dark_imgs/2.jpg'], ['dark_imgs/3.jpg']]
77
+ )
78
+ with gr.Row():
79
+ exposure_example_images = gr.Dataset(
80
+ components=[input_image],
81
+ samples=[['exposure_imgs/1.jpg'], ['exposure_imgs/2.jpg'], ['exposure_imgs/3.jpg']]
82
+ )
83
+
84
+ gr.Markdown(
85
+ """
86
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2205.14871' target='_blank'>You Only Need 90K Parameters to Adapt Light: A Light Weight Transformer for Image Enhancement and Exposure Correction</a> | <a href='https://github.com/cuiziteng/Illumination-Adaptive-Transformer' target='_blank'>Github Repo</a></p>
87
+ """
88
+ )
89
+
90
+ dark_button.click(fn=dark_inference, inputs=input_image, outputs=res_image)
91
+ exposure_button.click(fn=exposure_inference, inputs=input_image, outputs=res_image)
92
+ dark_example_images.click(fn=set_example_image, inputs=dark_example_images, outputs=dark_example_images.components)
93
+ exposure_example_images.click(fn=set_example_image, inputs=exposure_example_images, outputs=exposure_example_images.components)
94
+
95
+ demo.launch(enable_queue=True)
checkpoint/best_Epoch_exposure.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15a9494582f028bef996d4af7145860eaa5d67799d2b0625ed93ff8c546ea3ee
3
+ size 427160
checkpoint/best_Epoch_lol.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9564b7e10882e688ac817ae6fd164544d05b9f74232de56c33ed7f9dabf7bdc4
3
+ size 427160
dark_imgs/1.jpg ADDED
dark_imgs/2.jpg ADDED
dark_imgs/3.jpg ADDED
exposure_imgs/1.jpg ADDED
exposure_imgs/2.jpg ADDED
exposure_imgs/3.jpeg ADDED
model/IAT.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import os
6
+ import math
7
+
8
+ from timm.models.layers import trunc_normal_
9
+ from .blocks import CBlock_ln, SwinTransformerBlock
10
+ from .global_net import Global_pred
11
+
12
+
13
+ class Local_pred(nn.Module):
14
+ def __init__(self, dim=16, number=4, type='ccc'):
15
+ super(Local_pred, self).__init__()
16
+ # initial convolution
17
+ self.conv1 = nn.Conv2d(3, dim, 3, padding=1, groups=1)
18
+ self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
19
+ # main blocks
20
+ block = CBlock_ln(dim)
21
+ block_t = SwinTransformerBlock(dim) # head number
22
+ if type =='ccc':
23
+ #blocks1, blocks2 = [block for _ in range(number)], [block for _ in range(number)]
24
+ blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
25
+ blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
26
+ elif type =='ttt':
27
+ blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
28
+ elif type =='cct':
29
+ blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
30
+ # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
31
+ self.mul_blocks = nn.Sequential(*blocks1, nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
32
+ self.add_blocks = nn.Sequential(*blocks2, nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
33
+
34
+ def forward(self, img):
35
+ img1 = self.relu(self.conv1(img))
36
+ mul = self.mul_blocks(img1)
37
+ add = self.add_blocks(img1)
38
+ return mul, add
39
+
40
+
41
+ # Short Cut Connection on Final Layer
42
+ class Local_pred_S(nn.Module):
43
+ def __init__(self, in_dim=3, dim=16, number=4, type='ccc'):
44
+ super(Local_pred_S, self).__init__()
45
+ # initial convolution
46
+ self.conv1 = nn.Conv2d(in_dim, dim, 3, padding=1, groups=1)
47
+ self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
48
+ # main blocks
49
+ block = CBlock_ln(dim)
50
+ block_t = SwinTransformerBlock(dim) # head number
51
+ if type =='ccc':
52
+ blocks1 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
53
+ blocks2 = [CBlock_ln(16, drop_path=0.01), CBlock_ln(16, drop_path=0.05), CBlock_ln(16, drop_path=0.1)]
54
+ elif type =='ttt':
55
+ blocks1, blocks2 = [block_t for _ in range(number)], [block_t for _ in range(number)]
56
+ elif type =='cct':
57
+ blocks1, blocks2 = [block, block, block_t], [block, block, block_t]
58
+ # block1 = [CBlock_ln(16), nn.Conv2d(16,24,3,1,1)]
59
+ self.mul_blocks = nn.Sequential(*blocks1)
60
+ self.add_blocks = nn.Sequential(*blocks2)
61
+
62
+ self.mul_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.ReLU())
63
+ self.add_end = nn.Sequential(nn.Conv2d(dim, 3, 3, 1, 1), nn.Tanh())
64
+ self.apply(self._init_weights)
65
+
66
+ def _init_weights(self, m):
67
+ if isinstance(m, nn.Linear):
68
+ trunc_normal_(m.weight, std=.02)
69
+ if isinstance(m, nn.Linear) and m.bias is not None:
70
+ nn.init.constant_(m.bias, 0)
71
+ elif isinstance(m, nn.LayerNorm):
72
+ nn.init.constant_(m.bias, 0)
73
+ nn.init.constant_(m.weight, 1.0)
74
+ elif isinstance(m, nn.Conv2d):
75
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
76
+ fan_out //= m.groups
77
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
78
+ if m.bias is not None:
79
+ m.bias.data.zero_()
80
+
81
+ def forward(self, img):
82
+ img1 = self.relu(self.conv1(img))
83
+ # short cut connection
84
+ mul = self.mul_blocks(img1) + img1
85
+ add = self.add_blocks(img1) + img1
86
+ mul = self.mul_end(mul)
87
+ add = self.add_end(add)
88
+ return mul, add
89
+
90
+
91
+ class IAT(nn.Module):
92
+ def __init__(self, in_dim=3, with_global=True, type='lol'):
93
+ super(IAT, self).__init__()
94
+ self.local_net = Local_pred_S(in_dim=in_dim)
95
+ self.with_global = with_global
96
+ if self.with_global:
97
+ self.global_net = Global_pred(in_channels=in_dim, type=type)
98
+
99
+ def apply_color(self, image, ccm):
100
+ shape = image.shape
101
+ image = image.view(-1, 3)
102
+ image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
103
+ image = image.view(shape)
104
+ return torch.clamp(image, 1e-8, 1.0)
105
+
106
+ def forward(self, img_low):
107
+ #print(self.with_global)
108
+ mul, add = self.local_net(img_low)
109
+ img_high = (img_low.mul(mul)).add(add)
110
+
111
+ if not self.with_global:
112
+ return img_high
113
+ else:
114
+ gamma, color = self.global_net(img_low)
115
+ b = img_high.shape[0]
116
+ img_high = img_high.permute(0, 2, 3, 1) # (B,C,H,W) -- (B,H,W,C)
117
+ img_high = torch.stack([self.apply_color(img_high[i,:,:,:], color[i,:,:])**gamma[i,:] for i in range(b)], dim=0)
118
+ img_high = img_high.permute(0, 3, 1, 2) # (B,H,W,C) -- (B,C,H,W)
119
+ return img_high
120
+
121
+
122
+ if __name__ == "__main__":
123
+ img = torch.Tensor(1, 3, 400, 600)
124
+ net = IAT()
125
+ print('total parameters:', sum(param.numel() for param in net.parameters()))
126
+ high = net(img)
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .IAT import IAT
model/blocks.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code copy from uniformer source code:
3
+ https://github.com/Sense-X/UniFormer
4
+ """
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ from functools import partial
9
+ import math
10
+ from timm.models.vision_transformer import VisionTransformer, _cfg
11
+ from timm.models.registry import register_model
12
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
13
+
14
+ # ResMLP's normalization
15
+ class Aff(nn.Module):
16
+ def __init__(self, dim):
17
+ super().__init__()
18
+ # learnable
19
+ self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
20
+ self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
21
+
22
+ def forward(self, x):
23
+ x = x * self.alpha + self.beta
24
+ return x
25
+
26
+ # Color Normalization
27
+ class Aff_channel(nn.Module):
28
+ def __init__(self, dim, channel_first = True):
29
+ super().__init__()
30
+ # learnable
31
+ self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
32
+ self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
33
+ self.color = nn.Parameter(torch.eye(dim))
34
+ self.channel_first = channel_first
35
+
36
+ def forward(self, x):
37
+ if self.channel_first:
38
+ x1 = torch.tensordot(x, self.color, dims=[[-1], [-1]])
39
+ x2 = x1 * self.alpha + self.beta
40
+ else:
41
+ x1 = x * self.alpha + self.beta
42
+ x2 = torch.tensordot(x1, self.color, dims=[[-1], [-1]])
43
+ return x2
44
+
45
+ class Mlp(nn.Module):
46
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
47
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
48
+ super().__init__()
49
+ out_features = out_features or in_features
50
+ hidden_features = hidden_features or in_features
51
+ self.fc1 = nn.Linear(in_features, hidden_features)
52
+ self.act = act_layer()
53
+ self.fc2 = nn.Linear(hidden_features, out_features)
54
+ self.drop = nn.Dropout(drop)
55
+
56
+ def forward(self, x):
57
+ x = self.fc1(x)
58
+ x = self.act(x)
59
+ x = self.drop(x)
60
+ x = self.fc2(x)
61
+ x = self.drop(x)
62
+ return x
63
+
64
+ class CMlp(nn.Module):
65
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
66
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
67
+ super().__init__()
68
+ out_features = out_features or in_features
69
+ hidden_features = hidden_features or in_features
70
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
71
+ self.act = act_layer()
72
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
73
+ self.drop = nn.Dropout(drop)
74
+
75
+ def forward(self, x):
76
+ x = self.fc1(x)
77
+ x = self.act(x)
78
+ x = self.drop(x)
79
+ x = self.fc2(x)
80
+ x = self.drop(x)
81
+ return x
82
+
83
+ class CBlock_ln(nn.Module):
84
+ def __init__(self, dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
85
+ drop_path=0., act_layer=nn.GELU, norm_layer=Aff_channel, init_values=1e-4):
86
+ super().__init__()
87
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
88
+ #self.norm1 = Aff_channel(dim)
89
+ self.norm1 = norm_layer(dim)
90
+ self.conv1 = nn.Conv2d(dim, dim, 1)
91
+ self.conv2 = nn.Conv2d(dim, dim, 1)
92
+ self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
93
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
94
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95
+ #self.norm2 = Aff_channel(dim)
96
+ self.norm2 = norm_layer(dim)
97
+ mlp_hidden_dim = int(dim * mlp_ratio)
98
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
99
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((1, dim, 1, 1)), requires_grad=True)
100
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
101
+
102
+ def forward(self, x):
103
+ x = x + self.pos_embed(x)
104
+ B, C, H, W = x.shape
105
+ #print(x.shape)
106
+ norm_x = x.flatten(2).transpose(1, 2)
107
+ #print(norm_x.shape)
108
+ norm_x = self.norm1(norm_x)
109
+ norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
110
+
111
+
112
+ x = x + self.drop_path(self.gamma_1*self.conv2(self.attn(self.conv1(norm_x))))
113
+ norm_x = x.flatten(2).transpose(1, 2)
114
+ norm_x = self.norm2(norm_x)
115
+ norm_x = norm_x.view(B, H, W, C).permute(0, 3, 1, 2)
116
+ x = x + self.drop_path(self.gamma_2*self.mlp(norm_x))
117
+ return x
118
+
119
+
120
+ def window_partition(x, window_size):
121
+ """
122
+ Args:
123
+ x: (B, H, W, C)
124
+ window_size (int): window size
125
+ Returns:
126
+ windows: (num_windows*B, window_size, window_size, C)
127
+ """
128
+ B, H, W, C = x.shape
129
+ #print(x.shape)
130
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
131
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
132
+ return windows
133
+
134
+
135
+ def window_reverse(windows, window_size, H, W):
136
+ """
137
+ Args:
138
+ windows: (num_windows*B, window_size, window_size, C)
139
+ window_size (int): Window size
140
+ H (int): Height of image
141
+ W (int): Width of image
142
+ Returns:
143
+ x: (B, H, W, C)
144
+ """
145
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
146
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
147
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
148
+ return x
149
+
150
+
151
+ class WindowAttention(nn.Module):
152
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
153
+ It supports both of shifted and non-shifted window.
154
+ Args:
155
+ dim (int): Number of input channels.
156
+ window_size (tuple[int]): The height and width of the window.
157
+ num_heads (int): Number of attention heads.
158
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
159
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
160
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
161
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
162
+ """
163
+
164
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
165
+ super().__init__()
166
+ self.dim = dim
167
+ self.window_size = window_size # Wh, Ww
168
+ self.num_heads = num_heads
169
+ head_dim = dim // num_heads
170
+ self.scale = qk_scale or head_dim ** -0.5
171
+
172
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+
177
+ self.softmax = nn.Softmax(dim=-1)
178
+
179
+ def forward(self, x):
180
+ B_, N, C = x.shape
181
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
182
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
183
+
184
+ q = q * self.scale
185
+ attn = (q @ k.transpose(-2, -1))
186
+
187
+ attn = self.softmax(attn)
188
+
189
+ attn = self.attn_drop(attn)
190
+
191
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
192
+ x = self.proj(x)
193
+ x = self.proj_drop(x)
194
+ return x
195
+
196
+ ## Layer_norm, Aff_norm, Aff_channel_norm
197
+ class SwinTransformerBlock(nn.Module):
198
+ r""" Swin Transformer Block.
199
+ Args:
200
+ dim (int): Number of input channels.
201
+ input_resolution (tuple[int]): Input resulotion.
202
+ num_heads (int): Number of attention heads.
203
+ window_size (int): Window size.
204
+ shift_size (int): Shift size for SW-MSA.
205
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
206
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
207
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
208
+ drop (float, optional): Dropout rate. Default: 0.0
209
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
210
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
211
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
212
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
213
+ """
214
+
215
+ def __init__(self, dim, num_heads=2, window_size=8, shift_size=0,
216
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
217
+ act_layer=nn.GELU, norm_layer=Aff_channel):
218
+ super().__init__()
219
+ self.dim = dim
220
+ self.num_heads = num_heads
221
+ self.window_size = window_size
222
+ self.shift_size = shift_size
223
+ self.mlp_ratio = mlp_ratio
224
+
225
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
226
+ #self.norm1 = norm_layer(dim)
227
+ self.norm1 = norm_layer(dim)
228
+ self.attn = WindowAttention(
229
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
230
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
231
+
232
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
233
+ #self.norm2 = norm_layer(dim)
234
+ self.norm2 = norm_layer(dim)
235
+ mlp_hidden_dim = int(dim * mlp_ratio)
236
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
237
+
238
+ def forward(self, x):
239
+ x = x + self.pos_embed(x)
240
+ B, C, H, W = x.shape
241
+ x = x.flatten(2).transpose(1, 2)
242
+
243
+ shortcut = x
244
+ x = self.norm1(x)
245
+ x = x.view(B, H, W, C)
246
+
247
+ # cyclic shift
248
+ if self.shift_size > 0:
249
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
250
+ else:
251
+ shifted_x = x
252
+
253
+ # partition windows
254
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
255
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
256
+
257
+ # W-MSA/SW-MSA
258
+ attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C
259
+
260
+ # merge windows
261
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
262
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
263
+
264
+ x = shifted_x
265
+ x = x.view(B, H * W, C)
266
+
267
+ # FFN
268
+ x = shortcut + self.drop_path(x)
269
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
270
+ x = x.transpose(1, 2).reshape(B, C, H, W)
271
+
272
+ return x
273
+
274
+
275
+ if __name__ == "__main__":
276
+ os.environ['CUDA_VISIBLE_DEVICES']='1'
277
+ cb_blovk = CBlock_ln(dim = 16)
278
+ x = torch.Tensor(1, 16, 400, 600)
279
+ swin = SwinTransformerBlock(dim=16, num_heads=4)
280
+ x = cb_blovk(x)
281
+ print(x.shape)
model/global_net.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imp
2
+ import torch
3
+ import torch.nn as nn
4
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
5
+ import os
6
+ from .blocks import Mlp
7
+
8
+
9
+ class query_Attention(nn.Module):
10
+ def __init__(self, dim, num_heads=2, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
11
+ super().__init__()
12
+ self.num_heads = num_heads
13
+ head_dim = dim // num_heads
14
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
15
+ self.scale = qk_scale or head_dim ** -0.5
16
+
17
+ self.q = nn.Parameter(torch.ones((1, 10, dim)), requires_grad=True)
18
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
19
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
20
+ self.attn_drop = nn.Dropout(attn_drop)
21
+ self.proj = nn.Linear(dim, dim)
22
+ self.proj_drop = nn.Dropout(proj_drop)
23
+
24
+ def forward(self, x):
25
+ B, N, C = x.shape
26
+ k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
27
+ v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
28
+
29
+ q = self.q.expand(B, -1, -1).view(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
30
+ attn = (q @ k.transpose(-2, -1)) * self.scale
31
+ attn = attn.softmax(dim=-1)
32
+ attn = self.attn_drop(attn)
33
+
34
+ x = (attn @ v).transpose(1, 2).reshape(B, 10, C)
35
+ x = self.proj(x)
36
+ x = self.proj_drop(x)
37
+ return x
38
+
39
+
40
+ class query_SABlock(nn.Module):
41
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
42
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
43
+ super().__init__()
44
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
45
+ self.norm1 = norm_layer(dim)
46
+ self.attn = query_Attention(
47
+ dim,
48
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
49
+ attn_drop=attn_drop, proj_drop=drop)
50
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
51
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
52
+ self.norm2 = norm_layer(dim)
53
+ mlp_hidden_dim = int(dim * mlp_ratio)
54
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
55
+
56
+ def forward(self, x):
57
+ x = x + self.pos_embed(x)
58
+ x = x.flatten(2).transpose(1, 2)
59
+ x = self.drop_path(self.attn(self.norm1(x)))
60
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
61
+ return x
62
+
63
+
64
+ class conv_embedding(nn.Module):
65
+ def __init__(self, in_channels, out_channels):
66
+ super(conv_embedding, self).__init__()
67
+ self.proj = nn.Sequential(
68
+ nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
69
+ nn.BatchNorm2d(out_channels // 2),
70
+ nn.GELU(),
71
+ # nn.Conv2d(out_channels // 2, out_channels // 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
72
+ # nn.BatchNorm2d(out_channels // 2),
73
+ # nn.GELU(),
74
+ nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
75
+ nn.BatchNorm2d(out_channels),
76
+ )
77
+
78
+ def forward(self, x):
79
+ x = self.proj(x)
80
+ return x
81
+
82
+
83
+ class Global_pred(nn.Module):
84
+ def __init__(self, in_channels=3, out_channels=64, num_heads=4, type='exp'):
85
+ super(Global_pred, self).__init__()
86
+ if type == 'exp':
87
+ self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=False) # False in exposure correction
88
+ else:
89
+ self.gamma_base = nn.Parameter(torch.ones((1)), requires_grad=True)
90
+ self.color_base = nn.Parameter(torch.eye((3)), requires_grad=True) # basic color matrix
91
+ # main blocks
92
+ self.conv_large = conv_embedding(in_channels, out_channels)
93
+ self.generator = query_SABlock(dim=out_channels, num_heads=num_heads)
94
+ self.gamma_linear = nn.Linear(out_channels, 1)
95
+ self.color_linear = nn.Linear(out_channels, 1)
96
+
97
+ self.apply(self._init_weights)
98
+
99
+ for name, p in self.named_parameters():
100
+ if name == 'generator.attn.v.weight':
101
+ nn.init.constant_(p, 0)
102
+
103
+ def _init_weights(self, m):
104
+ if isinstance(m, nn.Linear):
105
+ trunc_normal_(m.weight, std=.02)
106
+ if isinstance(m, nn.Linear) and m.bias is not None:
107
+ nn.init.constant_(m.bias, 0)
108
+ elif isinstance(m, nn.LayerNorm):
109
+ nn.init.constant_(m.bias, 0)
110
+ nn.init.constant_(m.weight, 1.0)
111
+
112
+
113
+ def forward(self, x):
114
+ #print(self.gamma_base)
115
+ x = self.conv_large(x)
116
+ x = self.generator(x)
117
+ gamma, color = x[:, 0].unsqueeze(1), x[:, 1:]
118
+ gamma = self.gamma_linear(gamma).squeeze(-1) + self.gamma_base
119
+ #print(self.gamma_base, self.gamma_linear(gamma))
120
+ color = self.color_linear(color).squeeze(-1).view(-1, 3, 3) + self.color_base
121
+ return gamma, color
122
+
123
+ if __name__ == "__main__":
124
+ os.environ['CUDA_VISIBLE_DEVICES']='3'
125
+ #net = Local_pred_new().cuda()
126
+ img = torch.Tensor(8, 3, 400, 600)
127
+ global_net = Global_pred()
128
+ gamma, color = global_net(img)
129
+ print(gamma.shape, color.shape)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ Pillow
5
+ opencv-python
test_dark.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
test_exposure.ipynb ADDED
The diff for this file is too large to render. See raw diff