File size: 5,503 Bytes
e3e5f9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import math
import torch
import torch.nn as nn
from ..attention import ImgToTriplaneTransformer
import math
from einops import rearrange
class ImgToTriplaneModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
pos_emb_size=32,
pos_emb_dim=1024,
cam_cond_dim=20,
n_heads=16,
d_head=64,
depth=16,
context_dim=768,
triplane_dim=80,
upsample_time=1,
use_fp16=False,
use_bf16=True,
):
super().__init__()
self.pos_emb_size = pos_emb_size
self.pos_emb_dim = pos_emb_dim
# init embedding
self.pos_emb = nn.Parameter(torch.zeros(1, 3 * pos_emb_size * pos_emb_size, pos_emb_dim))
# TODO initialize pos_emb with a Gaussian random of zero-mean and std of 1/sqrt(1024).
# build image to triplane decoder
self.img_to_triplane_decoder = ImgToTriplaneTransformer(
query_dim=pos_emb_dim, n_heads=n_heads,
d_head=d_head, depth=depth, context_dim=context_dim,
triplane_size=pos_emb_size,
)
self.is_conv_upsampler = False
# build upsampler
self.triplane_dim = triplane_dim
if self.is_conv_upsampler:
upsamplers = []
for i in range(upsample_time):
if i == 0:
upsampler = nn.ConvTranspose2d(in_channels=pos_emb_dim, out_channels=triplane_dim,
kernel_size=2, stride=2,
padding=0, output_padding=0)
upsamplers.append(upsampler)
else:
upsampler = nn.ConvTranspose2d(in_channels=triplane_dim, out_channels=triplane_dim,
kernel_size=2, stride=2,
padding=0, output_padding=0)
upsamplers.append(upsampler)
if upsamplers:
self.upsampler = nn.Sequential(*upsamplers)
else:
self.upsampler = nn.Conv2d(in_channels=pos_emb_dim, out_channels=triplane_dim,
kernel_size=3, stride=1, padding=1)
else:
self.upsample_ratio = 4
self.upsampler = nn.Linear(in_features=pos_emb_dim, out_features=triplane_dim*(self.upsample_ratio**2))
def forward(self, x, cam_cond=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
B = x.shape[0]
h = self.pos_emb.expand(B, -1, -1)
context = x
h = self.img_to_triplane_decoder(h, context=context)
h = h.view(B * 3, self.pos_emb_size, self.pos_emb_size, self.pos_emb_dim)
if self.is_conv_upsampler:
h = rearrange(h, 'b h w c -> b c h w')
h = self.upsampler(h)
h = rearrange(h, '(b d) c h w-> b d c h w', d=3)
h = h.type(x.dtype)
return h
else:
h = self.upsampler(h) #[b, h, w, triplane_dim*4]
b, height, width, _ = h.shape
h = h.view(b, height, width, self.triplane_dim, self.upsample_ratio, self.upsample_ratio) #[b, h, w, triplane_dim, 2, 2]
h = h.permute(0,3,1,4,2,5).contiguous() #[b, triplane_dim, h, 2, w, 2]
h = h.view(b, self.triplane_dim, height*self.upsample_ratio, width*self.upsample_ratio)
h = rearrange(h, '(b d) c h w-> b d c h w', d=3)
h = h.type(x.dtype)
return h
|