xingzhehe commited on
Commit
91fc62a
1 Parent(s): 354ef90

try fitst commit

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ checkpoints
2
+ diffusers_cache
3
+ hub
4
+ wandb
5
+ __pycache__
6
+ *.pyc
7
+ flagged
8
+ gif
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.model import Model as AutoLink
2
+ import gradio as gr
3
+ import PIL
4
+ import torch
5
+ import os
6
+ import imageio
7
+ import numpy as np
8
+
9
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+
11
+
12
+ autolink = AutoLink.load_from_checkpoint(os.path.join("checkpoints", "celeba_wild_k32_m0.8_b16_t0.00075_sklr512", "model.ckpt"))
13
+ autolink.to(device)
14
+
15
+
16
+ def predict_image(image_in: PIL.Image.Image) -> PIL.Image.Image:
17
+ if image_in == None:
18
+ raise gr.Error("Please upload a video or image.")
19
+ edge_map = autolink(image_in)
20
+ return edge_map
21
+
22
+
23
+ def predict_video(video_in: str) -> str:
24
+ if video_in == None:
25
+ raise gr.Error("Please upload a video or image.")
26
+ video_out = video_in[:-4] + '_out.mp4'
27
+ video_in = imageio.get_reader(video_in)
28
+ writer = imageio.get_writer(video_out, mode='I', fps=video_in.get_meta_data()['fps'])
29
+ for image_in in video_in:
30
+ image_in = PIL.Image.fromarray(image_in)
31
+ edge_map = autolink(image_in)
32
+ writer.append_data(np.array(edge_map))
33
+ writer.close()
34
+ return video_out
35
+
36
+
37
+ with gr.Blocks() as blocks:
38
+ gr.Markdown("""
39
+ # AutoLink
40
+ ## Self-supervised Learning of Human Skeletons and Object Outlines by Linking Keypoints
41
+ * [Paper](https://arxiv.org/abs/2205.10636)
42
+ * [Project Page](https://xingzhehe.github.io/autolink/)
43
+ * [GitHub](https://github.com/xingzhehe/AutoLink-Self-supervised-Learning-of-Human-Skeletons-and-Object-Outlines-by-Linking-Keypoints)
44
+ """)
45
+
46
+ with gr.Tab("Image"):
47
+ with gr.Row():
48
+ with gr.Column():
49
+ image_in = gr.Image(source="upload", type="pil", visible=True)
50
+ with gr.Column():
51
+ image_out = gr.Image()
52
+ run_btn = gr.Button("Run")
53
+ run_btn.click(fn=predict_image, inputs=[image_in], outputs=[image_out])
54
+ gr.Examples(fn=predict_image, examples=[["assets/jackie_chan.jpg", None]],
55
+ inputs=[image_in], outputs=[image_out],
56
+ cache_examples=False)
57
+
58
+ with gr.Tab("Video") as tab:
59
+ with gr.Row():
60
+ with gr.Column():
61
+ video_in = gr.Video(source="upload", type="mp4")
62
+ with gr.Column():
63
+ video_out = gr.Video()
64
+ run_btn = gr.Button("Run")
65
+ run_btn.click(fn=predict_video, inputs=[video_in], outputs=[video_out])
66
+ gr.Examples(fn=predict_video, examples=[["assets/00344.mp4"],],
67
+ inputs=[video_in], outputs=[video_out],
68
+ cache_examples=False)
69
+
70
+ blocks.launch()
71
+
assets/00344.mp4 ADDED
Binary file (165 kB). View file
 
assets/jackie_chan.jpg ADDED
models/decoder.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from typing import Union
5
+ import pytorch_lightning as pl
6
+
7
+
8
+ def gen_grid2d(grid_size: int, left_end: float=-1, right_end: float=1) -> torch.Tensor:
9
+ """
10
+ Generate a grid of size (grid_size, grid_size, 2) with coordinate values in the range [left_end, right_end]
11
+ """
12
+ x = torch.linspace(left_end, right_end, grid_size)
13
+ x, y = torch.meshgrid([x, x], indexing='ij')
14
+ grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(grid_size, grid_size, 2)
15
+ return grid
16
+
17
+
18
+ def draw_lines(paired_joints: torch.Tensor, heatmap_size: int=16, thick: Union[float, torch.Tensor]=1e-2) -> torch.Tensor:
19
+ """
20
+ Draw lines on a grid.
21
+ :param paired_joints: (batch_size, n_points, 2, 2)
22
+ :return: (batch_size, n_points, grid_size, grid_size)
23
+ dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
24
+ """
25
+ bs, n_points, _, _ = paired_joints.shape
26
+ start = paired_joints[:, :, 0, :] # (batch_size, n_points, 2)
27
+ end = paired_joints[:, :, 1, :] # (batch_size, n_points, 2)
28
+ paired_diff = end - start # (batch_size, n_points, 2)
29
+ grid = gen_grid2d(heatmap_size).to(paired_joints.device).reshape(1, 1, -1, 2)
30
+ diff_to_start = grid - start.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
31
+ # (batch_size, n_points, heatmap_size**2)
32
+ t = (diff_to_start @ paired_diff.unsqueeze(-1)).squeeze(-1) / (1e-8+paired_diff.square().sum(dim=-1, keepdim=True))
33
+
34
+ diff_to_end = grid - end.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
35
+
36
+ before_start = (t <= 0).float() * diff_to_start.square().sum(dim=-1)
37
+ after_end = (t >= 1).float() * diff_to_end.square().sum(dim=-1)
38
+ between_start_end = (0 < t).float() * (t < 1).float() * (grid - (start.unsqueeze(-2) + t.unsqueeze(-1) * paired_diff.unsqueeze(-2))).square().sum(dim=-1)
39
+
40
+ squared_dist = (before_start + after_end + between_start_end).reshape(bs, n_points, heatmap_size, heatmap_size)
41
+ heatmaps = torch.exp(- squared_dist / thick)
42
+ return heatmaps
43
+
44
+
45
+ class DownBlock(nn.Module):
46
+ def __init__(self, in_channels: int, out_channels: int) -> None:
47
+ super().__init__()
48
+ self.net = nn.Sequential(
49
+ nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1),
50
+ nn.BatchNorm2d(out_channels),
51
+ nn.LeakyReLU(0.2, True),
52
+ nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1),
53
+ nn.BatchNorm2d(out_channels),
54
+ nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
55
+ nn.LeakyReLU(0.2, True),
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
59
+ x = self.net(x)
60
+ return x
61
+
62
+
63
+ class UpBlock(nn.Module):
64
+ def __init__(self, in_channels: int, out_channels: int) -> None:
65
+ super().__init__()
66
+ self.net = nn.Sequential(
67
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
68
+ nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1),
69
+ nn.BatchNorm2d(out_channels),
70
+ nn.LeakyReLU(0.2, True),
71
+ nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1),
72
+ nn.BatchNorm2d(out_channels),
73
+ nn.LeakyReLU(0.2, True),
74
+ )
75
+
76
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
77
+ x = self.net(x)
78
+ return x
79
+
80
+
81
+ class Decoder(nn.Module):
82
+ def __init__(self, hyper_paras: pl.LightningModule.hparams) -> None:
83
+ super().__init__()
84
+ self.n_parts = hyper_paras['n_parts']
85
+ self.thick = hyper_paras['thick']
86
+ self.sklr = hyper_paras['sklr']
87
+ self.skeleton_idx = torch.triu_indices(self.n_parts, self.n_parts, offset=1)
88
+ self.n_skeleton = len(self.skeleton_idx[0])
89
+
90
+ self.alpha = nn.Parameter(torch.tensor(1.0), requires_grad=True)
91
+
92
+ skeleton_scalar = (torch.randn(self.n_parts, self.n_parts) / 10 - 4) / self.sklr
93
+ self.skeleton_scalar = nn.Parameter(skeleton_scalar, requires_grad=True)
94
+
95
+ self.down0 = nn.Sequential(
96
+ nn.Conv2d(3 + 1, 64, kernel_size=(3, 3), padding=1),
97
+ nn.LeakyReLU(0.2, True),
98
+ )
99
+
100
+ self.down1 = DownBlock(64, 128) # 64
101
+ self.down2 = DownBlock(128, 256) # 32
102
+ self.down3 = DownBlock(256, 512) # 16
103
+ self.down4 = DownBlock(512, 512) # 8
104
+
105
+ self.up1 = UpBlock(512, 512) # 16
106
+ self.up2 = UpBlock(512 + 512, 256) # 32
107
+ self.up3 = UpBlock(256 + 256, 128) # 64
108
+ self.up4 = UpBlock(128 + 128, 64) # 64
109
+
110
+ self.conv = nn.Conv2d(64+64, 3, kernel_size=(3, 3), padding=1)
111
+
112
+ for m in self.modules():
113
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
114
+ nn.init.kaiming_normal_(m.weight, a=0.2)
115
+ if m.bias is not None:
116
+ m.bias.data.zero_()
117
+
118
+ def skeleton_scalar_matrix(self) -> torch.Tensor:
119
+ """
120
+ Give the skeleton scalar matrix
121
+ :return: (n_parts, n_parts)
122
+ """
123
+ skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
124
+ skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
125
+ skeleton_scalar = skeleton_scalar + skeleton_scalar.transpose(1, 0)
126
+ return skeleton_scalar
127
+
128
+ def rasterize(self, keypoints: torch.Tensor, output_size: int=128) -> torch.Tensor:
129
+ """
130
+ Generate edge heatmap from keypoints, where edges are weighted by the learned scalars.
131
+ :param keypoints: (batch_size, n_points, 2)
132
+ :return: (batch_size, 1, heatmap_size, heatmap_size)
133
+ """
134
+
135
+ paired_joints = torch.stack([keypoints[:, self.skeleton_idx[0], :2], keypoints[:, self.skeleton_idx[1], :2]], dim=2)
136
+
137
+ skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
138
+ skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
139
+ skeleton_scalar = skeleton_scalar[self.skeleton_idx[0], self.skeleton_idx[1]].reshape(1, self.n_skeleton, 1, 1)
140
+
141
+ skeleton_heatmap_sep = draw_lines(paired_joints, heatmap_size=output_size, thick=self.thick)
142
+ skeleton_heatmap_sep = skeleton_heatmap_sep * skeleton_scalar.reshape(1, self.n_skeleton, 1, 1)
143
+ skeleton_heatmap = skeleton_heatmap_sep.max(dim=1, keepdim=True)[0]
144
+ return skeleton_heatmap
145
+
146
+ def forward(self, input_dict: dict) -> dict:
147
+ skeleton_heatmap = self.rasterize(input_dict['keypoints'])
148
+
149
+ x = torch.cat([input_dict['damaged_img'] * self.alpha, skeleton_heatmap], dim=1)
150
+
151
+ down_128 = self.down0(x)
152
+ down_64 = self.down1(down_128)
153
+ down_32 = self.down2(down_64)
154
+ down_16 = self.down3(down_32)
155
+ down_8 = self.down4(down_16)
156
+ up_8 = down_8
157
+ up_16 = torch.cat([self.up1(up_8), down_16], dim=1)
158
+ up_32 = torch.cat([self.up2(up_16), down_32], dim=1)
159
+ up_64 = torch.cat([self.up3(up_32), down_64], dim=1)
160
+ up_128 = torch.cat([self.up4(up_64), down_128], dim=1)
161
+ img = self.conv(up_128)
162
+
163
+ input_dict['heatmap'] = skeleton_heatmap
164
+ input_dict['img'] = img
165
+ return input_dict
166
+
167
+
168
+ if __name__ == '__main__':
169
+ model = Decoder({'z_dim': 256, 'n_parts': 10, 'n_embedding': 128, 'tau': 0.01})
170
+ print(sum(p.numel() for p in model.parameters() if p.requires_grad))
models/encoder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ import pytorch_lightning as pl
5
+
6
+
7
+ def gen_grid2d(grid_size: int, left_end: float=-1, right_end: float=1) -> torch.Tensor:
8
+ """
9
+ Generate a grid of size (grid_size, grid_size, 2) with coordinate values in the range [left_end, right_end]
10
+ """
11
+ x = torch.linspace(left_end, right_end, grid_size)
12
+ x, y = torch.meshgrid([x, x], indexing='ij')
13
+ grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(grid_size, grid_size, 2)
14
+ return grid
15
+
16
+
17
+ class ResBlock(nn.Module):
18
+ def __init__(self, in_channels: int, out_channels: int) -> None:
19
+ super().__init__()
20
+ self.conv_res = nn.Sequential(
21
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
22
+ nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
23
+ nn.BatchNorm2d(out_channels)
24
+ )
25
+
26
+ self.net = nn.Sequential(
27
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
28
+ nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
29
+ nn.BatchNorm2d(out_channels),
30
+ nn.LeakyReLU(0.2, True),
31
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
32
+ nn.BatchNorm2d(out_channels)
33
+ )
34
+
35
+ self.relu = nn.LeakyReLU(0.2, True)
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ res = self.conv_res(x)
39
+ x = self.net(x)
40
+ return self.relu(x + res)
41
+
42
+
43
+ class TransposedBlock(nn.Module):
44
+ def __init__(self, in_channels: int, out_channels: int) -> None:
45
+ super().__init__()
46
+ self.net = nn.Sequential(
47
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
48
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
49
+ nn.BatchNorm2d(out_channels),
50
+ nn.LeakyReLU(0.2, True),
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ x = self.net(x)
55
+ return x
56
+
57
+
58
+ class Detector(nn.Module):
59
+ def __init__(self, hyper_paras: pl.utilities.parsing.AttributeDict) -> None:
60
+ super().__init__()
61
+ self.n_parts = hyper_paras.n_parts
62
+ self.output_size = 32
63
+
64
+ self.conv = nn.Sequential(
65
+ ResBlock(3, 64), # 64
66
+ ResBlock(64, 128), # 32
67
+ ResBlock(128, 256), # 16
68
+ ResBlock(256, 512), # 8
69
+ TransposedBlock(512, 256), # 16
70
+ TransposedBlock(256, 128), # 32
71
+ nn.Conv2d(128, self.n_parts, kernel_size=3, padding=1),
72
+ )
73
+
74
+ grid = gen_grid2d(self.output_size).reshape(1, 1, self.output_size ** 2, 2)
75
+ self.coord = nn.Parameter(grid, requires_grad=False)
76
+
77
+ def forward(self, input_dict: dict) -> dict:
78
+ img = F.interpolate(input_dict['img'], size=(128, 128), mode='bilinear', align_corners=False)
79
+ prob_map = self.conv(img).reshape(img.shape[0], self.n_parts, -1, 1)
80
+ prob_map = F.softmax(prob_map, dim=2)
81
+ keypoints = self.coord * prob_map
82
+ keypoints = keypoints.sum(dim=2)
83
+ prob_map = prob_map.reshape(keypoints.shape[0], self.n_parts, self.output_size, self.output_size)
84
+ return {'keypoints': keypoints, 'prob_map': prob_map}
85
+
86
+
87
+ class Encoder(nn.Module):
88
+ def __init__(self, hyper_paras: pl.utilities.parsing.AttributeDict) -> None:
89
+ super().__init__()
90
+ self.detector = Detector(hyper_paras)
91
+ self.missing = 0.8 # hyper_paras.missing
92
+ self.block = 16 # hyper_paras.block
93
+
94
+ for m in self.modules():
95
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
96
+ nn.init.kaiming_normal_(m.weight, a=0.2)
97
+ if m.bias is not None:
98
+ m.bias.data.zero_()
99
+
100
+ def forward(self, input_dict: dict, need_masked_img: bool=False) -> dict:
101
+ mask_batch = self.detector(input_dict)
102
+ if need_masked_img:
103
+ damage_mask = torch.zeros(input_dict['img'].shape[0], 1, self.block, self.block, device=input_dict['img'].device).uniform_() > self.missing
104
+ damage_mask = F.interpolate(damage_mask.to(input_dict['img']), size=input_dict['img'].shape[-1], mode='nearest')
105
+ mask_batch['damaged_img'] = input_dict['img'] * damage_mask
106
+ return mask_batch
models/model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import PIL
3
+ import pytorch_lightning as pl
4
+ import torch.utils.data
5
+ import wandb
6
+ from typing import Union
7
+ from torchvision import transforms
8
+ from utils_.loss import VGGPerceptualLoss
9
+ from utils_.visualization import *
10
+ import torch.nn.functional as F
11
+ import matplotlib.pyplot as plt
12
+
13
+
14
+ class Model(pl.LightningModule):
15
+ def __init__(self, **kwargs):
16
+ super().__init__()
17
+ self.save_hyperparameters()
18
+ self.encoder = importlib.import_module('models.' + self.hparams.encoder).Encoder(self.hparams)
19
+ self.decoder = importlib.import_module('models.' + self.hparams.decoder).Decoder(self.hparams)
20
+ self.batch_size = self.hparams.batch_size
21
+
22
+ self.vgg_loss = VGGPerceptualLoss()
23
+
24
+ self.transform = transforms.Compose([
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(0.5, 0.5)
27
+ ])
28
+
29
+ def forward(self, x: PIL.Image.Image) -> PIL.Image.Image:
30
+ """
31
+ :param x: a PIL image
32
+ :return: an edge map of the same size as x with values in [0, 1] (normalized by max)
33
+ """
34
+ w, h = x.size
35
+ x = self.transform(x).unsqueeze(0)
36
+ x = x.to(self.device)
37
+ kp = self.encoder({'img': x})['keypoints']
38
+ edge_map = self.decoder.rasterize(kp, output_size=64)
39
+ bs = edge_map.shape[0]
40
+ edge_map = edge_map / (1e-8 + edge_map.reshape(bs, 1, -1).max(dim=2, keepdim=True)[0].reshape(bs, 1, 1, 1))
41
+ edge_map = torch.cat([edge_map] * 3, dim=1)
42
+ edge_map = F.interpolate(edge_map, size=(h, w), mode='bilinear', align_corners=False)
43
+ x = torch.clamp(edge_map + (x * 0.5 + 0.5)*0.5, min=0, max=1)
44
+ x = transforms.ToPILImage()(x[0].detach().cpu())
45
+
46
+ fig = plt.figure(figsize=(1, h/w), dpi=w)
47
+ fig.tight_layout(pad=0)
48
+ plt.axis('off')
49
+ plt.imshow(x)
50
+ kp = kp[0].detach().cpu() * 0.5 + 0.5
51
+ kp[:, 1] *= w
52
+ kp[:, 0] *= h
53
+ plt.scatter(kp[:, 1], kp[:, 0], s=min(w/h, min(1, h/w)), marker='o')
54
+ ncols, nrows = fig.canvas.get_width_height()
55
+ fig.canvas.draw()
56
+ plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
57
+ plt.close(fig)
58
+ return plot
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch==1.13.1
3
+ torchvision
4
+ matplotlib
5
+ scipy
6
+ h5py
7
+ pandas
8
+ kornia
9
+ wandb
10
+ pytorch-lightning==1.5.10
11
+ seaborn
12
+ scikit-learn
13
+ imageio
14
+ imageio-ffmpeg
15
+ gradio
utils_/loss.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchvision
6
+
7
+
8
+ class VGGPerceptualLoss(torch.nn.Module):
9
+ def __init__(self):
10
+ super(VGGPerceptualLoss, self).__init__()
11
+ os.environ['TORCH_HOME'] = os.path.abspath(os.getcwd())
12
+ blocks = [torchvision.models.vgg16().features[:4].eval(),
13
+ torchvision.models.vgg16().features[4:9].eval(),
14
+ torchvision.models.vgg16().features[9:16].eval(),
15
+ torchvision.models.vgg16().features[16:23].eval()]
16
+ for bl in blocks:
17
+ for p in bl.parameters():
18
+ p.requires_grad = False
19
+ self.blocks = torch.nn.ModuleList(blocks)
20
+
21
+ self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
22
+ self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
23
+
24
+ def forward(self, x, y):
25
+ x = x * 0.5 + 0.5
26
+ y = y * 0.5 + 0.5
27
+ x = (x - self.mean) / self.std
28
+ y = (y - self.mean) / self.std
29
+
30
+ x = F.interpolate(x, mode='bilinear', size=(224, 224), align_corners=False)
31
+ y = F.interpolate(y, mode='bilinear', size=(224, 224), align_corners=False)
32
+ perceptual_loss = 0.0
33
+ style_loss = 0.0
34
+
35
+ for i, block in enumerate(self.blocks):
36
+ x = block(x)
37
+ y = block(y)
38
+
39
+ perceptual_loss += torch.nn.functional.l1_loss(x, y)
40
+
41
+ # b, ch, h, w = x.shape
42
+ # act_x = x.reshape(x.shape[0], x.shape[1], -1)
43
+ # act_y = y.reshape(y.shape[0], y.shape[1], -1)
44
+ # gram_x = act_x @ act_x.permute(0, 2, 1) / (ch * h * w)
45
+ # gram_y = act_y @ act_y.permute(0, 2, 1) / (ch * h * w)
46
+ # style_loss += torch.nn.functional.l1_loss(gram_x, gram_y)
47
+
48
+ return perceptual_loss#, style_loss
utils_/visualization.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.gridspec as gridspec
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import seaborn as sns
5
+ import torch
6
+ import torchvision
7
+ from matplotlib import colors
8
+
9
+
10
+ def get_part_color(n_parts):
11
+ colormap = ('red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen',
12
+ 'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid',
13
+ 'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson',
14
+ 'honeydew', 'thistle',
15
+ 'red', 'blue', 'yellow', 'magenta', 'green', 'indigo', 'darkorange', 'cyan', 'pink', 'yellowgreen',
16
+ 'rosybrown', 'coral', 'chocolate', 'bisque', 'gold', 'yellowgreen', 'aquamarine', 'deepskyblue', 'navy', 'orchid',
17
+ 'maroon', 'sienna', 'olive', 'lightgreen', 'teal', 'steelblue', 'slateblue', 'darkviolet', 'fuchsia', 'crimson',
18
+ 'honeydew', 'thistle')[:n_parts]
19
+ part_color = []
20
+ for i in range(n_parts):
21
+ part_color.append(colors.to_rgb(colormap[i]))
22
+ part_color = np.array(part_color)
23
+
24
+ return part_color
25
+
26
+
27
+ def denormalize(img):
28
+ mean = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1)
29
+ std = torch.tensor((0.5, 0.5, 0.5), device=img.device).reshape(1, 3, 1, 1)
30
+ img = img * std + mean
31
+ img = torch.clamp(img, min=0, max=1)
32
+ return img
33
+
34
+
35
+ def draw_matrix(mat):
36
+ fig = plt.figure()
37
+ sns.heatmap(mat, annot=True, fmt='.2f', cmap="YlGnBu")
38
+
39
+ ncols, nrows = fig.canvas.get_width_height()
40
+ fig.canvas.draw()
41
+ plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
42
+ plt.close(fig)
43
+ return plot
44
+
45
+
46
+ def draw_kp_grid(img, kp):
47
+ kp_color = get_part_color(kp.shape[1])
48
+ img = img[:64].permute(0, 2, 3, 1).detach().cpu()
49
+ kp = kp.detach().cpu()[:64]
50
+
51
+ fig = plt.figure(figsize=(8, 8))
52
+ gs = gridspec.GridSpec(8, 8)
53
+ gs.update(wspace=0, hspace=0)
54
+
55
+ for i, sample in enumerate(img):
56
+ ax = plt.subplot(gs[i])
57
+ plt.axis('off')
58
+ ax.set_xticklabels([])
59
+ ax.set_yticklabels([])
60
+ ax.imshow(sample, vmin=0, vmax=1)
61
+ ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+')
62
+
63
+ ncols, nrows = fig.canvas.get_width_height()
64
+ fig.canvas.draw()
65
+ plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
66
+ plt.close(fig)
67
+ return plot
68
+
69
+
70
+ def draw_kp_grid_unnorm(img, kp):
71
+ kp_color = get_part_color(kp.shape[1])
72
+ img = img[:64].permute(0, 2, 3, 1).detach().cpu()
73
+ kp = kp.detach().cpu()[:64]
74
+
75
+ fig = plt.figure(figsize=(8, 8))
76
+ gs = gridspec.GridSpec(8, 8)
77
+ gs.update(wspace=0, hspace=0)
78
+
79
+ for i, sample in enumerate(img):
80
+ ax = plt.subplot(gs[i])
81
+ plt.axis('off')
82
+ ax.set_xticklabels([])
83
+ ax.set_yticklabels([])
84
+ ax.imshow(sample)
85
+ ax.scatter(kp[i, :, 1], kp[i, :, 0], c=kp_color, s=20, marker='+')
86
+
87
+ ncols, nrows = fig.canvas.get_width_height()
88
+ fig.canvas.draw()
89
+ plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(nrows, ncols, 3)
90
+ plt.close(fig)
91
+ return plot
92
+
93
+
94
+ def draw_img_grid(img):
95
+ img = img[:64].detach().cpu()
96
+ nrow = min(8, img.shape[0])
97
+ img = torchvision.utils.make_grid(img[:64], nrow=nrow).permute(1, 2, 0)
98
+ return torch.clamp(img * 255, min=0, max=255).numpy().astype(np.uint8)