zxl
first commit
07c6a04
raw
history blame
3.4 kB
# Adapted from Open-Sora-Plan
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
# --------------------------------------------------------
from typing import Tuple, Union
import torch
import torch.nn as nn
from .ops import cast_tuple, video_to_image
class Conv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]] = 3,
stride: Union[int, Tuple[int]] = 1,
padding: Union[str, int, Tuple[int]] = 0,
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
device,
dtype,
)
@video_to_image
def forward(self, x):
return super().forward(x)
class CausalConv3d(nn.Module):
def __init__(
self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
):
super().__init__()
self.kernel_size = cast_tuple(kernel_size, 3)
self.time_kernel_size = self.kernel_size[0]
self.chan_in = chan_in
self.chan_out = chan_out
stride = kwargs.pop("stride", 1)
padding = kwargs.pop("padding", 0)
padding = list(cast_tuple(padding, 3))
padding[0] = 0
stride = cast_tuple(stride, 3)
self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
self._init_weights(init_method)
def _init_weights(self, init_method):
torch.tensor(self.kernel_size)
if init_method == "avg":
assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
eyes = torch.concat(
[
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
],
dim=-1,
)
weight[:, :, :, 0, 0] = eyes
self.conv.weight = nn.Parameter(
weight,
requires_grad=True,
)
elif init_method == "zero":
self.conv.weight = nn.Parameter(
torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
requires_grad=True,
)
if self.conv.bias is not None:
nn.init.constant_(self.conv.bias, 0)
def forward(self, x):
# 1 + 16 16 as video, 1 as image
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
return self.conv(x)