Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,654 Bytes
5004324 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from torch.nn import Identity
from einops import rearrange
def exist(item):
return item is not None
def set_default_item(condition, item_1, item_2=None):
if condition:
return item_1
else:
return item_2
def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=Identity, args_2=[], kwargs_2={}):
if condition:
return layer_1(*args_1, **kwargs_1)
else:
return layer_2(*args_2, **kwargs_2)
def get_tensor_items(x, pos, broadcast_shape):
device = pos.device
bs = pos.shape[0]
ndims = len(broadcast_shape[1:])
x = x.cpu()[pos.cpu()]
return x.reshape(bs, *((1,) * ndims)).to(device)
def local_patching(x, height, width, group_size):
if group_size > 0:
x = rearrange(
x, 'b c (h g1) (w g2) -> b (h w) (g1 g2) c',
h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
)
else:
x = rearrange(x, 'b c h w -> b (h w) c', h=height, w=width)
return x
def local_merge(x, height, width, group_size):
if group_size > 0:
x = rearrange(
x, 'b (h w) (g1 g2) c -> b c (h g1) (w g2)',
h=height//group_size, w=width//group_size, g1=group_size, g2=group_size
)
else:
x = rearrange(x, 'b (h w) c -> b c h w', h=height, w=width)
return x
def global_patching(x, height, width, group_size):
x = local_patching(x, height, width, height//group_size)
x = x.transpose(-2, -3)
return x
def global_merge(x, height, width, group_size):
x = x.transpose(-2, -3)
x = local_merge(x, height, width, height//group_size)
return x
|