import torch import torch.nn as nn from einops import rearrange, pack, unpack from .normalize import Normalize from .ops import nonlinearity, video_to_image from .conv import CausalConv3d from .block import Block class ResnetBlock2D(Block): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) @video_to_image def forward(self, x): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) x = x + h return x class ResnetBlock3D(Block): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): super().__init__() self.in_channels = in_channels self.out_channels = in_channels if out_channels is None else out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1) else: self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0) def forward(self, x): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h