|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import repeat |
|
from torch import nn |
|
|
|
from .transformer import MultiviewTransformer |
|
|
|
|
|
def timestep_embedding( |
|
timesteps: torch.Tensor, |
|
dim: int, |
|
max_period: int = 10000, |
|
repeat_only: bool = False, |
|
) -> torch.Tensor: |
|
if not repeat_only: |
|
half = dim // 2 |
|
freqs = torch.exp( |
|
-math.log(max_period) |
|
* torch.arange(start=0, end=half, dtype=torch.float32) |
|
/ half |
|
).to(device=timesteps.device) |
|
args = timesteps[:, None].float() * freqs[None] |
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
if dim % 2: |
|
embedding = torch.cat( |
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1 |
|
) |
|
else: |
|
embedding = repeat(timesteps, "b -> b d", d=dim) |
|
return embedding |
|
|
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, channels: int, out_channels: int | None = None): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, 1, 1) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
assert x.shape[1] == self.channels |
|
x = F.interpolate(x, scale_factor=2, mode="nearest") |
|
x = self.conv(x) |
|
return x |
|
|
|
|
|
class Downsample(nn.Module): |
|
def __init__(self, channels: int, out_channels: int | None = None): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.op = nn.Conv2d(self.channels, self.out_channels, 3, 2, 1) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
assert x.shape[1] == self.channels |
|
return self.op(x) |
|
|
|
|
|
class GroupNorm32(nn.GroupNorm): |
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return super().forward(input.float()).type(input.dtype) |
|
|
|
|
|
class TimestepEmbedSequential(nn.Sequential): |
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
emb: torch.Tensor, |
|
context: torch.Tensor, |
|
dense_emb: torch.Tensor, |
|
num_frames: int, |
|
) -> torch.Tensor: |
|
for layer in self: |
|
if isinstance(layer, MultiviewTransformer): |
|
assert num_frames is not None |
|
x = layer(x, context, num_frames) |
|
elif isinstance(layer, ResBlock): |
|
x = layer(x, emb, dense_emb) |
|
else: |
|
x = layer(x) |
|
return x |
|
|
|
|
|
class ResBlock(nn.Module): |
|
def __init__( |
|
self, |
|
channels: int, |
|
emb_channels: int, |
|
out_channels: int | None, |
|
dense_in_channels: int, |
|
dropout: float, |
|
): |
|
super().__init__() |
|
out_channels = out_channels or channels |
|
|
|
self.in_layers = nn.Sequential( |
|
GroupNorm32(32, channels), |
|
nn.SiLU(), |
|
nn.Conv2d(channels, out_channels, 3, 1, 1), |
|
) |
|
self.emb_layers = nn.Sequential( |
|
nn.SiLU(), nn.Linear(emb_channels, out_channels) |
|
) |
|
self.dense_emb_layers = nn.Sequential( |
|
nn.Conv2d(dense_in_channels, 2 * channels, 1, 1, 0) |
|
) |
|
self.out_layers = nn.Sequential( |
|
GroupNorm32(32, out_channels), |
|
nn.SiLU(), |
|
nn.Dropout(dropout), |
|
nn.Conv2d(out_channels, out_channels, 3, 1, 1), |
|
) |
|
if out_channels == channels: |
|
self.skip_connection = nn.Identity() |
|
else: |
|
self.skip_connection = nn.Conv2d(channels, out_channels, 1, 1, 0) |
|
|
|
def forward( |
|
self, x: torch.Tensor, emb: torch.Tensor, dense_emb: torch.Tensor |
|
) -> torch.Tensor: |
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] |
|
h = in_rest(x) |
|
dense = self.dense_emb_layers( |
|
F.interpolate( |
|
dense_emb, size=h.shape[2:], mode="bilinear", align_corners=True |
|
) |
|
).type(h.dtype) |
|
dense_scale, dense_shift = torch.chunk(dense, 2, dim=1) |
|
h = h * (1 + dense_scale) + dense_shift |
|
h = in_conv(h) |
|
emb_out = self.emb_layers(emb).type(h.dtype) |
|
|
|
while len(emb_out.shape) < len(h.shape): |
|
emb_out = emb_out[..., None] |
|
h = h + emb_out |
|
h = self.out_layers(h) |
|
h = self.skip_connection(x) + h |
|
return h |
|
|