LinB203
m
a220803
raw
history blame
3.16 kB
import torch
import torch.nn as nn
from einops import rearrange, pack, unpack
from .normalize import Normalize
from .ops import nonlinearity, video_to_image
from .conv import CausalConv3d
from .block import Block
class ResnetBlock2D(Block):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
@video_to_image
def forward(self, x):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
x = x + h
return x
class ResnetBlock3D(Block):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
else:
self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h