Spaces:
Sleeping
Sleeping
# Adapted from Open-Sora-Plan | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan | |
# -------------------------------------------------------- | |
import torch | |
from einops import rearrange | |
def video_to_image(func): | |
def wrapper(self, x, *args, **kwargs): | |
if x.dim() == 5: | |
t = x.shape[2] | |
x = rearrange(x, "b c t h w -> (b t) c h w") | |
x = func(self, x, *args, **kwargs) | |
x = rearrange(x, "(b t) c h w -> b c t h w", t=t) | |
return x | |
return wrapper | |
def nonlinearity(x): | |
return x * torch.sigmoid(x) | |
def cast_tuple(t, length=1): | |
return t if isinstance(t, tuple) else ((t,) * length) | |
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): | |
n_dims = len(x.shape) | |
if src_dim < 0: | |
src_dim = n_dims + src_dim | |
if dest_dim < 0: | |
dest_dim = n_dims + dest_dim | |
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims | |
dims = list(range(n_dims)) | |
del dims[src_dim] | |
permutation = [] | |
ctr = 0 | |
for i in range(n_dims): | |
if i == dest_dim: | |
permutation.append(src_dim) | |
else: | |
permutation.append(dims[ctr]) | |
ctr += 1 | |
x = x.permute(permutation) | |
if make_contiguous: | |
x = x.contiguous() | |
return x | |