""" --- title: Autoencoder for Stable Diffusion summary: > Annotated PyTorch implementation/tutorial of the autoencoder for stable diffusion. --- # Autoencoder for [Stable Diffusion](../index.html) This implements the auto-encoder model used to map between image space and latent space. We have kept to the model definition and naming unchanged from [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) so that we can load the checkpoints directly. """ from typing import List import torch import torch.nn.functional as F from torch import nn import spaces class Autoencoder(nn.Module): """ ## Autoencoder This consists of the encoder and decoder modules. """ def __init__( self, encoder: "Encoder", decoder: "Decoder", emb_channels: int, z_channels: int ): """ :param encoder: is the encoder :param decoder: is the decoder :param emb_channels: is the number of dimensions in the quantized embedding space :param z_channels: is the number of channels in the embedding space """ super().__init__() self.encoder = encoder self.decoder = decoder # Convolution to map from embedding space to # quantized embedding space moments (mean and log variance) self.quant_conv = nn.Conv2d(2 * z_channels, 2 * emb_channels, 1) # Convolution to map from quantized embedding space back to # embedding space self.post_quant_conv = nn.Conv2d(emb_channels, z_channels, 1) def encode(self, img: torch.Tensor) -> "GaussianDistribution": """ ### Encode images to latent representation :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]` """ # Get embeddings with shape `[batch_size, z_channels * 2, z_height, z_height]` print(f"encoder parameters max: {max([p.max() for p in self.encoder.parameters()])}") print(f"encoder parameters min: {min([p.min() for p in self.encoder.parameters()])}") print(f"img.dtype: {img.dtype}") z = self.encoder(img) print(f"z.max(): {z.max()}, z.min(): {z.min()}") # Get the moments in the quantized embedding space moments = self.quant_conv(z) print(f"moments.max(): {moments.max()}, moments.min(): {moments.min()}") # Return the distribution return GaussianDistribution(moments) def decode(self, z: torch.Tensor): """ ### Decode images from latent representation :param z: is the latent representation with shape `[batch_size, emb_channels, z_height, z_height]` """ # Map to embedding space from the quantized representation z = self.post_quant_conv(z) # Decode the image of shape `[batch_size, channels, height, width]` return self.decoder(z) def forward(self, x): posterior = self.encode(x) z = posterior.sample() dec = self.decode(z) return dec, posterior class Encoder(nn.Module): """ ## Encoder module """ def __init__( self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int, in_channels: int, z_channels: int ): """ :param channels: is the number of channels in the first convolution layer :param channel_multipliers: are the multiplicative factors for the number of channels in the subsequent blocks :param n_resnet_blocks: is the number of resnet layers at each resolution :param in_channels: is the number of channels in the image :param z_channels: is the number of channels in the embedding space """ super().__init__() # Number of blocks of different resolutions. # The resolution is halved at the end each top level block n_resolutions = len(channel_multipliers) # Initial $3 \times 3$ convolution layer that maps the image to `channels` self.conv_in = nn.Conv2d(in_channels, channels, 3, stride=1, padding=1) # Number of channels in each top level block channels_list = [m * channels for m in [1] + channel_multipliers] # List of top-level blocks self.down = nn.ModuleList() # Create top-level blocks for i in range(n_resolutions): # Each top level block consists of multiple ResNet Blocks and down-sampling resnet_blocks = nn.ModuleList() # Add ResNet Blocks for _ in range(n_resnet_blocks): resnet_blocks.append(ResnetBlock(channels, channels_list[i + 1])) channels = channels_list[i + 1] # Top-level block down = nn.Module() down.block = resnet_blocks # Down-sampling at the end of each top level block except the last if i != n_resolutions - 1: down.downsample = DownSample(channels) else: down.downsample = nn.Identity() # self.down.append(down) # Final ResNet blocks with attention self.mid = nn.Module() self.mid.block_1 = ResnetBlock(channels, channels) self.mid.attn_1 = AttnBlock(channels) self.mid.block_2 = ResnetBlock(channels, channels) # Map to embedding space with a $3 \times 3$ convolution self.norm_out = normalization(channels) self.conv_out = nn.Conv2d(channels, 2 * z_channels, 3, stride=1, padding=1) def forward(self, img: torch.Tensor): """ :param img: is the image tensor with shape `[batch_size, img_channels, img_height, img_width]` """ # Map to `channels` with the initial convolution x = self.conv_in(img) # Top-level blocks for down in self.down: # ResNet Blocks for block in down.block: x = block(x) # Down-sampling x = down.downsample(x) # Final ResNet blocks with attention x = self.mid.block_1(x) x = self.mid.attn_1(x) x = self.mid.block_2(x) # Normalize and map to embedding space x = self.norm_out(x) x = swish(x) x = self.conv_out(x) # return x class Decoder(nn.Module): """ ## Decoder module """ def __init__( self, *, channels: int, channel_multipliers: List[int], n_resnet_blocks: int, out_channels: int, z_channels: int ): """ :param channels: is the number of channels in the final convolution layer :param channel_multipliers: are the multiplicative factors for the number of channels in the previous blocks, in reverse order :param n_resnet_blocks: is the number of resnet layers at each resolution :param out_channels: is the number of channels in the image :param z_channels: is the number of channels in the embedding space """ super().__init__() # Number of blocks of different resolutions. # The resolution is halved at the end each top level block num_resolutions = len(channel_multipliers) # Number of channels in each top level block, in the reverse order channels_list = [m * channels for m in channel_multipliers] # Number of channels in the top-level block channels = channels_list[-1] # Initial $3 \times 3$ convolution layer that maps the embedding space to `channels` self.conv_in = nn.Conv2d(z_channels, channels, 3, stride=1, padding=1) # ResNet blocks with attention self.mid = nn.Module() self.mid.block_1 = ResnetBlock(channels, channels) self.mid.attn_1 = AttnBlock(channels) self.mid.block_2 = ResnetBlock(channels, channels) # List of top-level blocks self.up = nn.ModuleList() # Create top-level blocks for i in reversed(range(num_resolutions)): # Each top level block consists of multiple ResNet Blocks and up-sampling resnet_blocks = nn.ModuleList() # Add ResNet Blocks for _ in range(n_resnet_blocks + 1): resnet_blocks.append(ResnetBlock(channels, channels_list[i])) channels = channels_list[i] # Top-level block up = nn.Module() up.block = resnet_blocks # Up-sampling at the end of each top level block except the first if i != 0: up.upsample = UpSample(channels) else: up.upsample = nn.Identity() # Prepend to be consistent with the checkpoint self.up.insert(0, up) # Map to image space with a $3 \times 3$ convolution self.norm_out = normalization(channels) self.conv_out = nn.Conv2d(channels, out_channels, 3, stride=1, padding=1) def forward(self, z: torch.Tensor): """ :param z: is the embedding tensor with shape `[batch_size, z_channels, z_height, z_height]` """ # Map to `channels` with the initial convolution h = self.conv_in(z) # ResNet blocks with attention h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # Top-level blocks for up in reversed(self.up): # ResNet Blocks for block in up.block: h = block(h) # Up-sampling h = up.upsample(h) # Normalize and map to image space h = self.norm_out(h) h = swish(h) img = self.conv_out(h) # return img class GaussianDistribution: """ ## Gaussian Distribution """ def __init__(self, parameters: torch.Tensor): """ :param parameters: are the means and log of variances of the embedding of shape `[batch_size, z_channels * 2, z_height, z_height]` """ # Split mean and log of variance print(f"parameters.max(): {parameters.max()}, parameters.min(): {parameters.min()}") self.mean, log_var = torch.chunk(parameters, 2, dim=1) # Clamp the log of variances self.log_var = torch.clamp(log_var, -30.0, 20.0) # Calculate standard deviation self.std = torch.exp(0.5 * self.log_var) self.var = torch.exp(self.log_var) def sample(self): # Sample from the distribution print(f"self.mean.max(): {self.mean.max()}, self.mean.min(): {self.mean.min()}") print(f"self.std.max(): {self.std.max()}, self.std.min(): {self.std.min()}") return self.mean + self.std * torch.randn_like(self.std) def kl(self): return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.log_var, dim=[1, 2, 3] ) class AttnBlock(nn.Module): """ ## Attention block """ def __init__(self, channels: int): """ :param channels: is the number of channels """ super().__init__() # Group normalization self.norm = normalization(channels) # Query, key and value mappings self.q = nn.Conv2d(channels, channels, 1) self.k = nn.Conv2d(channels, channels, 1) self.v = nn.Conv2d(channels, channels, 1) # Final $1 \times 1$ convolution layer self.proj_out = nn.Conv2d(channels, channels, 1) # Attention scaling factor self.scale = channels**-0.5 def forward(self, x: torch.Tensor): """ :param x: is the tensor of shape `[batch_size, channels, height, width]` """ # Normalize `x` x_norm = self.norm(x) # Get query, key and vector embeddings q = self.q(x_norm) k = self.k(x_norm) v = self.v(x_norm) # Reshape to query, key and vector embeedings from # `[batch_size, channels, height, width]` to # `[batch_size, channels, height * width]` b, c, h, w = q.shape q = q.view(b, c, h * w) k = k.view(b, c, h * w) v = v.view(b, c, h * w) # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$ attn = torch.einsum("bci,bcj->bij", q, k) * self.scale attn = F.softmax(attn, dim=2) # Compute $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$ out = torch.einsum("bij,bcj->bci", attn, v) # Reshape back to `[batch_size, channels, height, width]` out = out.view(b, c, h, w) # Final $1 \times 1$ convolution layer out = self.proj_out(out) # Add residual connection return x + out class UpSample(nn.Module): """ ## Up-sampling layer """ def __init__(self, channels: int): """ :param channels: is the number of channels """ super().__init__() # $3 \times 3$ convolution mapping self.conv = nn.Conv2d(channels, channels, 3, padding=1) def forward(self, x: torch.Tensor): """ :param x: is the input feature map with shape `[batch_size, channels, height, width]` """ # Up-sample by a factor of $2$ x = F.interpolate(x, scale_factor=2.0, mode="nearest") # Apply convolution return self.conv(x) class DownSample(nn.Module): """ ## Down-sampling layer """ def __init__(self, channels: int): """ :param channels: is the number of channels """ super().__init__() # $3 \times 3$ convolution with stride length of $2$ to down-sample by a factor of $2$ self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=0) def forward(self, x: torch.Tensor): """ :param x: is the input feature map with shape `[batch_size, channels, height, width]` """ # Add padding x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0) # Apply convolution return self.conv(x) class ResnetBlock(nn.Module): """ ## ResNet Block """ def __init__(self, in_channels: int, out_channels: int): """ :param in_channels: is the number of channels in the input :param out_channels: is the number of channels in the output """ super().__init__() # First normalization and convolution layer self.norm1 = normalization(in_channels) self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1) # Second normalization and convolution layer self.norm2 = normalization(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1) # `in_channels` to `out_channels` mapping layer for residual connection if in_channels != out_channels: self.nin_shortcut = nn.Conv2d( in_channels, out_channels, 1, stride=1, padding=0 ) else: self.nin_shortcut = nn.Identity() def forward(self, x: torch.Tensor): """ :param x: is the input feature map with shape `[batch_size, channels, height, width]` """ h = x # First normalization and convolution layer h = self.norm1(h) h = swish(h) h = self.conv1(h) # Second normalization and convolution layer h = self.norm2(h) h = swish(h) h = self.conv2(h) # Map and add residual return self.nin_shortcut(x) + h def swish(x: torch.Tensor): """ ### Swish activation """ return x * torch.sigmoid(x) def normalization(channels: int): """ ### Group normalization This is a helper function, with fixed number of groups and `eps`. """ return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6) def restore_ae_from_sd(model, path): def remove_prefix(text, prefix): if text.startswith(prefix): return text[len(prefix) :] return text checkpoint = torch.load(path) # checkpoint = torch.load(path, map_location="cpu") ckpt_state_dict = checkpoint["state_dict"] new_ckpt_state_dict = {} for k, v in ckpt_state_dict.items(): new_k = remove_prefix(k, "first_stage_model.") new_ckpt_state_dict[new_k] = v missing_keys, extra_keys = model.load_state_dict(new_ckpt_state_dict, strict=False) assert len(missing_keys) == 0 def create_model(in_channels, out_channels, latent_dim=4): encoder = Encoder( z_channels=latent_dim, in_channels=in_channels, channels=128, channel_multipliers=[1, 2, 4, 4], n_resnet_blocks=2, ) decoder = Decoder( out_channels=out_channels, z_channels=latent_dim, channels=128, channel_multipliers=[1, 2, 4, 4], n_resnet_blocks=2, ) autoencoder = Autoencoder( emb_channels=latent_dim, encoder=encoder, decoder=decoder, z_channels=latent_dim ) return autoencoder