|
import torch |
|
from einops import rearrange |
|
|
|
def video_to_image(func): |
|
def wrapper(self, x, *args, **kwargs): |
|
if x.dim() == 5: |
|
t = x.shape[2] |
|
x = rearrange(x, "b c t h w -> (b t) c h w") |
|
x = func(self, x, *args, **kwargs) |
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=t) |
|
return x |
|
return wrapper |
|
|
|
def nonlinearity(x): |
|
return x * torch.sigmoid(x) |
|
|
|
def cast_tuple(t, length=1): |
|
return t if isinstance(t, tuple) else ((t,) * length) |
|
|
|
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): |
|
n_dims = len(x.shape) |
|
if src_dim < 0: |
|
src_dim = n_dims + src_dim |
|
if dest_dim < 0: |
|
dest_dim = n_dims + dest_dim |
|
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims |
|
dims = list(range(n_dims)) |
|
del dims[src_dim] |
|
permutation = [] |
|
ctr = 0 |
|
for i in range(n_dims): |
|
if i == dest_dim: |
|
permutation.append(src_dim) |
|
else: |
|
permutation.append(dims[ctr]) |
|
ctr += 1 |
|
x = x.permute(permutation) |
|
if make_contiguous: |
|
x = x.contiguous() |
|
return x |