try fitst commit
Browse files- .gitignore +8 -0
- app.py +71 -0
- assets/00344.mp4 +0 -0
- assets/jackie_chan.jpg +0 -0
- models/decoder.py +170 -0
- models/encoder.py +106 -0
- models/model.py +58 -0
- requirements.txt +15 -0
- utils_/loss.py +48 -0
- utils_/visualization.py +98 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
checkpoints
|
2 |
+
diffusers_cache
|
3 |
+
hub
|
4 |
+
wandb
|
5 |
+
__pycache__
|
6 |
+
*.pyc
|
7 |
+
flagged
|
8 |
+
gif
|
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.model import Model as AutoLink
|
2 |
+
import gradio as gr
|
3 |
+
import PIL
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
+
|
11 |
+
|
12 |
+
autolink = AutoLink.load_from_checkpoint(os.path.join("checkpoints", "celeba_wild_k32_m0.8_b16_t0.00075_sklr512", "model.ckpt"))
|
13 |
+
autolink.to(device)
|
14 |
+
|
15 |
+
|
16 |
+
def predict_image(image_in: PIL.Image.Image) -> PIL.Image.Image:
|
17 |
+
if image_in == None:
|
18 |
+
raise gr.Error("Please upload a video or image.")
|
19 |
+
edge_map = autolink(image_in)
|
20 |
+
return edge_map
|
21 |
+
|
22 |
+
|
23 |
+
def predict_video(video_in: str) -> str:
|
24 |
+
if video_in == None:
|
25 |
+
raise gr.Error("Please upload a video or image.")
|
26 |
+
video_out = video_in[:-4] + '_out.mp4'
|
27 |
+
video_in = imageio.get_reader(video_in)
|
28 |
+
writer = imageio.get_writer(video_out, mode='I', fps=video_in.get_meta_data()['fps'])
|
29 |
+
for image_in in video_in:
|
30 |
+
image_in = PIL.Image.fromarray(image_in)
|
31 |
+
edge_map = autolink(image_in)
|
32 |
+
writer.append_data(np.array(edge_map))
|
33 |
+
writer.close()
|
34 |
+
return video_out
|
35 |
+
|
36 |
+
|
37 |
+
with gr.Blocks() as blocks:
|
38 |
+
gr.Markdown("""
|
39 |
+
# AutoLink
|
40 |
+
## Self-supervised Learning of Human Skeletons and Object Outlines by Linking Keypoints
|
41 |
+
* [Paper](https://arxiv.org/abs/2205.10636)
|
42 |
+
* [Project Page](https://xingzhehe.github.io/autolink/)
|
43 |
+
* [GitHub](https://github.com/xingzhehe/AutoLink-Self-supervised-Learning-of-Human-Skeletons-and-Object-Outlines-by-Linking-Keypoints)
|
44 |
+
""")
|
45 |
+
|
46 |
+
with gr.Tab("Image"):
|
47 |
+
with gr.Row():
|
48 |
+
with gr.Column():
|
49 |
+
image_in = gr.Image(source="upload", type="pil", visible=True)
|
50 |
+
with gr.Column():
|
51 |
+
image_out = gr.Image()
|
52 |
+
run_btn = gr.Button("Run")
|
53 |
+
run_btn.click(fn=predict_image, inputs=[image_in], outputs=[image_out])
|
54 |
+
gr.Examples(fn=predict_image, examples=[["assets/jackie_chan.jpg", None]],
|
55 |
+
inputs=[image_in], outputs=[image_out],
|
56 |
+
cache_examples=False)
|
57 |
+
|
58 |
+
with gr.Tab("Video") as tab:
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Column():
|
61 |
+
video_in = gr.Video(source="upload", type="mp4")
|
62 |
+
with gr.Column():
|
63 |
+
video_out = gr.Video()
|
64 |
+
run_btn = gr.Button("Run")
|
65 |
+
run_btn.click(fn=predict_video, inputs=[video_in], outputs=[video_out])
|
66 |
+
gr.Examples(fn=predict_video, examples=[["assets/00344.mp4"],],
|
67 |
+
inputs=[video_in], outputs=[video_out],
|
68 |
+
cache_examples=False)
|
69 |
+
|
70 |
+
blocks.launch()
|
71 |
+
|
assets/00344.mp4
ADDED
Binary file (165 kB). View file
|
|
assets/jackie_chan.jpg
ADDED
models/decoder.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from typing import Union
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
|
7 |
+
|
8 |
+
def gen_grid2d(grid_size: int, left_end: float=-1, right_end: float=1) -> torch.Tensor:
|
9 |
+
"""
|
10 |
+
Generate a grid of size (grid_size, grid_size, 2) with coordinate values in the range [left_end, right_end]
|
11 |
+
"""
|
12 |
+
x = torch.linspace(left_end, right_end, grid_size)
|
13 |
+
x, y = torch.meshgrid([x, x], indexing='ij')
|
14 |
+
grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(grid_size, grid_size, 2)
|
15 |
+
return grid
|
16 |
+
|
17 |
+
|
18 |
+
def draw_lines(paired_joints: torch.Tensor, heatmap_size: int=16, thick: Union[float, torch.Tensor]=1e-2) -> torch.Tensor:
|
19 |
+
"""
|
20 |
+
Draw lines on a grid.
|
21 |
+
:param paired_joints: (batch_size, n_points, 2, 2)
|
22 |
+
:return: (batch_size, n_points, grid_size, grid_size)
|
23 |
+
dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
|
24 |
+
"""
|
25 |
+
bs, n_points, _, _ = paired_joints.shape
|
26 |
+
start = paired_joints[:, :, 0, :] # (batch_size, n_points, 2)
|
27 |
+
end = paired_joints[:, :, 1, :] # (batch_size, n_points, 2)
|
28 |
+
paired_diff = end - start # (batch_size, n_points, 2)
|
29 |
+
grid = gen_grid2d(heatmap_size).to(paired_joints.device).reshape(1, 1, -1, 2)
|
30 |
+
diff_to_start = grid - start.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
|
31 |
+
# (batch_size, n_points, heatmap_size**2)
|
32 |
+
t = (diff_to_start @ paired_diff.unsqueeze(-1)).squeeze(-1) / (1e-8+paired_diff.square().sum(dim=-1, keepdim=True))
|
33 |
+
|
34 |
+
diff_to_end = grid - end.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
|
35 |
+
|
36 |
+
before_start = (t <= 0).float() * diff_to_start.square().sum(dim=-1)
|
37 |
+
after_end = (t >= 1).float() * diff_to_end.square().sum(dim=-1)
|
38 |
+
between_start_end = (0 < t).float() * (t < 1).float() * (grid - (start.unsqueeze(-2) + t.unsqueeze(-1) * paired_diff.unsqueeze(-2))).square().sum(dim=-1)
|
39 |
+
|
40 |
+
squared_dist = (before_start + after_end + between_start_end).reshape(bs, n_points, heatmap_size, heatmap_size)
|
41 |
+
heatmaps = torch.exp(- squared_dist / thick)
|
42 |
+
return heatmaps
|
43 |
+
|
44 |
+
|
45 |
+
class DownBlock(nn.Module):
|
46 |
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
47 |
+
super().__init__()
|
48 |
+
self.net = nn.Sequential(
|
49 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1),
|
50 |
+
nn.BatchNorm2d(out_channels),
|
51 |
+
nn.LeakyReLU(0.2, True),
|
52 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1),
|
53 |
+
nn.BatchNorm2d(out_channels),
|
54 |
+
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
|
55 |
+
nn.LeakyReLU(0.2, True),
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
59 |
+
x = self.net(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class UpBlock(nn.Module):
|
64 |
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
65 |
+
super().__init__()
|
66 |
+
self.net = nn.Sequential(
|
67 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
68 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1),
|
69 |
+
nn.BatchNorm2d(out_channels),
|
70 |
+
nn.LeakyReLU(0.2, True),
|
71 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1),
|
72 |
+
nn.BatchNorm2d(out_channels),
|
73 |
+
nn.LeakyReLU(0.2, True),
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77 |
+
x = self.net(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
class Decoder(nn.Module):
|
82 |
+
def __init__(self, hyper_paras: pl.LightningModule.hparams) -> None:
|
83 |
+
super().__init__()
|
84 |
+
self.n_parts = hyper_paras['n_parts']
|
85 |
+
self.thick = hyper_paras['thick']
|
86 |
+
self.sklr = hyper_paras['sklr']
|
87 |
+
self.skeleton_idx = torch.triu_indices(self.n_parts, self.n_parts, offset=1)
|
88 |
+
self.n_skeleton = len(self.skeleton_idx[0])
|
89 |
+
|
90 |
+
self.alpha = nn.Parameter(torch.tensor(1.0), requires_grad=True)
|
91 |
+
|
92 |
+
skeleton_scalar = (torch.randn(self.n_parts, self.n_parts) / 10 - 4) / self.sklr
|
93 |
+
self.skeleton_scalar = nn.Parameter(skeleton_scalar, requires_grad=True)
|
94 |
+
|
95 |
+
self.down0 = nn.Sequential(
|
96 |
+
nn.Conv2d(3 + 1, 64, kernel_size=(3, 3), padding=1),
|
97 |
+
nn.LeakyReLU(0.2, True),
|
98 |
+
)
|
99 |
+
|
100 |
+
self.down1 = DownBlock(64, 128) # 64
|
101 |
+
self.down2 = DownBlock(128, 256) # 32
|
102 |
+
self.down3 = DownBlock(256, 512) # 16
|
103 |
+
self.down4 = DownBlock(512, 512) # 8
|
104 |
+
|
105 |
+
self.up1 = UpBlock(512, 512) # 16
|
106 |
+
self.up2 = UpBlock(512 + 512, 256) # 32
|
107 |
+
self.up3 = UpBlock(256 + 256, 128) # 64
|
108 |
+
self.up4 = UpBlock(128 + 128, 64) # 64
|
109 |
+
|
110 |
+
self.conv = nn.Conv2d(64+64, 3, kernel_size=(3, 3), padding=1)
|
111 |
+
|
112 |
+
for m in self.modules():
|
113 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
114 |
+
nn.init.kaiming_normal_(m.weight, a=0.2)
|
115 |
+
if m.bias is not None:
|
116 |
+
m.bias.data.zero_()
|
117 |
+
|
118 |
+
def skeleton_scalar_matrix(self) -> torch.Tensor:
|
119 |
+
"""
|
120 |
+
Give the skeleton scalar matrix
|
121 |
+
:return: (n_parts, n_parts)
|
122 |
+
"""
|
123 |
+
skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
|
124 |
+
skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
|
125 |
+
skeleton_scalar = skeleton_scalar + skeleton_scalar.transpose(1, 0)
|
126 |
+
return skeleton_scalar
|
127 |
+
|
128 |
+
def rasterize(self, keypoints: torch.Tensor, output_size: int=128) -> torch.Tensor:
|
129 |
+
"""
|
130 |
+
Generate edge heatmap from keypoints, where edges are weighted by the learned scalars.
|
131 |
+
:param keypoints: (batch_size, n_points, 2)
|
132 |
+
:return: (batch_size, 1, heatmap_size, heatmap_size)
|
133 |
+
"""
|
134 |
+
|
135 |
+
paired_joints = torch.stack([keypoints[:, self.skeleton_idx[0], :2], keypoints[:, self.skeleton_idx[1], :2]], dim=2)
|
136 |
+
|
137 |
+
skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
|
138 |
+
skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
|
139 |
+
skeleton_scalar = skeleton_scalar[self.skeleton_idx[0], self.skeleton_idx[1]].reshape(1, self.n_skeleton, 1, 1)
|
140 |
+
|
141 |
+
skeleton_heatmap_sep = draw_lines(paired_joints, heatmap_size=output_size, thick=self.thick)
|
142 |
+
skeleton_heatmap_sep = skeleton_heatmap_sep * skeleton_scalar.reshape(1, self.n_skeleton, 1, 1)
|
143 |
+
skeleton_heatmap = skeleton_heatmap_sep.max(dim=1, keepdim=True)[0]
|
144 |
+
return skeleton_heatmap
|
145 |
+
|
146 |
+
def forward(self, input_dict: dict) -> dict:
|
147 |
+
skeleton_heatmap = self.rasterize(input_dict['keypoints'])
|
148 |
+
|
149 |
+
x = torch.cat([input_dict['damaged_img'] * self.alpha, skeleton_heatmap], dim=1)
|
150 |
+
|
151 |
+
down_128 = self.down0(x)
|
152 |
+
down_64 = self.down1(down_128)
|
153 |
+
down_32 = self.down2(down_64)
|
154 |
+
down_16 = self.down3(down_32)
|
155 |
+
down_8 = self.down4(down_16)
|
156 |
+
up_8 = down_8
|
157 |
+
up_16 = torch.cat([self.up1(up_8), down_16], dim=1)
|
158 |
+
up_32 = torch.cat([self.up2(up_16), down_32], dim=1)
|
159 |
+
up_64 = torch.cat([self.up3(up_32), down_64], dim=1)
|
160 |
+
up_128 = torch.cat([self.up4(up_64), down_128], dim=1)
|
161 |
+
img = self.conv(up_128)
|
162 |
+
|
163 |
+
input_dict['heatmap'] = skeleton_heatmap
|
164 |
+
input_dict['img'] = img
|
165 |
+
return input_dict
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == '__main__':
|
169 |
+
model = Decoder({'z_dim': 256, 'n_parts': 10, 'n_embedding': 128, 'tau': 0.01})
|
170 |
+
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
|
models/encoder.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
|
6 |
+
|
7 |
+
def gen_grid2d(grid_size: int, left_end: float=-1, right_end: float=1) -> torch.Tensor:
|
8 |
+
"""
|
9 |
+
Generate a grid of size (grid_size, grid_size, 2) with coordinate values in the range [left_end, right_end]
|
10 |
+
"""
|
11 |
+
x = torch.linspace(left_end, right_end, grid_size)
|
12 |
+
x, y = torch.meshgrid([x, x], indexing='ij')
|
13 |
+
grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(grid_size, grid_size, 2)
|
14 |
+
return grid
|
15 |
+
|
16 |
+
|
17 |
+
class ResBlock(nn.Module):
|
18 |
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
19 |
+
super().__init__()
|
20 |
+
self.conv_res = nn.Sequential(
|
21 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
22 |
+
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
|
23 |
+
nn.BatchNorm2d(out_channels)
|
24 |
+
)
|
25 |
+
|
26 |
+
self.net = nn.Sequential(
|
27 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
28 |
+
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
|
29 |
+
nn.BatchNorm2d(out_channels),
|
30 |
+
nn.LeakyReLU(0.2, True),
|
31 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
|
32 |
+
nn.BatchNorm2d(out_channels)
|
33 |
+
)
|
34 |
+
|
35 |
+
self.relu = nn.LeakyReLU(0.2, True)
|
36 |
+
|
37 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
38 |
+
res = self.conv_res(x)
|
39 |
+
x = self.net(x)
|
40 |
+
return self.relu(x + res)
|
41 |
+
|
42 |
+
|
43 |
+
class TransposedBlock(nn.Module):
|
44 |
+
def __init__(self, in_channels: int, out_channels: int) -> None:
|
45 |
+
super().__init__()
|
46 |
+
self.net = nn.Sequential(
|
47 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
48 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
49 |
+
nn.BatchNorm2d(out_channels),
|
50 |
+
nn.LeakyReLU(0.2, True),
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
54 |
+
x = self.net(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
class Detector(nn.Module):
|
59 |
+
def __init__(self, hyper_paras: pl.utilities.parsing.AttributeDict) -> None:
|
60 |
+
super().__init__()
|
61 |
+
self.n_parts = hyper_paras.n_parts
|
62 |
+
self.output_size = 32
|
63 |
+
|
64 |
+
self.conv = nn.Sequential(
|
65 |
+
ResBlock(3, 64), # 64
|
66 |
+
ResBlock(64, 128), # 32
|
67 |
+
ResBlock(128, 256), # 16
|
68 |
+
ResBlock(256, 512), # 8
|
69 |
+
TransposedBlock(512, 256), # 16
|
70 |
+
TransposedBlock(256, 128), # 32
|
71 |
+
nn.Conv2d(128, self.n_parts, kernel_size=3, padding=1),
|
72 |
+
)
|
73 |
+
|
74 |
+
grid = gen_grid2d(self.output_size).reshape(1, 1, self.output_size ** 2, 2)
|
75 |
+
self.coord = nn.Parameter(grid, requires_grad=False)
|
76 |
+
|
77 |
+
def forward(self, input_dict: dict) -> dict:
|
78 |
+
img = F.interpolate(input_dict['img'], size=(128, 128), mode='bilinear', align_corners=False)
|
79 |
+
prob_map = self.conv(img).reshape(img.shape[0], self.n_parts, -1, 1)
|
80 |
+
prob_map = F.softmax(prob_map, dim=2)
|
81 |
+
keypoints = self.coord * prob_map
|
82 |
+
keypoints = keypoints.sum(dim=2)
|
83 |
+
prob_map = prob_map.reshape(keypoints.shape[0], self.n_parts, self.output_size, self.output_size)
|
84 |
+
return {'keypoints': keypoints, 'prob_map': prob_map}
|
85 |
+
|
86 |
+
|
87 |
+
class Encoder(nn.Module):
|
88 |
+
def __init__(self, hyper_paras: pl.utilities.parsing.AttributeDict) -> None:
|
89 |
+
super().__init__()
|
90 |
+
self.detector = Detector(hyper_paras)
|
91 |
+
self.missing = 0.8 # hyper_paras.missing
|
92 |
+
self.block = 16 # hyper_paras.block
|
93 |
+
|
94 |
+
for m in self.modules():
|
95 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
96 |
+
nn.init.kaiming_normal_(m.weight, a=0.2)
|
97 |
+
if m.bias is not None:
|
98 |
+
m.bias.data.zero_()
|
99 |
+
|
100 |
+
def forward(self, input_dict: dict, need_masked_img: bool=False) -> dict:
|
101 |
+
mask_batch = self.detector(input_dict)
|
102 |
+
if need_masked_img:
|
103 |
+
damage_mask = torch.zeros(input_dict['img'].shape[0], 1, self.block, self.block, device=input_dict['img'].device).uniform_() > self.missing
|
104 |
+
damage_mask = F.interpolate(damage_mask.to(input_dict['img']), size=input_dict['img'].shape[-1], mode='nearest')
|
105 |
+
mask_batch['damaged_img'] = input_dict['img'] * damage_mask
|
106 |
+
return mask_batch
|
models/model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import PIL
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch.utils.data
|
5 |
+
import wandb
|
6 |
+
from typing import Union
|
7 |
+
from torchvision import transforms
|
8 |
+
from utils_.loss import VGGPerceptualLoss
|
9 |
+
from utils_.visualization import *
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
|
13 |
+
|
14 |
+
class Model(pl.LightningModule):
|
15 |
+
def __init__(self, **kwargs):
|
16 |
+
super().__init__()
|
17 |
+
self.save_hyperparameters()
|
18 |
+
self.encoder = importlib.import_module('models.' + self.hparams.encoder).Encoder(self.hparams)
|
19 |
+
self.decoder = importlib.import_module('models.' + self.hparams.decoder).Decoder(self.hparams)
|
20 |
+
self.batch_size = self.hparams.batch_size
|
21 |
+
|
22 |
+
self.vgg_loss = VGGPerceptualLoss()
|
23 |
+
|
24 |
+
self.transform = transforms.Compose([
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize(0.5, 0.5)
|
27 |
+
])
|
28 |
+
|
29 |
+
def forward(self, x: PIL.Image.Image) -> PIL.Image.Image:
|
30 |
+
"""
|
31 |
+
:param x: a PIL image
|
32 |
+
:return: an edge map of the same size as x with values in [0, 1] (normalized by max)
|
33 |
+
"""
|
34 |
+
w, h = x.size
|
35 |
+
x = self.transform(x).unsqueeze(0)
|
36 |
+
x = x.to(self.device)
|
37 |
+
kp = self.encoder({'img': x})['keypoints']
|
38 |
+
edge_map = self.decoder.rasterize(kp, output_size=64)
|
39 |
+
bs = edge_map.shape[0]
|
40 |
+
edge_map = edge_map / (1e-8 + edge_map.reshape(bs, 1, -1).max(dim=2, keepdim=True)[0].reshape(bs, 1, 1, 1))
|
41 |
+
edge_map = torch.cat([edge_map] * 3, dim=1)
|
42 |
+
edge_map = F.interpolate(edge_map, size=(h, w), mode='bilinear', align_corners=False)
|
43 |
+
x = torch.clamp(edge_map + (x * 0.5 + 0.5)*0.5, min=0, max=1)
|
44 |
+
x = transforms.ToPILImage()(x[0].detach().cpu())
|
45 |
+
|
46 |
+
fig = plt.figure(figsize=(1, h/w), dpi=w)
|
47 |
+
fig.tight_layout(pad=0)
|
48 |
+
plt.axis('off')
|
49 |
+
plt.imshow(x)
|
50 |
+
kp = kp[0].detach().cpu() * 0.5 + 0.5
|
51 |
+
kp[:, 1] *= w
|
52 |
+
kp[:, 0] *= h
|
53 |
+
plt.scatter(kp[:, 1], kp[:, 0], s=min(w/h, min(1, h/w)), marker='o')
|
54 |
+
ncols, nrows = fig.canvas.get_width_height()
|
55 |
+
fig.canvas.draw()
|
56 |
+
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
|
57 |
+
plt.close(fig)
|
58 |
+
return plot
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch==1.13.1
|
3 |
+
torchvision
|
4 |
+
matplotlib
|
5 |
+
scipy
|
6 |
+
h5py
|
7 |
+
pandas
|
8 |
+
kornia
|
9 |
+
wandb
|
10 |
+
pytorch-lightning==1.5.10
|
11 |
+
seaborn
|
12 |
+
scikit-learn
|
13 |
+
imageio
|
14 |
+
imageio-ffmpeg
|
15 |
+
gradio
|
utils_/loss.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision
|
6 |
+
|
7 |
+
|
8 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super(VGGPerceptualLoss, self).__init__()
|
11 |
+
os.environ['TORCH_HOME'] = os.path.abspath(os.getcwd())
|
12 |
+
blocks = [torchvision.models.vgg16().features[:4].eval(),
|
13 |
+
torchvision.models.vgg16().features[4:9].eval(),
|
14 |
+
torchvision.models.vgg16().features[9:16].eval(),
|
15 |
+
torchvision.models.vgg16().features[16:23].eval()]
|
16 |
+
for bl in blocks:
|
17 |
+
for p in bl.parameters():
|
18 |
+
p.requires_grad = False
|
19 |
+
self.blocks = torch.nn.ModuleList(blocks)
|
20 |
+
|
21 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
22 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
23 |
+
|
24 |
+
def forward(self, x, y):
|
25 |
+
x = x * 0.5 + 0.5
|
26 |
+
y = y * 0.5 + 0.5
|
27 |
+
x = (x - self.mean) / self.std
|
28 |
+
y = (y - self.mean) / self.std
|
29 |
+
|
30 |
+
x = F.interpolate(x, mode='bilinear', size=(224, 224), align_corners=False)
|
31 |
+
y = F.interpolate(y, mode='bilinear', size=(224, 224), align_corners=False)
|
32 |
+
perceptual_loss = 0.0
|
33 |
+
style_loss = 0.0
|
34 |
+
|
35 |
+
for i, block in enumerate(self.blocks):
|
36 |
+
x = block(x)
|
37 |
+
y = block(y)
|
38 |
+
|
39 |
+
perceptual_loss += torch.nn.functional.l1_loss(x, y)
|
40 |
+
|
41 |
+
# b, ch, h, w = x.shape
|
42 |
+
# act_x = x.reshape(x.shape[0], x.shape[1], -1)
|
43 |
+
# act_y = y.reshape(y.shape[0], y.shape[1], -1)
|
44 |
+
# gram_x = act_x @ act_x.permute(0, 2, 1) / (ch * h * w)
|
45 |
+
# gram_y = act_y @ act_y.permute(0, 2, 1) / (ch * h * w)
|
46 |
+
# style_loss += torch.nn.functional.l1_loss(gram_x, gram_y)
|
47 |
+
|
48 |
+
return perceptual_loss#, style_loss
|
utils_/visualization.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.gridspec as gridspec
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import seaborn as sns
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
from matplotlib import colors
|
8 |
+
|
9 |
+
|
10 |
+
def get_part_color(n_parts):
|
11 |
+
colormap = ('red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen',
|
12 |
+
'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid',
|
13 |
+
'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson',
|
14 |
+
'honeydew', 'thistle',
|
15 |
+
'red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen',
|
16 |
+
'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid',
|
17 |
+
'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson',
|
18 |
+
'honeydew', 'thistle')[:n_parts]
|
19 |
+
part_color = []
|
20 |
+
for i in range(n_parts):
|
21 |
+
part_color.append(colors.to_rgb(colormap[i]))
|
22 |
+
part_color = np.array(part_color)
|
23 |
+
|
24 |
+
return part_color
|
25 |
+
|
26 |
+
|
27 |
+
def denormalize(img):
|
28 |
+
mean = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1)
|
29 |
+
std = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1)
|
30 |
+
img = img * std + mean
|
31 |
+
img = torch.clamp(img, min=0, max=1)
|
32 |
+
return img
|
33 |
+
|
34 |
+
|
35 |
+
def draw_matrix(mat):
|
36 |
+
fig = plt.figure()
|
37 |
+
sns.heatmap(mat, annot=True, fmt='.2f', cmap="YlGnBu")
|
38 |
+
|
39 |
+
ncols, nrows = fig.canvas.get_width_height()
|
40 |
+
fig.canvas.draw()
|
41 |
+
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
|
42 |
+
plt.close(fig)
|
43 |
+
return plot
|
44 |
+
|
45 |
+
|
46 |
+
def draw_kp_grid(img, kp):
|
47 |
+
kp_color = get_part_color(kp.shape[1])
|
48 |
+
img = img[:64].permute(0, 2, 3, 1).detach().cpu()
|
49 |
+
kp = kp.detach().cpu()[:64]
|
50 |
+
|
51 |
+
fig = plt.figure(figsize=(8, 8))
|
52 |
+
gs = gridspec.GridSpec(8, 8)
|
53 |
+
gs.update(wspace=0, hspace=0)
|
54 |
+
|
55 |
+
for i, sample in enumerate(img):
|
56 |
+
ax = plt.subplot(gs[i])
|
57 |
+
plt.axis('off')
|
58 |
+
ax.set_xticklabels([])
|
59 |
+
ax.set_yticklabels([])
|
60 |
+
ax.imshow(sample, vmin=0, vmax=1)
|
61 |
+
ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+')
|
62 |
+
|
63 |
+
ncols, nrows = fig.canvas.get_width_height()
|
64 |
+
fig.canvas.draw()
|
65 |
+
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
|
66 |
+
plt.close(fig)
|
67 |
+
return plot
|
68 |
+
|
69 |
+
|
70 |
+
def draw_kp_grid_unnorm(img, kp):
|
71 |
+
kp_color = get_part_color(kp.shape[1])
|
72 |
+
img = img[:64].permute(0, 2, 3, 1).detach().cpu()
|
73 |
+
kp = kp.detach().cpu()[:64]
|
74 |
+
|
75 |
+
fig = plt.figure(figsize=(8, 8))
|
76 |
+
gs = gridspec.GridSpec(8, 8)
|
77 |
+
gs.update(wspace=0, hspace=0)
|
78 |
+
|
79 |
+
for i, sample in enumerate(img):
|
80 |
+
ax = plt.subplot(gs[i])
|
81 |
+
plt.axis('off')
|
82 |
+
ax.set_xticklabels([])
|
83 |
+
ax.set_yticklabels([])
|
84 |
+
ax.imshow(sample)
|
85 |
+
ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+')
|
86 |
+
|
87 |
+
ncols, nrows = fig.canvas.get_width_height()
|
88 |
+
fig.canvas.draw()
|
89 |
+
plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
|
90 |
+
plt.close(fig)
|
91 |
+
return plot
|
92 |
+
|
93 |
+
|
94 |
+
def draw_img_grid(img):
|
95 |
+
img = img[:64].detach().cpu()
|
96 |
+
nrow = min(8, img.shape[0])
|
97 |
+
img = torchvision.utils.make_grid(img[:64], nrow=nrow).permute(1, 2, 0)
|
98 |
+
return torch.clamp(img * 255, min=0, max=255).numpy().astype(np.uint8)
|