WSSS_ResNet50 / core /puzzle_utils.py
kittendev's picture
Upload 176 files
c20a1af verified
import math
import torch
import torch.nn.functional as F
def tile_features(features, num_pieces):
_, _, h, w = features.size()
num_pieces_per_line = int(math.sqrt(num_pieces))
h_per_patch = h // num_pieces_per_line
w_per_patch = w // num_pieces_per_line
"""
+-----+-----+
| 1 | 2 |
+-----+-----+
| 3 | 4 |
+-----+-----+
+-----+-----+-----+-----+
| 1 | 2 | 3 | 4 |
+-----+-----+-----+-----+
"""
patches = []
for splitted_features in torch.split(features, h_per_patch, dim=2):
for patch in torch.split(splitted_features, w_per_patch, dim=3):
patches.append(patch)
return torch.cat(patches, dim=0)
def merge_features(features, num_pieces, batch_size):
"""
+-----+-----+-----+-----+
| 1 | 2 | 3 | 4 |
+-----+-----+-----+-----+
+-----+-----+
| 1 | 2 |
+-----+-----+
| 3 | 4 |
+-----+-----+
"""
features_list = list(torch.split(features, batch_size))
num_pieces_per_line = int(math.sqrt(num_pieces))
index = 0
ext_h_list = []
for _ in range(num_pieces_per_line):
ext_w_list = []
for _ in range(num_pieces_per_line):
ext_w_list.append(features_list[index])
index += 1
ext_h_list.append(torch.cat(ext_w_list, dim=3))
features = torch.cat(ext_h_list, dim=2)
return features
def puzzle_module(x, func_list, num_pieces):
tiled_x = tile_features(x, num_pieces)
for func in func_list:
tiled_x = func(tiled_x)
merged_x = merge_features(tiled_x, num_pieces, x.size()[0])
return merged_x