Spaces:
Sleeping
Sleeping
hasibzunair
commited on
Commit
•
46fdf2a
1
Parent(s):
93775f8
inital files
Browse files- 000001.jpg +0 -0
- 000006.jpg +0 -0
- 000009.jpg +0 -0
- app.py +97 -0
- pipeline/csra.py +55 -0
- pipeline/dataset.py +255 -0
- pipeline/losses.py +56 -0
- pipeline/models/tresnet/layers/anti_aliasing.py +60 -0
- pipeline/models/tresnet/layers/avg_pool.py +19 -0
- pipeline/models/tresnet/layers/general_layers.py +93 -0
- pipeline/models/tresnet/tresnet.py +268 -0
- pipeline/models/utils/__init__.py +2 -0
- pipeline/models/utils/factory.py +25 -0
- pipeline/resnet_csra.py +94 -0
- pipeline/timm_utils/__init__.py +4 -0
- pipeline/timm_utils/drop.py +168 -0
- pipeline/timm_utils/tuple.py +27 -0
- pipeline/timm_utils/weight_init.py +60 -0
- pipeline/vit_csra.py +303 -0
- requirements.txt +3 -0
- utils/demo_images/000001.jpg +0 -0
- utils/demo_images/000002.jpg +0 -0
- utils/demo_images/000004.jpg +0 -0
- utils/demo_images/000006.jpg +0 -0
- utils/demo_images/000007.jpg +0 -0
- utils/demo_images/000009.jpg +0 -0
- utils/evaluation/cal_PR.py +79 -0
- utils/evaluation/cal_mAP.py +56 -0
- utils/evaluation/eval.py +64 -0
- utils/evaluation/warmUpLR.py +11 -0
- utils/prepare/prepare_coco.py +67 -0
- utils/prepare/prepare_voc.py +149 -0
- utils/prepare/prepare_wider.py +58 -0
- utils/visualize.py +42 -0
000001.jpg
ADDED
000006.jpg
ADDED
000009.jpg
ADDED
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
import argparse
|
5 |
+
import time
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.optim as optim
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
from PIL import Image
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision import transforms
|
15 |
+
|
16 |
+
from pipeline.resnet_csra import ResNet_CSRA
|
17 |
+
from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA
|
18 |
+
from pipeline.dataset import DataSet
|
19 |
+
from torchvision.transforms import transforms
|
20 |
+
from utils.evaluation.eval import voc_classes, wider_classes, coco_classes, class_dict
|
21 |
+
|
22 |
+
torch.manual_seed(0)
|
23 |
+
|
24 |
+
if torch.cuda.is_available():
|
25 |
+
torch.backends.cudnn.deterministic = True
|
26 |
+
|
27 |
+
# Device
|
28 |
+
DEVICE = "cpu"
|
29 |
+
print(DEVICE)
|
30 |
+
|
31 |
+
# Make directories
|
32 |
+
os.system("mkdir ./models")
|
33 |
+
|
34 |
+
# Get model weights
|
35 |
+
if not os.path.exists("./models/msl_c_voc.pth"):
|
36 |
+
os.system(
|
37 |
+
"wget -O ./models/msl_c_voc.pth https://github.com/hasibzunair/msl-recognition/releases/download/v1.0-models/msl_c_voc.pth"
|
38 |
+
)
|
39 |
+
|
40 |
+
# Load model
|
41 |
+
model = ResNet_CSRA(num_heads=1, lam=0.1, num_classes=20)
|
42 |
+
normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
|
43 |
+
model.to(DEVICE)
|
44 |
+
print("Loading weights from {}".format("./models/msl_c_voc.pth"))
|
45 |
+
model.load_state_dict(torch.load("./models/msl_c_voc.pth"))
|
46 |
+
|
47 |
+
# Inference!
|
48 |
+
def inference(img_path):
|
49 |
+
# read image
|
50 |
+
image = Image.open(img_path).convert("RGB")
|
51 |
+
|
52 |
+
# image pre-process
|
53 |
+
transforms_image = transforms.Compose([
|
54 |
+
transforms.Resize((448, 448)),
|
55 |
+
transforms.ToTensor(),
|
56 |
+
normalize
|
57 |
+
])
|
58 |
+
|
59 |
+
image = transforms_image(image)
|
60 |
+
image = image.unsqueeze(0)
|
61 |
+
|
62 |
+
# Predict
|
63 |
+
result = []
|
64 |
+
model.eval()
|
65 |
+
with torch.no_grad():
|
66 |
+
image = image.to(DEVICE)
|
67 |
+
logit = model(image).squeeze(0)
|
68 |
+
logit = nn.Sigmoid()(logit)
|
69 |
+
|
70 |
+
pos = torch.where(logit > 0.5)[0].cpu().numpy()
|
71 |
+
for k in pos:
|
72 |
+
result.append(str(class_dict["voc07"][k]))
|
73 |
+
return result
|
74 |
+
|
75 |
+
|
76 |
+
# Define ins outs placeholders
|
77 |
+
inputs = gr.inputs.Image(type="filepath", label="Input Image")
|
78 |
+
|
79 |
+
# Define style
|
80 |
+
title = "Learning to Recognize Occluded and Small Objects with Partial Inputs"
|
81 |
+
description = "TBA."
|
82 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1512.03385' target='_blank'>Learning to Recognize Occluded and Small Objects with Partial Inputs</a> | <a href='https://github.com/hasibzunair/msl-recognition' target='_blank'>Github Repo</a></p>"
|
83 |
+
|
84 |
+
voc_classes = ("aeroplane", "bicycle", "bird", "boat", "bottle",
|
85 |
+
"bus", "car", "cat", "chair", "cow", "diningtable",
|
86 |
+
"dog", "horse", "motorbike", "person", "pottedplant",
|
87 |
+
"sheep", "sofa", "train", "tvmonitor")
|
88 |
+
|
89 |
+
# Run inference
|
90 |
+
gr.Interface(inference,
|
91 |
+
inputs,
|
92 |
+
outputs="text",
|
93 |
+
examples=["demo_images/000001.jpg", "demo_images/000006.jpg", "demo_images/000009.jpg"],
|
94 |
+
title=title,
|
95 |
+
description=description,
|
96 |
+
article=article,
|
97 |
+
analytics_enabled=False).launch()
|
pipeline/csra.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
class CSRA(nn.Module): # one basic block
|
7 |
+
def __init__(self, input_dim, num_classes, T, lam):
|
8 |
+
super(CSRA, self).__init__()
|
9 |
+
self.T = T # temperature
|
10 |
+
self.lam = lam # Lambda
|
11 |
+
self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False)
|
12 |
+
self.softmax = nn.Softmax(dim=2)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
# x (B d H W)
|
16 |
+
# normalize classifier
|
17 |
+
# score (B C HxW)
|
18 |
+
score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
|
19 |
+
score = score.flatten(2)
|
20 |
+
base_logit = torch.mean(score, dim=2)
|
21 |
+
|
22 |
+
if self.T == 99: # max-pooling
|
23 |
+
att_logit = torch.max(score, dim=2)[0]
|
24 |
+
else:
|
25 |
+
score_soft = self.softmax(score * self.T)
|
26 |
+
# https://github.com/Kevinz-code/CSRA/issues/5
|
27 |
+
att_logit = torch.sum(score * score_soft, dim=2)
|
28 |
+
|
29 |
+
return base_logit + self.lam * att_logit
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
class MHA(nn.Module): # multi-head attention
|
35 |
+
temp_settings = { # softmax temperature settings
|
36 |
+
1: [1],
|
37 |
+
2: [1, 99],
|
38 |
+
4: [1, 2, 4, 99],
|
39 |
+
6: [1, 2, 3, 4, 5, 99],
|
40 |
+
8: [1, 2, 3, 4, 5, 6, 7, 99]
|
41 |
+
}
|
42 |
+
|
43 |
+
def __init__(self, num_heads, lam, input_dim, num_classes):
|
44 |
+
super(MHA, self).__init__()
|
45 |
+
self.temp_list = self.temp_settings[num_heads]
|
46 |
+
self.multi_head = nn.ModuleList([
|
47 |
+
CSRA(input_dim, num_classes, self.temp_list[i], lam)
|
48 |
+
for i in range(num_heads)
|
49 |
+
])
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
logit = 0.
|
53 |
+
for head in self.multi_head:
|
54 |
+
logit += head(x)
|
55 |
+
return logit
|
pipeline/dataset.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision.transforms import transforms
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
try:
|
12 |
+
from torchvision.transforms import InterpolationMode
|
13 |
+
|
14 |
+
BICUBIC = InterpolationMode.BICUBIC
|
15 |
+
except ImportError:
|
16 |
+
BICUBIC = Image.BICUBIC
|
17 |
+
|
18 |
+
|
19 |
+
# modify for transformation for vit
|
20 |
+
# modfify wider crop-person images
|
21 |
+
|
22 |
+
|
23 |
+
###### Base data loader ######
|
24 |
+
class DataSet(Dataset):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
ann_files,
|
28 |
+
augs,
|
29 |
+
img_size,
|
30 |
+
dataset,
|
31 |
+
):
|
32 |
+
self.dataset = dataset
|
33 |
+
self.ann_files = ann_files
|
34 |
+
self.augment = self.augs_function(augs, img_size)
|
35 |
+
self.transform = transforms.Compose(
|
36 |
+
[transforms.ToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])]
|
37 |
+
# In this paper, we normalize the image data to [0, 1]
|
38 |
+
# You can also use the so called 'ImageNet' Normalization method
|
39 |
+
)
|
40 |
+
self.anns = []
|
41 |
+
self.load_anns()
|
42 |
+
print(self.augment)
|
43 |
+
|
44 |
+
# in wider dataset we use vit models
|
45 |
+
# so transformation has been changed
|
46 |
+
if self.dataset == "wider":
|
47 |
+
self.transform = transforms.Compose(
|
48 |
+
[
|
49 |
+
transforms.ToTensor(),
|
50 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
51 |
+
]
|
52 |
+
)
|
53 |
+
|
54 |
+
def augs_function(self, augs, img_size):
|
55 |
+
t = []
|
56 |
+
if "randomflip" in augs:
|
57 |
+
t.append(transforms.RandomHorizontalFlip())
|
58 |
+
if "ColorJitter" in augs:
|
59 |
+
t.append(
|
60 |
+
transforms.ColorJitter(
|
61 |
+
brightness=0.5, contrast=0.5, saturation=0.5, hue=0
|
62 |
+
)
|
63 |
+
)
|
64 |
+
if "resizedcrop" in augs:
|
65 |
+
t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)))
|
66 |
+
if "RandAugment" in augs:
|
67 |
+
t.append(RandAugment())
|
68 |
+
|
69 |
+
t.append(transforms.Resize((img_size, img_size)))
|
70 |
+
|
71 |
+
return transforms.Compose(t)
|
72 |
+
|
73 |
+
def load_anns(self):
|
74 |
+
self.anns = []
|
75 |
+
for ann_file in self.ann_files:
|
76 |
+
json_data = json.load(open(ann_file, "r"))
|
77 |
+
self.anns += json_data
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.anns)
|
81 |
+
|
82 |
+
def __getitem__(self, idx):
|
83 |
+
idx = idx % len(self)
|
84 |
+
ann = self.anns[idx]
|
85 |
+
img = Image.open(ann["img_path"]).convert("RGB")
|
86 |
+
|
87 |
+
if self.dataset == "wider":
|
88 |
+
x, y, w, h = ann["bbox"]
|
89 |
+
img_area = img.crop([x, y, x + w, y + h])
|
90 |
+
img_area = self.augment(img_area)
|
91 |
+
img_area = self.transform(img_area)
|
92 |
+
message = {
|
93 |
+
"img_path": ann["img_path"],
|
94 |
+
"target": torch.Tensor(ann["target"]),
|
95 |
+
"img": img_area,
|
96 |
+
}
|
97 |
+
else: # voc and coco
|
98 |
+
img = self.augment(img)
|
99 |
+
img = self.transform(img)
|
100 |
+
message = {
|
101 |
+
"img_path": ann["img_path"],
|
102 |
+
"target": torch.Tensor(ann["target"]),
|
103 |
+
"img": img,
|
104 |
+
}
|
105 |
+
|
106 |
+
return message
|
107 |
+
# finally, if we use dataloader to get the data, we will get
|
108 |
+
# {
|
109 |
+
# "img_path": list, # length = batch_size
|
110 |
+
# "target": Tensor, # shape: batch_size * num_classes
|
111 |
+
# "img": Tensor, # shape: batch_size * 3 * 224 * 224
|
112 |
+
# }
|
113 |
+
|
114 |
+
|
115 |
+
def preprocess_scribble(img, img_size):
|
116 |
+
transform = transforms.Compose(
|
117 |
+
[
|
118 |
+
transforms.Resize(img_size, BICUBIC),
|
119 |
+
transforms.CenterCrop(img_size),
|
120 |
+
#_convert_image_to_rgb,
|
121 |
+
transforms.ToTensor(),
|
122 |
+
]
|
123 |
+
)
|
124 |
+
return transform(img)
|
125 |
+
|
126 |
+
|
127 |
+
class DataSetMaskSup(Dataset):
|
128 |
+
"""
|
129 |
+
Data loader with scribbles.
|
130 |
+
"""
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
ann_files,
|
134 |
+
augs,
|
135 |
+
img_size,
|
136 |
+
dataset,
|
137 |
+
):
|
138 |
+
self.dataset = dataset
|
139 |
+
self.ann_files = ann_files
|
140 |
+
self.img_size = img_size
|
141 |
+
self.augment = self.augs_function(augs, img_size)
|
142 |
+
self.transform = transforms.Compose(
|
143 |
+
[transforms.ToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])]
|
144 |
+
# In this paper, we normalize the image data to [0, 1]
|
145 |
+
# You can also use the so called 'ImageNet' Normalization method
|
146 |
+
)
|
147 |
+
self.anns = []
|
148 |
+
self.load_anns()
|
149 |
+
print(self.augment)
|
150 |
+
|
151 |
+
# scribbles
|
152 |
+
self._scribbles_folder = "./datasets/SCRIBBLES"
|
153 |
+
|
154 |
+
# Type of masks to use, this is hardcoded since we find that high masks
|
155 |
+
# work better in MSL. See paper for details.
|
156 |
+
|
157 |
+
# for low masks
|
158 |
+
# self._scribbles = sorted(glob.glob(self._scribbles_folder + "/*.png"))[
|
159 |
+
# :1000
|
160 |
+
# ]
|
161 |
+
|
162 |
+
# for high masks
|
163 |
+
self._scribbles = sorted(glob.glob(self._scribbles_folder + "/*.png"))[::-1][
|
164 |
+
:1000
|
165 |
+
]
|
166 |
+
|
167 |
+
# in wider dataset we use vit models
|
168 |
+
# so transformation has been changed
|
169 |
+
if self.dataset == "wider":
|
170 |
+
self.transform = transforms.Compose(
|
171 |
+
[
|
172 |
+
transforms.ToTensor(),
|
173 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
174 |
+
]
|
175 |
+
)
|
176 |
+
|
177 |
+
def augs_function(self, augs, img_size):
|
178 |
+
t = []
|
179 |
+
if "randomflip" in augs:
|
180 |
+
t.append(transforms.RandomHorizontalFlip())
|
181 |
+
if "ColorJitter" in augs:
|
182 |
+
t.append(
|
183 |
+
transforms.ColorJitter(
|
184 |
+
brightness=0.5, contrast=0.5, saturation=0.5, hue=0
|
185 |
+
)
|
186 |
+
)
|
187 |
+
if "resizedcrop" in augs:
|
188 |
+
t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)))
|
189 |
+
if "RandAugment" in augs:
|
190 |
+
t.append(RandAugment())
|
191 |
+
|
192 |
+
t.append(transforms.Resize((img_size, img_size)))
|
193 |
+
|
194 |
+
return transforms.Compose(t)
|
195 |
+
|
196 |
+
def load_anns(self):
|
197 |
+
self.anns = []
|
198 |
+
for ann_file in self.ann_files:
|
199 |
+
json_data = json.load(open(ann_file, "r"))
|
200 |
+
self.anns += json_data
|
201 |
+
|
202 |
+
def __len__(self):
|
203 |
+
return len(self.anns)
|
204 |
+
|
205 |
+
def __getitem__(self, idx):
|
206 |
+
idx = idx % len(self)
|
207 |
+
ann = self.anns[idx]
|
208 |
+
img = Image.open(ann["img_path"]).convert("RGB")
|
209 |
+
|
210 |
+
# get scribble
|
211 |
+
scribble_path = self._scribbles[
|
212 |
+
random.randint(0, 950)
|
213 |
+
]
|
214 |
+
scribble = Image.open(scribble_path).convert('P')
|
215 |
+
scribble = preprocess_scribble(scribble, self.img_size)
|
216 |
+
|
217 |
+
scribble_t = (scribble > 0).float() # threshold to [0,1]
|
218 |
+
inv_scribble = (torch.max(scribble_t) - scribble_t) # inverted scribble
|
219 |
+
|
220 |
+
if self.dataset == "wider":
|
221 |
+
x, y, w, h = ann["bbox"]
|
222 |
+
img_area = img.crop([x, y, x + w, y + h])
|
223 |
+
img_area = self.augment(img_area)
|
224 |
+
img_area = self.transform(img_area)
|
225 |
+
|
226 |
+
# masked image
|
227 |
+
masked_image = img_area * inv_scribble
|
228 |
+
message = {
|
229 |
+
"img_path": ann["img_path"],
|
230 |
+
"target": torch.Tensor(ann["target"]),
|
231 |
+
"img": img_area,
|
232 |
+
"masked_img": masked_image,
|
233 |
+
#"scribble": inv_scribble,
|
234 |
+
}
|
235 |
+
else: # voc and coco
|
236 |
+
img = self.augment(img)
|
237 |
+
img = self.transform(img)
|
238 |
+
# masked image
|
239 |
+
masked_image = img * inv_scribble
|
240 |
+
message = {
|
241 |
+
"img_path": ann["img_path"],
|
242 |
+
"target": torch.Tensor(ann["target"]),
|
243 |
+
"img": img,
|
244 |
+
"masked_img": masked_image,
|
245 |
+
#"scribble": inv_scribble,
|
246 |
+
}
|
247 |
+
|
248 |
+
return message
|
249 |
+
# finally, if we use dataloader to get the data, we will get
|
250 |
+
# {
|
251 |
+
# "img_path": list, # length = batch_size
|
252 |
+
# "target": Tensor, # shape: batch_size * num_classes
|
253 |
+
# "img": Tensor, # shape: batch_size * 3 * 224 * 224
|
254 |
+
# "masked_img": Tensor, # shape: batch_size * 3 * 224 * 224
|
255 |
+
# }
|
pipeline/losses.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
"""ASL taken from https://github.com/Alibaba-MIIL/ASL"""
|
5 |
+
|
6 |
+
# Usage
|
7 |
+
# global criterion_asl
|
8 |
+
# criterion_asl = AsymmetricLoss(gamma_neg=4, gamma_pos=0, clip=0.05, disable_torch_grad_focal_loss=True)
|
9 |
+
# loss3 = criterion_asl(pred1, pred2)
|
10 |
+
|
11 |
+
class AsymmetricLoss(nn.Module):
|
12 |
+
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True):
|
13 |
+
super(AsymmetricLoss, self).__init__()
|
14 |
+
|
15 |
+
self.gamma_neg = gamma_neg
|
16 |
+
self.gamma_pos = gamma_pos
|
17 |
+
self.clip = clip
|
18 |
+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
19 |
+
self.eps = eps
|
20 |
+
|
21 |
+
def forward(self, x, y):
|
22 |
+
""""
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
x: input logits
|
26 |
+
y: targets (multi-label binarized vector)
|
27 |
+
"""
|
28 |
+
|
29 |
+
# Calculating Probabilities
|
30 |
+
x_sigmoid = torch.sigmoid(x)
|
31 |
+
xs_pos = x_sigmoid
|
32 |
+
xs_neg = 1 - x_sigmoid
|
33 |
+
|
34 |
+
# Asymmetric Clipping
|
35 |
+
if self.clip is not None and self.clip > 0:
|
36 |
+
xs_neg = (xs_neg + self.clip).clamp(max=1)
|
37 |
+
|
38 |
+
# Basic CE calculation
|
39 |
+
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
|
40 |
+
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
|
41 |
+
loss = los_pos + los_neg
|
42 |
+
|
43 |
+
# Asymmetric Focusing
|
44 |
+
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
45 |
+
if self.disable_torch_grad_focal_loss:
|
46 |
+
torch.set_grad_enabled(False)
|
47 |
+
pt0 = xs_pos * y
|
48 |
+
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
49 |
+
pt = pt0 + pt1
|
50 |
+
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
51 |
+
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
52 |
+
if self.disable_torch_grad_focal_loss:
|
53 |
+
torch.set_grad_enabled(True)
|
54 |
+
loss *= one_sided_w
|
55 |
+
|
56 |
+
return -loss.sum()
|
pipeline/models/tresnet/layers/anti_aliasing.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.parallel
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class AntiAliasDownsampleLayer(nn.Module):
|
9 |
+
def __init__(self, remove_model_jit: bool = False, filt_size: int = 3, stride: int = 2,
|
10 |
+
channels: int = 0):
|
11 |
+
super(AntiAliasDownsampleLayer, self).__init__()
|
12 |
+
if not remove_model_jit:
|
13 |
+
self.op = DownsampleJIT(filt_size, stride, channels)
|
14 |
+
else:
|
15 |
+
self.op = Downsample(filt_size, stride, channels)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return self.op(x)
|
19 |
+
|
20 |
+
|
21 |
+
@torch.jit.script
|
22 |
+
class DownsampleJIT(object):
|
23 |
+
def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0):
|
24 |
+
self.stride = stride
|
25 |
+
self.filt_size = filt_size
|
26 |
+
self.channels = channels
|
27 |
+
|
28 |
+
assert self.filt_size == 3
|
29 |
+
assert stride == 2
|
30 |
+
a = torch.tensor([1., 2., 1.])
|
31 |
+
|
32 |
+
filt = (a[:, None] * a[None, :]).clone().detach()
|
33 |
+
filt = filt / torch.sum(filt)
|
34 |
+
self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half()
|
35 |
+
|
36 |
+
def __call__(self, input: torch.Tensor):
|
37 |
+
if input.dtype != self.filt.dtype:
|
38 |
+
self.filt = self.filt.float()
|
39 |
+
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
40 |
+
return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1])
|
41 |
+
|
42 |
+
|
43 |
+
class Downsample(nn.Module):
|
44 |
+
def __init__(self, filt_size=3, stride=2, channels=None):
|
45 |
+
super(Downsample, self).__init__()
|
46 |
+
self.filt_size = filt_size
|
47 |
+
self.stride = stride
|
48 |
+
self.channels = channels
|
49 |
+
|
50 |
+
|
51 |
+
assert self.filt_size == 3
|
52 |
+
a = torch.tensor([1., 2., 1.])
|
53 |
+
|
54 |
+
filt = (a[:, None] * a[None, :]).clone().detach()
|
55 |
+
filt = filt / torch.sum(filt)
|
56 |
+
self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
57 |
+
|
58 |
+
def forward(self, input):
|
59 |
+
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
60 |
+
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
|
pipeline/models/tresnet/layers/avg_pool.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class FastAvgPool2d(nn.Module):
|
8 |
+
def __init__(self, flatten=False):
|
9 |
+
super(FastAvgPool2d, self).__init__()
|
10 |
+
self.flatten = flatten
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
if self.flatten:
|
14 |
+
in_size = x.size()
|
15 |
+
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
16 |
+
else:
|
17 |
+
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
18 |
+
|
19 |
+
|
pipeline/models/tresnet/layers/general_layers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from pipeline.models.tresnet.layers.avg_pool import FastAvgPool2d
|
6 |
+
|
7 |
+
|
8 |
+
class Flatten(nn.Module):
|
9 |
+
def forward(self, x):
|
10 |
+
return x.view(x.size(0), -1)
|
11 |
+
|
12 |
+
|
13 |
+
class DepthToSpace(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, block_size):
|
16 |
+
super().__init__()
|
17 |
+
self.bs = block_size
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
N, C, H, W = x.size()
|
21 |
+
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
|
22 |
+
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
|
23 |
+
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class SpaceToDepthModule(nn.Module):
|
28 |
+
def __init__(self, remove_model_jit=False):
|
29 |
+
super().__init__()
|
30 |
+
if not remove_model_jit:
|
31 |
+
self.op = SpaceToDepthJit()
|
32 |
+
else:
|
33 |
+
self.op = SpaceToDepth()
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return self.op(x)
|
37 |
+
|
38 |
+
|
39 |
+
class SpaceToDepth(nn.Module):
|
40 |
+
def __init__(self, block_size=4):
|
41 |
+
super().__init__()
|
42 |
+
assert block_size == 4
|
43 |
+
self.bs = block_size
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
N, C, H, W = x.size()
|
47 |
+
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
|
48 |
+
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
49 |
+
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
@torch.jit.script
|
54 |
+
class SpaceToDepthJit(object):
|
55 |
+
def __call__(self, x: torch.Tensor):
|
56 |
+
# assuming hard-coded that block_size==4 for acceleration
|
57 |
+
N, C, H, W = x.size()
|
58 |
+
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
|
59 |
+
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
60 |
+
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class hard_sigmoid(nn.Module):
|
65 |
+
def __init__(self, inplace=True):
|
66 |
+
super(hard_sigmoid, self).__init__()
|
67 |
+
self.inplace = inplace
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
if self.inplace:
|
71 |
+
return x.add_(3.).clamp_(0., 6.).div_(6.)
|
72 |
+
else:
|
73 |
+
return F.relu6(x + 3.) / 6.
|
74 |
+
|
75 |
+
|
76 |
+
class SEModule(nn.Module):
|
77 |
+
|
78 |
+
def __init__(self, channels, reduction_channels, inplace=True):
|
79 |
+
super(SEModule, self).__init__()
|
80 |
+
self.avg_pool = FastAvgPool2d()
|
81 |
+
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
82 |
+
self.relu = nn.ReLU(inplace=inplace)
|
83 |
+
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
84 |
+
# self.activation = hard_sigmoid(inplace=inplace)
|
85 |
+
self.activation = nn.Sigmoid()
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x_se = self.avg_pool(x)
|
89 |
+
x_se2 = self.fc1(x_se)
|
90 |
+
x_se2 = self.relu(x_se2)
|
91 |
+
x_se = self.fc2(x_se2)
|
92 |
+
x_se = self.activation(x_se)
|
93 |
+
return x * x_se
|
pipeline/models/tresnet/tresnet.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import Module as Module
|
4 |
+
from collections import OrderedDict
|
5 |
+
from pipeline.models.tresnet.layers.anti_aliasing import AntiAliasDownsampleLayer
|
6 |
+
from .layers.avg_pool import FastAvgPool2d
|
7 |
+
from .layers.general_layers import SEModule, SpaceToDepthModule
|
8 |
+
from inplace_abn import InPlaceABN, ABN
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
def InplacABN_to_ABN(module: nn.Module) -> nn.Module:
|
12 |
+
# convert all InplaceABN layer to bit-accurate ABN layers.
|
13 |
+
if isinstance(module, InPlaceABN):
|
14 |
+
module_new = ABN(module.num_features, activation=module.activation,
|
15 |
+
activation_param=module.activation_param)
|
16 |
+
for key in module.state_dict():
|
17 |
+
module_new.state_dict()[key].copy_(module.state_dict()[key])
|
18 |
+
module_new.training = module.training
|
19 |
+
module_new.weight.data = module_new.weight.abs() + module_new.eps
|
20 |
+
return module_new
|
21 |
+
for name, child in reversed(module._modules.items()):
|
22 |
+
new_child = InplacABN_to_ABN(child)
|
23 |
+
if new_child != child:
|
24 |
+
module._modules[name] = new_child
|
25 |
+
return module
|
26 |
+
|
27 |
+
class bottleneck_head(nn.Module):
|
28 |
+
def __init__(self, num_features, num_classes, bottleneck_features=200):
|
29 |
+
super(bottleneck_head, self).__init__()
|
30 |
+
self.embedding_generator = nn.ModuleList()
|
31 |
+
self.embedding_generator.append(nn.Linear(num_features, bottleneck_features))
|
32 |
+
self.embedding_generator = nn.Sequential(*self.embedding_generator)
|
33 |
+
self.FC = nn.Linear(bottleneck_features, num_classes)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
self.embedding = self.embedding_generator(x)
|
37 |
+
logits = self.FC(self.embedding)
|
38 |
+
return logits
|
39 |
+
|
40 |
+
|
41 |
+
def conv2d(ni, nf, stride):
|
42 |
+
return nn.Sequential(
|
43 |
+
nn.Conv2d(ni, nf, kernel_size=3, stride=stride, padding=1, bias=False),
|
44 |
+
nn.BatchNorm2d(nf),
|
45 |
+
nn.ReLU(inplace=True)
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1):
|
50 |
+
return nn.Sequential(
|
51 |
+
nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups,
|
52 |
+
bias=False),
|
53 |
+
InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param)
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class BasicBlock(Module):
|
58 |
+
expansion = 1
|
59 |
+
|
60 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
|
61 |
+
super(BasicBlock, self).__init__()
|
62 |
+
if stride == 1:
|
63 |
+
self.conv1 = conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3)
|
64 |
+
else:
|
65 |
+
if anti_alias_layer is None:
|
66 |
+
self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3)
|
67 |
+
else:
|
68 |
+
self.conv1 = nn.Sequential(conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3),
|
69 |
+
anti_alias_layer(channels=planes, filt_size=3, stride=2))
|
70 |
+
|
71 |
+
self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity")
|
72 |
+
self.relu = nn.ReLU(inplace=True)
|
73 |
+
self.downsample = downsample
|
74 |
+
self.stride = stride
|
75 |
+
reduce_layer_planes = max(planes * self.expansion // 4, 64)
|
76 |
+
self.se = SEModule(planes * self.expansion, reduce_layer_planes) if use_se else None
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
if self.downsample is not None:
|
80 |
+
residual = self.downsample(x)
|
81 |
+
else:
|
82 |
+
residual = x
|
83 |
+
|
84 |
+
out = self.conv1(x)
|
85 |
+
out = self.conv2(out)
|
86 |
+
|
87 |
+
if self.se is not None: out = self.se(out)
|
88 |
+
|
89 |
+
out += residual
|
90 |
+
|
91 |
+
out = self.relu(out)
|
92 |
+
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
class Bottleneck(Module):
|
97 |
+
expansion = 4
|
98 |
+
|
99 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
|
100 |
+
super(Bottleneck, self).__init__()
|
101 |
+
self.conv1 = conv2d_ABN(inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu",
|
102 |
+
activation_param=1e-3)
|
103 |
+
if stride == 1:
|
104 |
+
self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu",
|
105 |
+
activation_param=1e-3)
|
106 |
+
else:
|
107 |
+
if anti_alias_layer is None:
|
108 |
+
self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=2, activation="leaky_relu",
|
109 |
+
activation_param=1e-3)
|
110 |
+
else:
|
111 |
+
self.conv2 = nn.Sequential(conv2d_ABN(planes, planes, kernel_size=3, stride=1,
|
112 |
+
activation="leaky_relu", activation_param=1e-3),
|
113 |
+
anti_alias_layer(channels=planes, filt_size=3, stride=2))
|
114 |
+
|
115 |
+
self.conv3 = conv2d_ABN(planes, planes * self.expansion, kernel_size=1, stride=1,
|
116 |
+
activation="identity")
|
117 |
+
|
118 |
+
self.relu = nn.ReLU(inplace=True)
|
119 |
+
self.downsample = downsample
|
120 |
+
self.stride = stride
|
121 |
+
|
122 |
+
reduce_layer_planes = max(planes * self.expansion // 8, 64)
|
123 |
+
self.se = SEModule(planes, reduce_layer_planes) if use_se else None
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
if self.downsample is not None:
|
127 |
+
residual = self.downsample(x)
|
128 |
+
else:
|
129 |
+
residual = x
|
130 |
+
|
131 |
+
out = self.conv1(x)
|
132 |
+
out = self.conv2(out)
|
133 |
+
if self.se is not None: out = self.se(out)
|
134 |
+
|
135 |
+
out = self.conv3(out)
|
136 |
+
out = out + residual # no inplace
|
137 |
+
out = self.relu(out)
|
138 |
+
|
139 |
+
return out
|
140 |
+
|
141 |
+
|
142 |
+
class TResNet(Module):
|
143 |
+
|
144 |
+
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0,
|
145 |
+
do_bottleneck_head=False,bottleneck_features=512):
|
146 |
+
super(TResNet, self).__init__()
|
147 |
+
|
148 |
+
# Loss function
|
149 |
+
self.loss_func = F.binary_cross_entropy_with_logits
|
150 |
+
|
151 |
+
# JIT layers
|
152 |
+
space_to_depth = SpaceToDepthModule()
|
153 |
+
anti_alias_layer = AntiAliasDownsampleLayer
|
154 |
+
global_pool_layer = FastAvgPool2d(flatten=True)
|
155 |
+
|
156 |
+
# TResnet stages
|
157 |
+
self.inplanes = int(64 * width_factor)
|
158 |
+
self.planes = int(64 * width_factor)
|
159 |
+
conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3)
|
160 |
+
layer1 = self._make_layer(BasicBlock, self.planes, layers[0], stride=1, use_se=True,
|
161 |
+
anti_alias_layer=anti_alias_layer) # 56x56
|
162 |
+
layer2 = self._make_layer(BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True,
|
163 |
+
anti_alias_layer=anti_alias_layer) # 28x28
|
164 |
+
layer3 = self._make_layer(Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True,
|
165 |
+
anti_alias_layer=anti_alias_layer) # 14x14
|
166 |
+
layer4 = self._make_layer(Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False,
|
167 |
+
anti_alias_layer=anti_alias_layer) # 7x7
|
168 |
+
|
169 |
+
# body
|
170 |
+
self.body = nn.Sequential(OrderedDict([
|
171 |
+
('SpaceToDepth', space_to_depth),
|
172 |
+
('conv1', conv1),
|
173 |
+
('layer1', layer1),
|
174 |
+
('layer2', layer2),
|
175 |
+
('layer3', layer3),
|
176 |
+
('layer4', layer4)]))
|
177 |
+
|
178 |
+
# head
|
179 |
+
self.embeddings = []
|
180 |
+
self.global_pool = nn.Sequential(OrderedDict([('global_pool_layer', global_pool_layer)]))
|
181 |
+
self.num_features = (self.planes * 8) * Bottleneck.expansion
|
182 |
+
if do_bottleneck_head:
|
183 |
+
fc = bottleneck_head(self.num_features, num_classes,
|
184 |
+
bottleneck_features=bottleneck_features)
|
185 |
+
else:
|
186 |
+
fc = nn.Linear(self.num_features , num_classes)
|
187 |
+
|
188 |
+
self.head = nn.Sequential(OrderedDict([('fc', fc)]))
|
189 |
+
|
190 |
+
# model initilization
|
191 |
+
for m in self.modules():
|
192 |
+
if isinstance(m, nn.Conv2d):
|
193 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
194 |
+
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN):
|
195 |
+
nn.init.constant_(m.weight, 1)
|
196 |
+
nn.init.constant_(m.bias, 0)
|
197 |
+
|
198 |
+
# residual connections special initialization
|
199 |
+
for m in self.modules():
|
200 |
+
if isinstance(m, BasicBlock):
|
201 |
+
m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero
|
202 |
+
if isinstance(m, Bottleneck):
|
203 |
+
m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero
|
204 |
+
if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01)
|
205 |
+
|
206 |
+
def _make_layer(self, block, planes, blocks, stride=1, use_se=True, anti_alias_layer=None):
|
207 |
+
downsample = None
|
208 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
209 |
+
layers = []
|
210 |
+
if stride == 2:
|
211 |
+
# avg pooling before 1x1 conv
|
212 |
+
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
|
213 |
+
layers += [conv2d_ABN(self.inplanes, planes * block.expansion, kernel_size=1, stride=1,
|
214 |
+
activation="identity")]
|
215 |
+
downsample = nn.Sequential(*layers)
|
216 |
+
|
217 |
+
layers = []
|
218 |
+
layers.append(block(self.inplanes, planes, stride, downsample, use_se=use_se,
|
219 |
+
anti_alias_layer=anti_alias_layer))
|
220 |
+
self.inplanes = planes * block.expansion
|
221 |
+
for i in range(1, blocks): layers.append(
|
222 |
+
block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer))
|
223 |
+
return nn.Sequential(*layers)
|
224 |
+
|
225 |
+
def forward_train(self, x, target):
|
226 |
+
x = self.body(x)
|
227 |
+
self.embeddings = self.global_pool(x)
|
228 |
+
logits = self.head(self.embeddings)
|
229 |
+
loss = self.loss_func(logits, target, reduction="mean")
|
230 |
+
return logits, loss
|
231 |
+
|
232 |
+
def forward_test(self, x):
|
233 |
+
x = self.body(x)
|
234 |
+
self.embeddings = self.global_pool(x)
|
235 |
+
logits = self.head(self.embeddings)
|
236 |
+
return logits
|
237 |
+
|
238 |
+
def forward(self, x, target=None):
|
239 |
+
if target is not None:
|
240 |
+
return self.forward_train(x, target)
|
241 |
+
else:
|
242 |
+
return self.forward_test(x)
|
243 |
+
|
244 |
+
|
245 |
+
def TResnetM(num_classes):
|
246 |
+
"""Constructs a medium TResnet model.
|
247 |
+
"""
|
248 |
+
in_chans = 3
|
249 |
+
model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans)
|
250 |
+
return model
|
251 |
+
|
252 |
+
|
253 |
+
def TResnetL(num_classes):
|
254 |
+
"""Constructs a large TResnet model.
|
255 |
+
"""
|
256 |
+
in_chans = 3
|
257 |
+
do_bottleneck_head = False
|
258 |
+
model = TResNet(layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2,
|
259 |
+
do_bottleneck_head=do_bottleneck_head)
|
260 |
+
return model
|
261 |
+
|
262 |
+
|
263 |
+
def TResnetXL(num_classes):
|
264 |
+
"""Constructs a xlarge TResnet model.
|
265 |
+
"""
|
266 |
+
in_chans = 3
|
267 |
+
model = TResNet(layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3)
|
268 |
+
return model
|
pipeline/models/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .factory import create_model
|
2 |
+
__all__ = ['create_model']
|
pipeline/models/utils/factory.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
logger = logging.getLogger(__name__)
|
4 |
+
|
5 |
+
from ..tresnet import TResnetM, TResnetL, TResnetXL
|
6 |
+
|
7 |
+
|
8 |
+
def create_model(args):
|
9 |
+
"""Create a model
|
10 |
+
"""
|
11 |
+
model_params = {'args': args, 'num_classes': args.num_classes}
|
12 |
+
args = model_params['args']
|
13 |
+
args.model_name = args.model_name.lower()
|
14 |
+
|
15 |
+
if args.model_name=='tresnet_m':
|
16 |
+
model = TResnetM(model_params)
|
17 |
+
elif args.model_name=='tresnet_l':
|
18 |
+
model = TResnetL(model_params)
|
19 |
+
elif args.model_name=='tresnet_xl':
|
20 |
+
model = TResnetXL(model_params)
|
21 |
+
else:
|
22 |
+
print("model: {} not found !!".format(args.model_name))
|
23 |
+
exit(-1)
|
24 |
+
|
25 |
+
return model
|
pipeline/resnet_csra.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.models import ResNet
|
2 |
+
from torchvision.models.resnet import Bottleneck, BasicBlock
|
3 |
+
from .csra import CSRA, MHA
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
import logging
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
model_urls = {
|
12 |
+
"resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
13 |
+
"resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
14 |
+
"resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
15 |
+
"resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
16 |
+
"resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
class ResNet_CSRA(ResNet):
|
21 |
+
arch_settings = {
|
22 |
+
18: (BasicBlock, (2, 2, 2, 2)),
|
23 |
+
34: (BasicBlock, (3, 4, 6, 3)),
|
24 |
+
50: (Bottleneck, (3, 4, 6, 3)),
|
25 |
+
101: (Bottleneck, (3, 4, 23, 3)),
|
26 |
+
152: (Bottleneck, (3, 8, 36, 3)),
|
27 |
+
}
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self, num_heads, lam, num_classes, depth=101, input_dim=2048, cutmix=None
|
31 |
+
):
|
32 |
+
self.block, self.layers = self.arch_settings[depth]
|
33 |
+
self.depth = depth
|
34 |
+
super(ResNet_CSRA, self).__init__(self.block, self.layers)
|
35 |
+
self.init_weights(pretrained=True, cutmix=cutmix)
|
36 |
+
|
37 |
+
self.classifier = MHA(num_heads, lam, input_dim, num_classes)
|
38 |
+
self.loss_func = F.binary_cross_entropy_with_logits
|
39 |
+
# todo
|
40 |
+
# criterion = nn.BCEWithLogitsLoss() # loss combines a Sigmoid layer and the BCELoss in one single class
|
41 |
+
|
42 |
+
def backbone(self, x):
|
43 |
+
x = self.conv1(x)
|
44 |
+
x = self.bn1(x)
|
45 |
+
x = self.relu(x)
|
46 |
+
x = self.maxpool(x)
|
47 |
+
|
48 |
+
x = self.layer1(x)
|
49 |
+
x = self.layer2(x)
|
50 |
+
x = self.layer3(x)
|
51 |
+
x = self.layer4(x)
|
52 |
+
|
53 |
+
return x
|
54 |
+
|
55 |
+
def forward_train(self, x, target):
|
56 |
+
x = self.backbone(x)
|
57 |
+
logit = self.classifier(x)
|
58 |
+
loss = self.loss_func(logit, target, reduction="mean")
|
59 |
+
return logit, loss
|
60 |
+
|
61 |
+
def forward_test(self, x):
|
62 |
+
x = self.backbone(x)
|
63 |
+
x = self.classifier(x)
|
64 |
+
return x
|
65 |
+
|
66 |
+
def forward(self, x, target=None):
|
67 |
+
if target is not None:
|
68 |
+
return self.forward_train(x, target)
|
69 |
+
else:
|
70 |
+
return self.forward_test(x)
|
71 |
+
|
72 |
+
def init_weights(self, pretrained=True, cutmix=None):
|
73 |
+
if cutmix is not None:
|
74 |
+
print("backbone params inited by CutMix pretrained model")
|
75 |
+
state_dict = torch.load(cutmix)
|
76 |
+
elif pretrained:
|
77 |
+
print("backbone params inited by Pytorch official model")
|
78 |
+
model_url = model_urls["resnet{}".format(self.depth)]
|
79 |
+
state_dict = model_zoo.load_url(model_url)
|
80 |
+
|
81 |
+
model_dict = self.state_dict()
|
82 |
+
try:
|
83 |
+
pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
|
84 |
+
self.load_state_dict(pretrained_dict)
|
85 |
+
except:
|
86 |
+
logger = logging.getLogger()
|
87 |
+
logger.info(
|
88 |
+
"the keys in pretrained model is not equal to the keys in the ResNet you choose, trying to fix..."
|
89 |
+
)
|
90 |
+
state_dict = self._keysFix(model_dict, state_dict)
|
91 |
+
self.load_state_dict(state_dict)
|
92 |
+
|
93 |
+
# remove the original 1000-class fc
|
94 |
+
self.fc = nn.Sequential()
|
pipeline/timm_utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .tuple import to_ntuple, to_2tuple, to_3tuple, to_4tuple
|
2 |
+
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
3 |
+
from .weight_init import trunc_normal_
|
4 |
+
|
pipeline/timm_utils/drop.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" DropBlock, DropPath
|
2 |
+
|
3 |
+
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
4 |
+
|
5 |
+
Papers:
|
6 |
+
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
7 |
+
|
8 |
+
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
9 |
+
|
10 |
+
Code:
|
11 |
+
DropBlock impl inspired by two Tensorflow impl that I liked:
|
12 |
+
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
13 |
+
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
14 |
+
|
15 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
16 |
+
"""
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
|
22 |
+
def drop_block_2d(
|
23 |
+
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
|
24 |
+
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
25 |
+
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
26 |
+
|
27 |
+
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
28 |
+
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
29 |
+
"""
|
30 |
+
B, C, H, W = x.shape
|
31 |
+
total_size = W * H
|
32 |
+
clipped_block_size = min(block_size, min(W, H))
|
33 |
+
# seed_drop_rate, the gamma parameter
|
34 |
+
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
35 |
+
(W - block_size + 1) * (H - block_size + 1))
|
36 |
+
|
37 |
+
# Forces the block to be inside the feature map.
|
38 |
+
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
|
39 |
+
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
|
40 |
+
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
41 |
+
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
42 |
+
|
43 |
+
if batchwise:
|
44 |
+
# one mask for whole batch, quite a bit faster
|
45 |
+
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
46 |
+
else:
|
47 |
+
uniform_noise = torch.rand_like(x)
|
48 |
+
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
49 |
+
block_mask = -F.max_pool2d(
|
50 |
+
-block_mask,
|
51 |
+
kernel_size=clipped_block_size, # block_size,
|
52 |
+
stride=1,
|
53 |
+
padding=clipped_block_size // 2)
|
54 |
+
|
55 |
+
if with_noise:
|
56 |
+
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
57 |
+
if inplace:
|
58 |
+
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
59 |
+
else:
|
60 |
+
x = x * block_mask + normal_noise * (1 - block_mask)
|
61 |
+
else:
|
62 |
+
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
63 |
+
if inplace:
|
64 |
+
x.mul_(block_mask * normalize_scale)
|
65 |
+
else:
|
66 |
+
x = x * block_mask * normalize_scale
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
def drop_block_fast_2d(
|
71 |
+
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
|
72 |
+
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
73 |
+
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
74 |
+
|
75 |
+
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
76 |
+
block mask at edges.
|
77 |
+
"""
|
78 |
+
B, C, H, W = x.shape
|
79 |
+
total_size = W * H
|
80 |
+
clipped_block_size = min(block_size, min(W, H))
|
81 |
+
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
82 |
+
(W - block_size + 1) * (H - block_size + 1))
|
83 |
+
|
84 |
+
if batchwise:
|
85 |
+
# one mask for whole batch, quite a bit faster
|
86 |
+
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
|
87 |
+
else:
|
88 |
+
# mask per batch element
|
89 |
+
block_mask = torch.rand_like(x) < gamma
|
90 |
+
block_mask = F.max_pool2d(
|
91 |
+
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
|
92 |
+
|
93 |
+
if with_noise:
|
94 |
+
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
95 |
+
if inplace:
|
96 |
+
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
|
97 |
+
else:
|
98 |
+
x = x * (1. - block_mask) + normal_noise * block_mask
|
99 |
+
else:
|
100 |
+
block_mask = 1 - block_mask
|
101 |
+
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
|
102 |
+
if inplace:
|
103 |
+
x.mul_(block_mask * normalize_scale)
|
104 |
+
else:
|
105 |
+
x = x * block_mask * normalize_scale
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class DropBlock2d(nn.Module):
|
110 |
+
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
111 |
+
"""
|
112 |
+
def __init__(self,
|
113 |
+
drop_prob=0.1,
|
114 |
+
block_size=7,
|
115 |
+
gamma_scale=1.0,
|
116 |
+
with_noise=False,
|
117 |
+
inplace=False,
|
118 |
+
batchwise=False,
|
119 |
+
fast=True):
|
120 |
+
super(DropBlock2d, self).__init__()
|
121 |
+
self.drop_prob = drop_prob
|
122 |
+
self.gamma_scale = gamma_scale
|
123 |
+
self.block_size = block_size
|
124 |
+
self.with_noise = with_noise
|
125 |
+
self.inplace = inplace
|
126 |
+
self.batchwise = batchwise
|
127 |
+
self.fast = fast # FIXME finish comparisons of fast vs not
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
if not self.training or not self.drop_prob:
|
131 |
+
return x
|
132 |
+
if self.fast:
|
133 |
+
return drop_block_fast_2d(
|
134 |
+
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
135 |
+
else:
|
136 |
+
return drop_block_2d(
|
137 |
+
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
138 |
+
|
139 |
+
|
140 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
141 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
142 |
+
|
143 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
144 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
145 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
146 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
147 |
+
'survival rate' as the argument.
|
148 |
+
|
149 |
+
"""
|
150 |
+
if drop_prob == 0. or not training:
|
151 |
+
return x
|
152 |
+
keep_prob = 1 - drop_prob
|
153 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
154 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
155 |
+
random_tensor.floor_() # binarize
|
156 |
+
output = x.div(keep_prob) * random_tensor
|
157 |
+
return output
|
158 |
+
|
159 |
+
|
160 |
+
class DropPath(nn.Module):
|
161 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
162 |
+
"""
|
163 |
+
def __init__(self, drop_prob=None):
|
164 |
+
super(DropPath, self).__init__()
|
165 |
+
self.drop_prob = drop_prob
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
return drop_path(x, self.drop_prob, self.training)
|
pipeline/timm_utils/tuple.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Layer/Module Helpers
|
2 |
+
|
3 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
from itertools import repeat
|
6 |
+
from torch._six import container_abcs
|
7 |
+
|
8 |
+
|
9 |
+
# From PyTorch internals
|
10 |
+
def _ntuple(n):
|
11 |
+
def parse(x):
|
12 |
+
if isinstance(x, container_abcs.Iterable):
|
13 |
+
return x
|
14 |
+
return tuple(repeat(x, n))
|
15 |
+
return parse
|
16 |
+
|
17 |
+
|
18 |
+
to_1tuple = _ntuple(1)
|
19 |
+
to_2tuple = _ntuple(2)
|
20 |
+
to_3tuple = _ntuple(3)
|
21 |
+
to_4tuple = _ntuple(4)
|
22 |
+
to_ntuple = _ntuple
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
pipeline/timm_utils/weight_init.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import math
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
|
6 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
7 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
8 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
9 |
+
def norm_cdf(x):
|
10 |
+
# Computes standard normal cumulative distribution function
|
11 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
12 |
+
|
13 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
14 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
15 |
+
"The distribution of values may be incorrect.",
|
16 |
+
stacklevel=2)
|
17 |
+
|
18 |
+
with torch.no_grad():
|
19 |
+
# Values are generated by using a truncated uniform distribution and
|
20 |
+
# then using the inverse CDF for the normal distribution.
|
21 |
+
# Get upper and lower cdf values
|
22 |
+
l = norm_cdf((a - mean) / std)
|
23 |
+
u = norm_cdf((b - mean) / std)
|
24 |
+
|
25 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
26 |
+
# [2l-1, 2u-1].
|
27 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
28 |
+
|
29 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
30 |
+
# standard normal
|
31 |
+
tensor.erfinv_()
|
32 |
+
|
33 |
+
# Transform to proper mean, std
|
34 |
+
tensor.mul_(std * math.sqrt(2.))
|
35 |
+
tensor.add_(mean)
|
36 |
+
|
37 |
+
# Clamp to ensure it's in the proper range
|
38 |
+
tensor.clamp_(min=a, max=b)
|
39 |
+
return tensor
|
40 |
+
|
41 |
+
|
42 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
43 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
44 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
45 |
+
normal distribution. The values are effectively drawn from the
|
46 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
47 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
48 |
+
the bounds. The method used for generating the random values works
|
49 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
50 |
+
Args:
|
51 |
+
tensor: an n-dimensional `torch.Tensor`
|
52 |
+
mean: the mean of the normal distribution
|
53 |
+
std: the standard deviation of the normal distribution
|
54 |
+
a: the minimum cutoff value
|
55 |
+
b: the maximum cutoff value
|
56 |
+
Examples:
|
57 |
+
>>> w = torch.empty(3, 5)
|
58 |
+
>>> nn.init.trunc_normal_(w)
|
59 |
+
"""
|
60 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
pipeline/vit_csra.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Vision Transformer (ViT) in PyTorch
|
2 |
+
|
3 |
+
A PyTorch implement of Vision Transformers as described in
|
4 |
+
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
5 |
+
|
6 |
+
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
7 |
+
|
8 |
+
Status/TODO:
|
9 |
+
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
|
10 |
+
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
|
11 |
+
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
|
12 |
+
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
|
13 |
+
|
14 |
+
Acknowledgments:
|
15 |
+
* The paper authors for releasing code and weights, thanks!
|
16 |
+
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
17 |
+
for some einops/einsum fun
|
18 |
+
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
19 |
+
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
20 |
+
|
21 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
22 |
+
"""
|
23 |
+
import math
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
import torch.nn.functional as F
|
27 |
+
import torch.utils.model_zoo as model_zoo
|
28 |
+
from functools import partial
|
29 |
+
from .timm_utils import DropPath, to_2tuple, trunc_normal_
|
30 |
+
from .csra import MHA, CSRA
|
31 |
+
|
32 |
+
|
33 |
+
default_cfgs = {
|
34 |
+
'vit_base_patch16_224': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
35 |
+
'vit_large_patch16_224':'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth'
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
class Mlp(nn.Module):
|
41 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
42 |
+
super().__init__()
|
43 |
+
out_features = out_features or in_features
|
44 |
+
hidden_features = hidden_features or in_features
|
45 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
46 |
+
self.act = act_layer()
|
47 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
48 |
+
self.drop = nn.Dropout(drop)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
x = self.fc1(x)
|
52 |
+
x = self.act(x)
|
53 |
+
x = self.drop(x)
|
54 |
+
x = self.fc2(x)
|
55 |
+
x = self.drop(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class Attention(nn.Module):
|
60 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
61 |
+
super().__init__()
|
62 |
+
self.num_heads = num_heads
|
63 |
+
head_dim = dim // num_heads # 64
|
64 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
65 |
+
self.scale = qk_scale or head_dim ** -0.5
|
66 |
+
|
67 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
68 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
69 |
+
self.proj = nn.Linear(dim, dim)
|
70 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
B, N, C = x.shape
|
74 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
75 |
+
# qkv (3, B, 12, N, C/12)
|
76 |
+
# q (B, 12, N, C/12)
|
77 |
+
# k (B, 12, N, C/12)
|
78 |
+
# v (B, 12, N, C/12)
|
79 |
+
# attn (B, 12, N, N)
|
80 |
+
# x (B, 12, N, C/12)
|
81 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
82 |
+
|
83 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
84 |
+
attn = attn.softmax(dim=-1)
|
85 |
+
attn = self.attn_drop(attn)
|
86 |
+
|
87 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
88 |
+
|
89 |
+
x = self.proj(x)
|
90 |
+
x = self.proj_drop(x)
|
91 |
+
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class Block(nn.Module):
|
96 |
+
|
97 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
98 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
99 |
+
super().__init__()
|
100 |
+
self.norm1 = norm_layer(dim)
|
101 |
+
self.attn = Attention(
|
102 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
103 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
104 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
105 |
+
self.norm2 = norm_layer(dim)
|
106 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
107 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
111 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
112 |
+
return x
|
113 |
+
|
114 |
+
|
115 |
+
class PatchEmbed(nn.Module):
|
116 |
+
""" Image to Patch Embedding
|
117 |
+
"""
|
118 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
119 |
+
super().__init__()
|
120 |
+
img_size = to_2tuple(img_size)
|
121 |
+
patch_size = to_2tuple(patch_size)
|
122 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
123 |
+
self.img_size = img_size
|
124 |
+
self.patch_size = patch_size
|
125 |
+
self.num_patches = num_patches
|
126 |
+
|
127 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
B, C, H, W = x.shape
|
131 |
+
# FIXME look at relaxing size constraints
|
132 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
133 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
134 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
135 |
+
return x
|
136 |
+
|
137 |
+
|
138 |
+
class HybridEmbed(nn.Module):
|
139 |
+
""" CNN Feature Map Embedding
|
140 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
141 |
+
"""
|
142 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
143 |
+
super().__init__()
|
144 |
+
assert isinstance(backbone, nn.Module)
|
145 |
+
img_size = to_2tuple(img_size)
|
146 |
+
self.img_size = img_size
|
147 |
+
self.backbone = backbone
|
148 |
+
if feature_size is None:
|
149 |
+
with torch.no_grad():
|
150 |
+
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
151 |
+
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
152 |
+
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
153 |
+
training = backbone.training
|
154 |
+
if training:
|
155 |
+
backbone.eval()
|
156 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
157 |
+
feature_size = o.shape[-2:]
|
158 |
+
feature_dim = o.shape[1]
|
159 |
+
backbone.train(training)
|
160 |
+
else:
|
161 |
+
feature_size = to_2tuple(feature_size)
|
162 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
163 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
164 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
x = self.backbone(x)[-1]
|
168 |
+
x = x.flatten(2).transpose(1, 2)
|
169 |
+
x = self.proj(x)
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class VIT_CSRA(nn.Module):
|
174 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
175 |
+
"""
|
176 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
177 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
178 |
+
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, cls_num_heads=1, cls_num_cls=80, lam=0.3):
|
179 |
+
super().__init__()
|
180 |
+
self.add_w = 0.
|
181 |
+
self.normalize = False
|
182 |
+
self.num_classes = num_classes
|
183 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
184 |
+
|
185 |
+
if hybrid_backbone is not None:
|
186 |
+
self.patch_embed = HybridEmbed(
|
187 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
188 |
+
else:
|
189 |
+
self.patch_embed = PatchEmbed(
|
190 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
191 |
+
num_patches = self.patch_embed.num_patches
|
192 |
+
self.HW = int(math.sqrt(num_patches))
|
193 |
+
|
194 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
195 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
196 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
197 |
+
|
198 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
199 |
+
self.blocks = nn.ModuleList([
|
200 |
+
Block(
|
201 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
202 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
203 |
+
for i in range(depth)])
|
204 |
+
self.norm = norm_layer(embed_dim)
|
205 |
+
|
206 |
+
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
207 |
+
#self.repr = nn.Linear(embed_dim, representation_size)
|
208 |
+
#self.repr_act = nn.Tanh()
|
209 |
+
|
210 |
+
trunc_normal_(self.pos_embed, std=.02)
|
211 |
+
trunc_normal_(self.cls_token, std=.02)
|
212 |
+
self.apply(self._init_weights)
|
213 |
+
|
214 |
+
# We add our MHA (CSRA) beside the orginal VIT structure below
|
215 |
+
self.head = nn.Sequential() # delete original classifier
|
216 |
+
self.classifier = MHA(input_dim=embed_dim, num_heads=cls_num_heads, num_classes=cls_num_cls, lam=lam)
|
217 |
+
|
218 |
+
self.loss_func = F.binary_cross_entropy_with_logits
|
219 |
+
|
220 |
+
def _init_weights(self, m):
|
221 |
+
if isinstance(m, nn.Linear):
|
222 |
+
trunc_normal_(m.weight, std=.02)
|
223 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
224 |
+
nn.init.constant_(m.bias, 0)
|
225 |
+
elif isinstance(m, nn.LayerNorm):
|
226 |
+
nn.init.constant_(m.bias, 0)
|
227 |
+
nn.init.constant_(m.weight, 1.0)
|
228 |
+
|
229 |
+
def backbone(self, x):
|
230 |
+
B = x.shape[0]
|
231 |
+
x = self.patch_embed(x)
|
232 |
+
|
233 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
234 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
235 |
+
x = x + self.pos_embed
|
236 |
+
x = self.pos_drop(x)
|
237 |
+
|
238 |
+
for blk in self.blocks:
|
239 |
+
x = blk(x)
|
240 |
+
x = self.norm(x)
|
241 |
+
|
242 |
+
# (B, 1+HW, C)
|
243 |
+
# we use all the feature to form the tensor like B C H W
|
244 |
+
x = x[:, 1:]
|
245 |
+
b, hw, c = x.shape
|
246 |
+
x = x.transpose(1, 2)
|
247 |
+
x = x.reshape(b, c, self.HW, self.HW)
|
248 |
+
|
249 |
+
return x
|
250 |
+
|
251 |
+
def forward_train(self, x, target):
|
252 |
+
x = self.backbone(x)
|
253 |
+
logit = self.classifier(x)
|
254 |
+
loss = self.loss_func(logit, target, reduction="mean")
|
255 |
+
return logit, loss
|
256 |
+
|
257 |
+
def forward_test(self, x):
|
258 |
+
x = self.backbone(x)
|
259 |
+
x = self.classifier(x)
|
260 |
+
return x
|
261 |
+
|
262 |
+
def forward(self, x, target=None):
|
263 |
+
if target is not None:
|
264 |
+
return self.forward_train(x, target)
|
265 |
+
else:
|
266 |
+
return self.forward_test(x)
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
def _conv_filter(state_dict, patch_size=16):
|
272 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
273 |
+
out_dict = {}
|
274 |
+
for k, v in state_dict.items():
|
275 |
+
if 'patch_embed.proj.weight' in k:
|
276 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
277 |
+
out_dict[k] = v
|
278 |
+
return out_dict
|
279 |
+
|
280 |
+
|
281 |
+
def VIT_B16_224_CSRA(pretrained=True, cls_num_heads=1, cls_num_cls=80, lam=0.3):
|
282 |
+
model = VIT_CSRA(
|
283 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
284 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), cls_num_heads=cls_num_heads, cls_num_cls=cls_num_cls, lam=lam)
|
285 |
+
|
286 |
+
model_url = default_cfgs['vit_base_patch16_224']
|
287 |
+
if pretrained:
|
288 |
+
state_dict = model_zoo.load_url(model_url)
|
289 |
+
model.load_state_dict(state_dict, strict=False)
|
290 |
+
return model
|
291 |
+
|
292 |
+
|
293 |
+
def VIT_L16_224_CSRA(pretrained=True, cls_num_heads=1, cls_num_cls=80, lam=0.3):
|
294 |
+
model = VIT_CSRA(
|
295 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
296 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), cls_num_heads=cls_num_heads, cls_num_cls=cls_num_cls, lam=lam)
|
297 |
+
|
298 |
+
model_url = default_cfgs['vit_large_patch16_224']
|
299 |
+
if pretrained:
|
300 |
+
state_dict = model_zoo.load_url(model_url)
|
301 |
+
model.load_state_dict(state_dict, strict=False)
|
302 |
+
# load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
303 |
+
return model
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
Pillow
|
utils/demo_images/000001.jpg
ADDED
utils/demo_images/000002.jpg
ADDED
utils/demo_images/000004.jpg
ADDED
utils/demo_images/000006.jpg
ADDED
utils/demo_images/000007.jpg
ADDED
utils/demo_images/000009.jpg
ADDED
utils/evaluation/cal_PR.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
def json_metric(score_json, target_json, num_classes, types):
|
7 |
+
assert len(score_json) == len(target_json)
|
8 |
+
scores = np.zeros((len(score_json), num_classes))
|
9 |
+
targets = np.zeros((len(target_json), num_classes))
|
10 |
+
for index in range(len(score_json)):
|
11 |
+
scores[index] = score_json[index]["scores"]
|
12 |
+
targets[index] = target_json[index]["target"]
|
13 |
+
|
14 |
+
|
15 |
+
return metric(scores, targets, types)
|
16 |
+
|
17 |
+
def json_metric_top3(score_json, target_json, num_classes, types):
|
18 |
+
assert len(score_json) == len(target_json)
|
19 |
+
scores = np.zeros((len(score_json), num_classes))
|
20 |
+
targets = np.zeros((len(target_json), num_classes))
|
21 |
+
for index in range(len(score_json)):
|
22 |
+
tmp = np.array(score_json[index]['scores'])
|
23 |
+
idx = np.argsort(-tmp)
|
24 |
+
idx_after_3 = idx[3:]
|
25 |
+
tmp[idx_after_3] = 0.
|
26 |
+
|
27 |
+
scores[index] = tmp
|
28 |
+
# scores[index] = score_json[index]["scores"]
|
29 |
+
targets[index] = target_json[index]["target"]
|
30 |
+
|
31 |
+
return metric(scores, targets, types)
|
32 |
+
|
33 |
+
|
34 |
+
def metric(scores, targets, types):
|
35 |
+
"""
|
36 |
+
:param scores: the output the model predict
|
37 |
+
:param targets: the gt label
|
38 |
+
:return: OP, OR, OF1, CP, CR, CF1
|
39 |
+
calculate the Precision of every class by: TP/TP+FP i.e. TP/total predict
|
40 |
+
calculate the Recall by: TP/total GT
|
41 |
+
"""
|
42 |
+
num, num_class = scores.shape
|
43 |
+
gt_num = np.zeros(num_class)
|
44 |
+
tp_num = np.zeros(num_class)
|
45 |
+
predict_num = np.zeros(num_class)
|
46 |
+
|
47 |
+
|
48 |
+
for index in range(num_class):
|
49 |
+
score = scores[:, index]
|
50 |
+
target = targets[:, index]
|
51 |
+
if types == 'wider':
|
52 |
+
tmp = np.where(target == 99)[0]
|
53 |
+
# score[tmp] = 0
|
54 |
+
target[tmp] = 0
|
55 |
+
|
56 |
+
if types == 'voc07':
|
57 |
+
tmp = np.where(target != 0)[0]
|
58 |
+
score = score[tmp]
|
59 |
+
target = target[tmp]
|
60 |
+
neg_id = np.where(target == -1)[0]
|
61 |
+
target[neg_id] = 0
|
62 |
+
|
63 |
+
|
64 |
+
gt_num[index] = np.sum(target == 1)
|
65 |
+
predict_num[index] = np.sum(score >= 0.5)
|
66 |
+
tp_num[index] = np.sum(target * (score >= 0.5))
|
67 |
+
|
68 |
+
predict_num[predict_num == 0] = 1 # avoid dividing 0
|
69 |
+
OP = np.sum(tp_num) / np.sum(predict_num)
|
70 |
+
OR = np.sum(tp_num) / np.sum(gt_num)
|
71 |
+
OF1 = (2 * OP * OR) / (OP + OR)
|
72 |
+
|
73 |
+
#print(tp_num / predict_num)
|
74 |
+
#print(tp_num / gt_num)
|
75 |
+
CP = np.sum(tp_num / predict_num) / num_class
|
76 |
+
CR = np.sum(tp_num / gt_num) / num_class
|
77 |
+
CF1 = (2 * CP * CR) / (CP + CR)
|
78 |
+
|
79 |
+
return OP, OR, OF1, CP, CR, CF1
|
utils/evaluation/cal_mAP.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
def json_map(cls_id, pred_json, ann_json, types):
|
8 |
+
assert len(ann_json) == len(pred_json)
|
9 |
+
num = len(ann_json)
|
10 |
+
predict = np.zeros((num), dtype=np.float64)
|
11 |
+
target = np.zeros((num), dtype=np.float64)
|
12 |
+
|
13 |
+
for i in range(num):
|
14 |
+
predict[i] = pred_json[i]["scores"][cls_id]
|
15 |
+
target[i] = ann_json[i]["target"][cls_id]
|
16 |
+
|
17 |
+
if types == 'wider':
|
18 |
+
tmp = np.where(target != 99)[0]
|
19 |
+
predict = predict[tmp]
|
20 |
+
target = target[tmp]
|
21 |
+
num = len(tmp)
|
22 |
+
|
23 |
+
if types == 'voc07':
|
24 |
+
tmp = np.where(target != 0)[0]
|
25 |
+
predict = predict[tmp]
|
26 |
+
target = target[tmp]
|
27 |
+
neg_id = np.where(target == -1)[0]
|
28 |
+
target[neg_id] = 0
|
29 |
+
num = len(tmp)
|
30 |
+
|
31 |
+
|
32 |
+
tmp = np.argsort(-predict)
|
33 |
+
target = target[tmp]
|
34 |
+
predict = predict[tmp]
|
35 |
+
|
36 |
+
|
37 |
+
pre, obj = 0, 0
|
38 |
+
for i in range(num):
|
39 |
+
if target[i] == 1:
|
40 |
+
obj += 1.0
|
41 |
+
pre += obj / (i+1)
|
42 |
+
pre /= obj
|
43 |
+
return pre
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
utils/evaluation/eval.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
from .cal_mAP import json_map
|
7 |
+
from .cal_PR import json_metric, metric, json_metric_top3
|
8 |
+
|
9 |
+
|
10 |
+
voc_classes = ("aeroplane", "bicycle", "bird", "boat", "bottle",
|
11 |
+
"bus", "car", "cat", "chair", "cow", "diningtable",
|
12 |
+
"dog", "horse", "motorbike", "person", "pottedplant",
|
13 |
+
"sheep", "sofa", "train", "tvmonitor")
|
14 |
+
coco_classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
15 |
+
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
|
16 |
+
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
|
17 |
+
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
|
18 |
+
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
19 |
+
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
|
20 |
+
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
|
21 |
+
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
22 |
+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
|
23 |
+
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
24 |
+
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
|
25 |
+
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
|
26 |
+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
|
27 |
+
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
|
28 |
+
wider_classes = (
|
29 |
+
"Male","longHair","sunglass","Hat","Tshiirt","longSleeve","formal",
|
30 |
+
"shorts","jeans","longPants","skirt","faceMask", "logo","stripe")
|
31 |
+
|
32 |
+
class_dict = {
|
33 |
+
"voc07": voc_classes,
|
34 |
+
"coco": coco_classes,
|
35 |
+
"wider": wider_classes,
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
def evaluation(result, types, ann_path):
|
41 |
+
print("Evaluation")
|
42 |
+
classes = class_dict[types]
|
43 |
+
aps = np.zeros(len(classes), dtype=np.float64)
|
44 |
+
|
45 |
+
ann_json = json.load(open(ann_path, "r"))
|
46 |
+
pred_json = result
|
47 |
+
|
48 |
+
for i, _ in enumerate(tqdm(classes)):
|
49 |
+
ap = json_map(i, pred_json, ann_json, types)
|
50 |
+
aps[i] = ap
|
51 |
+
OP, OR, OF1, CP, CR, CF1 = json_metric(pred_json, ann_json, len(classes), types)
|
52 |
+
print("mAP: {:4f}".format(np.mean(aps)))
|
53 |
+
print("CP: {:4f}, CR: {:4f}, CF1 :{:4F}".format(CP, CR, CF1))
|
54 |
+
print("OP: {:4f}, OR: {:4f}, OF1 {:4F}".format(OP, OR, OF1))
|
55 |
+
|
56 |
+
# I added it here
|
57 |
+
class WarmUpLR(torch.optim.lr_scheduler._LRScheduler):
|
58 |
+
def __init__(self, optimizer, total_iters, last_epoch=-1):
|
59 |
+
self.total_iters = total_iters
|
60 |
+
super().__init__(optimizer, last_epoch=last_epoch)
|
61 |
+
|
62 |
+
def get_lr(self):
|
63 |
+
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
|
64 |
+
|
utils/evaluation/warmUpLR.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class WarmUpLR(torch.optim.lr_scheduler._LRScheduler):
|
5 |
+
def __init__(self, optimizer, total_iters, last_epoch=-1):
|
6 |
+
self.total_iters = total_iters
|
7 |
+
super().__init__(optimizer, last_epoch=last_epoch)
|
8 |
+
|
9 |
+
def get_lr(self):
|
10 |
+
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
|
11 |
+
|
utils/prepare/prepare_coco.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
from pycocotools.coco import COCO
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def make_data(data_path=None, tag="train"):
|
10 |
+
annFile = os.path.join(data_path, "annotations/instances_{}2014.json".format(tag))
|
11 |
+
coco = COCO(annFile)
|
12 |
+
|
13 |
+
img_id = coco.getImgIds()
|
14 |
+
cat_id = coco.getCatIds()
|
15 |
+
img_id = list(sorted(img_id))
|
16 |
+
cat_trans = {}
|
17 |
+
for i in range(len(cat_id)):
|
18 |
+
cat_trans[cat_id[i]] = i
|
19 |
+
|
20 |
+
message = []
|
21 |
+
|
22 |
+
|
23 |
+
for i in img_id:
|
24 |
+
data = {}
|
25 |
+
target = [0] * 80
|
26 |
+
path = ""
|
27 |
+
img_info = coco.loadImgs(i)[0]
|
28 |
+
ann_ids = coco.getAnnIds(imgIds = i)
|
29 |
+
anns = coco.loadAnns(ann_ids)
|
30 |
+
if len(anns) == 0:
|
31 |
+
continue
|
32 |
+
else:
|
33 |
+
for i in range(len(anns)):
|
34 |
+
cls = anns[i]['category_id']
|
35 |
+
cls = cat_trans[cls]
|
36 |
+
target[cls] = 1
|
37 |
+
path = img_info['file_name']
|
38 |
+
data['target'] = target
|
39 |
+
data['img_path'] = os.path.join(os.path.join(data_path, "images/{}2014/".format(tag)), path)
|
40 |
+
message.append(data)
|
41 |
+
|
42 |
+
with open('data/coco/{}_coco2014.json'.format(tag), 'w') as f:
|
43 |
+
json.dump(message, f)
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
# The final json file include: train_coco2014.json & val_coco2014.json
|
48 |
+
# which is the following format:
|
49 |
+
# [item1, item2, item3, ......,]
|
50 |
+
# item1 = {
|
51 |
+
# "target":
|
52 |
+
# "img_path":
|
53 |
+
# }
|
54 |
+
if __name__ == "__main__":
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
# Usage: --data_path /your/dataset/path/COCO2014
|
57 |
+
parser.add_argument("--data_path", default="Dataset/COCO2014/", type=str, help="The absolute path of COCO2014")
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
if not os.path.exists("data/coco"):
|
61 |
+
os.makedirs("data/coco")
|
62 |
+
|
63 |
+
make_data(data_path=args.data_path, tag="train")
|
64 |
+
make_data(data_path=args.data_path, tag="val")
|
65 |
+
|
66 |
+
print("COCO data ready!")
|
67 |
+
print("data/coco/train_coco2014.json, data/coco/val_coco2014.json")
|
utils/prepare/prepare_voc.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import xml.dom.minidom as XML
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
voc_cls_id = {"aeroplane":0, "bicycle":1, "bird":2, "boat":3, "bottle":4,
|
10 |
+
"bus":5, "car":6, "cat":7, "chair":8, "cow":9,
|
11 |
+
"diningtable":10, "dog":11, "horse":12, "motorbike":13, "person":14,
|
12 |
+
"pottedplant":15, "sheep":16, "sofa":17, "train":18, "tvmonitor":19}
|
13 |
+
|
14 |
+
|
15 |
+
def get_label(data_path):
|
16 |
+
print("generating labels for VOC07 dataset")
|
17 |
+
xml_paths = os.path.join(data_path, "VOC2007/Annotations/")
|
18 |
+
save_dir = "data/voc07/labels"
|
19 |
+
|
20 |
+
if not os.path.exists(save_dir):
|
21 |
+
os.makedirs(save_dir)
|
22 |
+
|
23 |
+
for i in os.listdir(xml_paths):
|
24 |
+
if not i.endswith(".xml"):
|
25 |
+
continue
|
26 |
+
s_name = i.split('.')[0] + ".txt"
|
27 |
+
s_dir = os.path.join(save_dir, s_name)
|
28 |
+
xml_path = os.path.join(xml_paths, i)
|
29 |
+
DomTree = XML.parse(xml_path)
|
30 |
+
Root = DomTree.documentElement
|
31 |
+
|
32 |
+
obj_all = Root.getElementsByTagName("object")
|
33 |
+
leng = len(obj_all)
|
34 |
+
cls = []
|
35 |
+
difi_tag = []
|
36 |
+
for obj in obj_all:
|
37 |
+
# get the classes
|
38 |
+
obj_name = obj.getElementsByTagName('name')[0]
|
39 |
+
one_class = obj_name.childNodes[0].data
|
40 |
+
cls.append(voc_cls_id[one_class])
|
41 |
+
|
42 |
+
difficult = obj.getElementsByTagName('difficult')[0]
|
43 |
+
difi_tag.append(difficult.childNodes[0].data)
|
44 |
+
|
45 |
+
for i, c in enumerate(cls):
|
46 |
+
with open(s_dir, "a") as f:
|
47 |
+
f.writelines("%s,%s\n" % (c, difi_tag[i]))
|
48 |
+
|
49 |
+
|
50 |
+
def transdifi(data_path):
|
51 |
+
print("generating final json file for VOC07 dataset")
|
52 |
+
label_dir = "data/voc07/labels/"
|
53 |
+
img_dir = os.path.join(data_path, "VOC2007/JPEGImages/")
|
54 |
+
|
55 |
+
# get trainval test id
|
56 |
+
id_dirs = os.path.join(data_path, "VOC2007/ImageSets/Main/")
|
57 |
+
f_train = open(os.path.join(id_dirs, "train.txt"), "r").readlines()
|
58 |
+
f_val = open(os.path.join(id_dirs, "val.txt"), "r").readlines()
|
59 |
+
f_trainval = f_train + f_val
|
60 |
+
f_test = open(os.path.join(id_dirs, "test.txt"), "r")
|
61 |
+
|
62 |
+
trainval_id = np.sort([int(line.strip()) for line in f_trainval]).tolist()
|
63 |
+
test_id = [int(line.strip()) for line in f_test]
|
64 |
+
trainval_data = []
|
65 |
+
test_data = []
|
66 |
+
|
67 |
+
# ternary label
|
68 |
+
# -1 means negative
|
69 |
+
# 0 means difficult
|
70 |
+
# +1 means positive
|
71 |
+
|
72 |
+
# binary label
|
73 |
+
# 0 means negative
|
74 |
+
# +1 means positive
|
75 |
+
|
76 |
+
# we use binary labels in our implementation
|
77 |
+
|
78 |
+
for item in sorted(os.listdir(label_dir)):
|
79 |
+
with open(os.path.join(label_dir, item), "r") as f:
|
80 |
+
|
81 |
+
target = np.array([-1] * 20)
|
82 |
+
classes = []
|
83 |
+
diffi_tag = []
|
84 |
+
|
85 |
+
for line in f.readlines():
|
86 |
+
cls, tag = map(int, line.strip().split(','))
|
87 |
+
classes.append(cls)
|
88 |
+
diffi_tag.append(tag)
|
89 |
+
|
90 |
+
classes = np.array(classes)
|
91 |
+
diffi_tag = np.array(diffi_tag)
|
92 |
+
for i in range(20):
|
93 |
+
if i in classes:
|
94 |
+
i_index = np.where(classes == i)[0]
|
95 |
+
if len(i_index) == 1:
|
96 |
+
target[i] = 1 - diffi_tag[i_index]
|
97 |
+
else:
|
98 |
+
if len(i_index) == sum(diffi_tag[i_index]):
|
99 |
+
target[i] = 0
|
100 |
+
else:
|
101 |
+
target[i] = 1
|
102 |
+
else:
|
103 |
+
continue
|
104 |
+
img_path = os.path.join(img_dir, item.split('.')[0]+".jpg")
|
105 |
+
|
106 |
+
if int(item.split('.')[0]) in trainval_id:
|
107 |
+
target[target == -1] = 0 # from ternary to binary by treating difficult as negatives
|
108 |
+
data = {"target": target.tolist(), "img_path": img_path}
|
109 |
+
trainval_data.append(data)
|
110 |
+
if int(item.split('.')[0]) in test_id:
|
111 |
+
data = {"target": target.tolist(), "img_path": img_path}
|
112 |
+
test_data.append(data)
|
113 |
+
|
114 |
+
json.dump(trainval_data, open("data/voc07/trainval_voc07.json", "w"))
|
115 |
+
json.dump(test_data, open("data/voc07/test_voc07.json", "w"))
|
116 |
+
print("VOC07 data preparing finished!")
|
117 |
+
print("data/voc07/trainval_voc07.json data/voc07/test_voc07.json")
|
118 |
+
|
119 |
+
# remove label cash
|
120 |
+
for item in os.listdir(label_dir):
|
121 |
+
os.remove(os.path.join(label_dir, item))
|
122 |
+
os.rmdir(label_dir)
|
123 |
+
|
124 |
+
|
125 |
+
# We treat difficult classes in trainval_data as negtive while ignore them in test_data
|
126 |
+
# The ignoring operation can be automatically done during evaluation (testing).
|
127 |
+
# The final json file include: trainval_voc07.json & test_voc07.json
|
128 |
+
# which is the following format:
|
129 |
+
# [item1, item2, item3, ......,]
|
130 |
+
# item1 = {
|
131 |
+
# "target":
|
132 |
+
# "img_path":
|
133 |
+
# }
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
parser = argparse.ArgumentParser()
|
137 |
+
# Usage: --data_path /your/dataset/path/VOCdevkit
|
138 |
+
parser.add_argument("--data_path", default="Dataset/VOCdevkit/", type=str, help="The absolute path of VOCdevkit")
|
139 |
+
args = parser.parse_args()
|
140 |
+
|
141 |
+
if not os.path.exists("data/voc07"):
|
142 |
+
os.makedirs("data/voc07")
|
143 |
+
|
144 |
+
if 'VOCdevkit' not in args.data_path:
|
145 |
+
print("WARNING: please include \'VOCdevkit\' str in your args.data_path")
|
146 |
+
# exit()
|
147 |
+
|
148 |
+
get_label(args.data_path)
|
149 |
+
transdifi(args.data_path)
|
utils/prepare/prepare_wider.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
|
7 |
+
def make_wider(tag, value, data_path):
|
8 |
+
img_path = os.path.join(data_path, "Image")
|
9 |
+
ann_path = os.path.join(data_path, "Annotations")
|
10 |
+
ann_file = os.path.join(ann_path, "wider_attribute_{}.json".format(tag))
|
11 |
+
|
12 |
+
data = json.load(open(ann_file, "r"))
|
13 |
+
|
14 |
+
final = []
|
15 |
+
image_list = data['images']
|
16 |
+
for image in image_list:
|
17 |
+
for person in image["targets"]: # iterate over each person
|
18 |
+
tmp = {}
|
19 |
+
tmp['img_path'] = os.path.join(img_path, image['file_name'])
|
20 |
+
tmp['bbox'] = person['bbox']
|
21 |
+
attr = person["attribute"]
|
22 |
+
for i, item in enumerate(attr):
|
23 |
+
if item == -1:
|
24 |
+
attr[i] = 0
|
25 |
+
if item == 0:
|
26 |
+
attr[i] = value # pad un-specified samples
|
27 |
+
if item == 1:
|
28 |
+
attr[i] = 1
|
29 |
+
tmp["target"] = attr
|
30 |
+
final.append(tmp)
|
31 |
+
|
32 |
+
json.dump(final, open("data/wider/{}_wider.json".format(tag), "w"))
|
33 |
+
print("data/wider/{}_wider.json".format(tag))
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
# which is the following format:
|
38 |
+
# [item1, item2, item3, ......,]
|
39 |
+
# item1 = {
|
40 |
+
# "target":
|
41 |
+
# "img_path":
|
42 |
+
# }
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == "__main__":
|
46 |
+
parser = argparse.ArgumentParser()
|
47 |
+
parser.add_argument("--data_path", default="Dataset/WIDER_ATTRIBUTE", type=str)
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
if not os.path.exists("data/wider"):
|
51 |
+
os.makedirs("data/wider")
|
52 |
+
|
53 |
+
# 0 (zero) means negative, we treat un-specified attribute as negative in the trainval set
|
54 |
+
make_wider(tag='trainval', value=0, data_path=args.data_path)
|
55 |
+
|
56 |
+
# 99 means we ignore un-specified attribute in the test set, following previous work
|
57 |
+
# the number 99 can be properly identified when evaluating mAP
|
58 |
+
make_wider(tag='test', value=99, data_path=args.data_path)
|
utils/visualize.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
def show_cam_on_img(img, mask, img_path_save):
|
11 |
+
heat_map = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
|
12 |
+
heat_map = np.float32(heat_map) / 255
|
13 |
+
|
14 |
+
cam = heat_map + np.float32(img)
|
15 |
+
cam = cam / np.max(cam)
|
16 |
+
cv2.imwrite(img_path_save, np.uint8(255 * cam))
|
17 |
+
|
18 |
+
|
19 |
+
img_path_read = ""
|
20 |
+
img_path_save = ""
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def main():
|
26 |
+
img = cv2.imread(img_path_read, flags=1)
|
27 |
+
|
28 |
+
img = np.float32(cv2.resize(img, (224, 224))) / 255
|
29 |
+
|
30 |
+
# cam_all is the score tensor of shape (B, C, H, W), similar to y_raw in out Figure 1
|
31 |
+
# cls_idx specifying the i-th class out of C class
|
32 |
+
# visualize the 0's class heatmap
|
33 |
+
cls_idx = 0
|
34 |
+
cam = cam_all[cls_idx]
|
35 |
+
|
36 |
+
|
37 |
+
# cam = nn.ReLU()(cam)
|
38 |
+
cam = cam / torch.max(cam)
|
39 |
+
|
40 |
+
cam = cv2.resize(np.array(cam), (224, 224))
|
41 |
+
show_cam_on_img(img, cam, img_path_save)
|
42 |
+
|