|
import math |
|
import torch |
|
import torch.nn as nn |
|
from ..attention import ImgToTriplaneTransformer |
|
import math |
|
from einops import rearrange |
|
|
|
|
|
class ImgToTriplaneModel(nn.Module): |
|
""" |
|
The full UNet model with attention and timestep embedding. |
|
:param in_channels: channels in the input Tensor. |
|
:param model_channels: base channel count for the model. |
|
:param out_channels: channels in the output Tensor. |
|
:param num_res_blocks: number of residual blocks per downsample. |
|
:param attention_resolutions: a collection of downsample rates at which |
|
attention will take place. May be a set, list, or tuple. |
|
For example, if this contains 4, then at 4x downsampling, attention |
|
will be used. |
|
:param dropout: the dropout probability. |
|
:param channel_mult: channel multiplier for each level of the UNet. |
|
:param conv_resample: if True, use learned convolutions for upsampling and |
|
downsampling. |
|
:param dims: determines if the signal is 1D, 2D, or 3D. |
|
:param num_classes: if specified (as an int), then this model will be |
|
class-conditional with `num_classes` classes. |
|
:param use_checkpoint: use gradient checkpointing to reduce memory usage. |
|
:param num_heads: the number of attention heads in each attention layer. |
|
:param num_heads_channels: if specified, ignore num_heads and instead use |
|
a fixed channel width per attention head. |
|
:param num_heads_upsample: works with num_heads to set a different number |
|
of heads for upsampling. Deprecated. |
|
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism. |
|
:param resblock_updown: use residual blocks for up/downsampling. |
|
:param use_new_attention_order: use a different attention pattern for potentially |
|
increased efficiency. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
pos_emb_size=32, |
|
pos_emb_dim=1024, |
|
cam_cond_dim=20, |
|
n_heads=16, |
|
d_head=64, |
|
depth=16, |
|
context_dim=768, |
|
triplane_dim=80, |
|
upsample_time=1, |
|
use_fp16=False, |
|
use_bf16=True, |
|
): |
|
super().__init__() |
|
|
|
self.pos_emb_size = pos_emb_size |
|
self.pos_emb_dim = pos_emb_dim |
|
|
|
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, 3 * pos_emb_size * pos_emb_size, pos_emb_dim)) |
|
|
|
|
|
|
|
self.img_to_triplane_decoder = ImgToTriplaneTransformer( |
|
query_dim=pos_emb_dim, n_heads=n_heads, |
|
d_head=d_head, depth=depth, context_dim=context_dim, |
|
triplane_size=pos_emb_size, |
|
) |
|
|
|
self.is_conv_upsampler = False |
|
|
|
self.triplane_dim = triplane_dim |
|
if self.is_conv_upsampler: |
|
upsamplers = [] |
|
for i in range(upsample_time): |
|
if i == 0: |
|
upsampler = nn.ConvTranspose2d(in_channels=pos_emb_dim, out_channels=triplane_dim, |
|
kernel_size=2, stride=2, |
|
padding=0, output_padding=0) |
|
upsamplers.append(upsampler) |
|
else: |
|
upsampler = nn.ConvTranspose2d(in_channels=triplane_dim, out_channels=triplane_dim, |
|
kernel_size=2, stride=2, |
|
padding=0, output_padding=0) |
|
upsamplers.append(upsampler) |
|
if upsamplers: |
|
self.upsampler = nn.Sequential(*upsamplers) |
|
else: |
|
self.upsampler = nn.Conv2d(in_channels=pos_emb_dim, out_channels=triplane_dim, |
|
kernel_size=3, stride=1, padding=1) |
|
else: |
|
self.upsample_ratio = 4 |
|
self.upsampler = nn.Linear(in_features=pos_emb_dim, out_features=triplane_dim*(self.upsample_ratio**2)) |
|
|
|
|
|
|
|
def forward(self, x, cam_cond=None, **kwargs): |
|
""" |
|
Apply the model to an input batch. |
|
:param x: an [N x C x ...] Tensor of inputs. |
|
:param timesteps: a 1-D batch of timesteps. |
|
:param context: conditioning plugged in via crossattn |
|
:param y: an [N] Tensor of labels, if class-conditional. |
|
:return: an [N x C x ...] Tensor of outputs. |
|
""" |
|
|
|
B = x.shape[0] |
|
h = self.pos_emb.expand(B, -1, -1) |
|
context = x |
|
|
|
h = self.img_to_triplane_decoder(h, context=context) |
|
|
|
h = h.view(B * 3, self.pos_emb_size, self.pos_emb_size, self.pos_emb_dim) |
|
if self.is_conv_upsampler: |
|
h = rearrange(h, 'b h w c -> b c h w') |
|
h = self.upsampler(h) |
|
h = rearrange(h, '(b d) c h w-> b d c h w', d=3) |
|
h = h.type(x.dtype) |
|
return h |
|
else: |
|
h = self.upsampler(h) |
|
b, height, width, _ = h.shape |
|
h = h.view(b, height, width, self.triplane_dim, self.upsample_ratio, self.upsample_ratio) |
|
h = h.permute(0,3,1,4,2,5).contiguous() |
|
h = h.view(b, self.triplane_dim, height*self.upsample_ratio, width*self.upsample_ratio) |
|
h = rearrange(h, '(b d) c h w-> b d c h w', d=3) |
|
h = h.type(x.dtype) |
|
return h |
|
|