Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn.functional as F | |
def tile_features(features, num_pieces): | |
_, _, h, w = features.size() | |
num_pieces_per_line = int(math.sqrt(num_pieces)) | |
h_per_patch = h // num_pieces_per_line | |
w_per_patch = w // num_pieces_per_line | |
""" | |
+-----+-----+ | |
| 1 | 2 | | |
+-----+-----+ | |
| 3 | 4 | | |
+-----+-----+ | |
+-----+-----+-----+-----+ | |
| 1 | 2 | 3 | 4 | | |
+-----+-----+-----+-----+ | |
""" | |
patches = [] | |
for splitted_features in torch.split(features, h_per_patch, dim=2): | |
for patch in torch.split(splitted_features, w_per_patch, dim=3): | |
patches.append(patch) | |
return torch.cat(patches, dim=0) | |
def merge_features(features, num_pieces, batch_size): | |
""" | |
+-----+-----+-----+-----+ | |
| 1 | 2 | 3 | 4 | | |
+-----+-----+-----+-----+ | |
+-----+-----+ | |
| 1 | 2 | | |
+-----+-----+ | |
| 3 | 4 | | |
+-----+-----+ | |
""" | |
features_list = list(torch.split(features, batch_size)) | |
num_pieces_per_line = int(math.sqrt(num_pieces)) | |
index = 0 | |
ext_h_list = [] | |
for _ in range(num_pieces_per_line): | |
ext_w_list = [] | |
for _ in range(num_pieces_per_line): | |
ext_w_list.append(features_list[index]) | |
index += 1 | |
ext_h_list.append(torch.cat(ext_w_list, dim=3)) | |
features = torch.cat(ext_h_list, dim=2) | |
return features | |
def puzzle_module(x, func_list, num_pieces): | |
tiled_x = tile_features(x, num_pieces) | |
for func in func_list: | |
tiled_x = func(tiled_x) | |
merged_x = merge_features(tiled_x, num_pieces, x.size()[0]) | |
return merged_x | |