Spaces:
Sleeping
Sleeping
File size: 1,754 Bytes
c20a1af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import math
import torch
import torch.nn.functional as F
def tile_features(features, num_pieces):
_, _, h, w = features.size()
num_pieces_per_line = int(math.sqrt(num_pieces))
h_per_patch = h // num_pieces_per_line
w_per_patch = w // num_pieces_per_line
"""
+-----+-----+
| 1 | 2 |
+-----+-----+
| 3 | 4 |
+-----+-----+
+-----+-----+-----+-----+
| 1 | 2 | 3 | 4 |
+-----+-----+-----+-----+
"""
patches = []
for splitted_features in torch.split(features, h_per_patch, dim=2):
for patch in torch.split(splitted_features, w_per_patch, dim=3):
patches.append(patch)
return torch.cat(patches, dim=0)
def merge_features(features, num_pieces, batch_size):
"""
+-----+-----+-----+-----+
| 1 | 2 | 3 | 4 |
+-----+-----+-----+-----+
+-----+-----+
| 1 | 2 |
+-----+-----+
| 3 | 4 |
+-----+-----+
"""
features_list = list(torch.split(features, batch_size))
num_pieces_per_line = int(math.sqrt(num_pieces))
index = 0
ext_h_list = []
for _ in range(num_pieces_per_line):
ext_w_list = []
for _ in range(num_pieces_per_line):
ext_w_list.append(features_list[index])
index += 1
ext_h_list.append(torch.cat(ext_w_list, dim=3))
features = torch.cat(ext_h_list, dim=2)
return features
def puzzle_module(x, func_list, num_pieces):
tiled_x = tile_features(x, num_pieces)
for func in func_list:
tiled_x = func(tiled_x)
merged_x = merge_features(tiled_x, num_pieces, x.size()[0])
return merged_x
|