diff --git a/WavLM.py b/WavLM.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e1adf5dfeaa610bd5b1b15838ac426f5e3dcf5c
--- /dev/null
+++ b/WavLM.py
@@ -0,0 +1,854 @@
+# --------------------------------------------------------
+# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
+# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import logging
+from typing import List, Optional, Tuple
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import LayerNorm
+from einops import rearrange
+import requests
+from clint.textui import progress
+import os
+from WavLM_modules import (
+ Fp32GroupNorm,
+ Fp32LayerNorm,
+ GradMultiply,
+ MultiheadAttention,
+ SamePad,
+ init_bert_params,
+ get_activation_fn,
+ TransposeLast,
+ GLU_Linear,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class WavLM_wrapper(nn.Module):
+ def __init__(
+ self, model_size="Base+", feed_as_frames=True, merge_type="cat", model_path=None
+ ):
+ super().__init__()
+ assert model_size in ["Base+", "Large"]
+ if model_path is None:
+ model_path = os.path.join(
+ os.path.dirname(__file__), f"WavLM-{model_size}.pt"
+ )
+ if not os.path.exists(model_path):
+ self.download_model(model_path, model_size)
+ checkpoint = torch.load(model_path)
+ cfg = WavLMConfig(checkpoint["cfg"])
+ self.cfg = cfg
+ self.model = WavLM(cfg)
+ self.model.load_state_dict(checkpoint["model"])
+ self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+ self.code_size = 768 * 2 if merge_type == "cat" else 768
+ self.merge_type = merge_type
+ self.feed_as_frames = feed_as_frames
+
+ def download_model(self, out_path, size: str = "Base+"):
+ print("Downloading model...")
+ if size == "Base+":
+ url = "https://valle.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D"
+ else:
+ url = "https://valle.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D"
+ r = requests.get(url, allow_redirects=True, stream=True)
+ with open(out_path, "wb") as f:
+ total_length = int(r.headers.get("content-length"))
+ for chunk in progress.bar(
+ r.iter_content(chunk_size=1024), expected_size=(total_length / 1024) + 1
+ ):
+ if chunk:
+ f.write(chunk)
+ f.flush()
+ print("Model downloaded to %s" % out_path)
+
+ def forward(self, x):
+ """
+ Args:
+ x: (batch, n_frames, audio_features)
+ """
+ T = x.shape[1]
+
+ if self.feed_as_frames:
+ x = rearrange(x, "b f d -> (b f) d")
+ else:
+ x = rearrange(x, "b ... -> b (...)")
+
+ if self.cfg.normalize:
+ x = torch.nn.functional.layer_norm(x, x.shape)
+
+ x = self.model.extract_features(x)[0] # B, new_features, C
+ if self.feed_as_frames:
+ x = rearrange(x, "(b f) d c -> b f d c", f=T)
+ else:
+ x = torch.nn.functional.interpolate(
+ x.permute(0, 2, 1), T * 2, mode="nearest"
+ )
+ x = rearrange(x, "b c (f d) -> b f d c", d=2)
+
+ if self.merge_type == "cat":
+ if x.dim() == 3:
+ return rearrange(x, "b d c -> b (d c)")
+ return rearrange(x, "b f d c -> b f (d c)")
+ elif self.merge_type == "sum":
+ return x.sum(dim=-2)
+ elif self.merge_type == "mean":
+ return x.mean(dim=-2)
+ elif self.merge_type == "None":
+ return x
+ else:
+ raise NotImplementedError
+
+
+def compute_mask_indices(
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape
+ Args:
+ shape: the the shape for which to compute masks.
+ should be of size 2 where first element is batch size and 2nd is timesteps
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
+ mask_type: how to compute mask lengths
+ static = fixed size
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
+ poisson = sample from possion distribution with lambda = mask length
+ min_masks: minimum number of masked spans
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
+ """
+
+ bsz, all_sz = shape
+ mask = np.full((bsz, all_sz), False)
+
+ all_num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * all_sz / float(mask_length) + np.random.rand()
+ )
+
+ all_num_mask = max(min_masks, all_num_mask)
+
+ mask_idcs = []
+ for i in range(bsz):
+ if padding_mask is not None:
+ sz = all_sz - padding_mask[i].long().sum().item()
+ num_mask = int(
+ # add a random number for probabilistic rounding
+ mask_prob * sz / float(mask_length) + np.random.rand()
+ )
+ num_mask = max(min_masks, num_mask)
+ else:
+ sz = all_sz
+ num_mask = all_num_mask
+
+ if mask_type == "static":
+ lengths = np.full(num_mask, mask_length)
+ elif mask_type == "uniform":
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
+ elif mask_type == "normal":
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
+ lengths = [max(1, int(round(x))) for x in lengths]
+ elif mask_type == "poisson":
+ lengths = np.random.poisson(mask_length, size=num_mask)
+ lengths = [int(round(x)) for x in lengths]
+ else:
+ raise Exception("unknown mask selection " + mask_type)
+
+ if sum(lengths) == 0:
+ lengths[0] = min(mask_length, sz - 1)
+
+ if no_overlap:
+ mask_idc = []
+
+ def arrange(s, e, length, keep_length):
+ span_start = np.random.randint(s, e - length)
+ mask_idc.extend(span_start + i for i in range(length))
+
+ new_parts = []
+ if span_start - s - min_space >= keep_length:
+ new_parts.append((s, span_start - min_space + 1))
+ if e - span_start - keep_length - min_space > keep_length:
+ new_parts.append((span_start + length + min_space, e))
+ return new_parts
+
+ parts = [(0, sz)]
+ min_length = min(lengths)
+ for length in sorted(lengths, reverse=True):
+ lens = np.fromiter(
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
+ np.int,
+ )
+ l_sum = np.sum(lens)
+ if l_sum == 0:
+ break
+ probs = lens / np.sum(lens)
+ c = np.random.choice(len(parts), p=probs)
+ s, e = parts.pop(c)
+ parts.extend(arrange(s, e, length, min_length))
+ mask_idc = np.asarray(mask_idc)
+ else:
+ min_len = min(lengths)
+ if sz - min_len <= num_mask:
+ min_len = sz - num_mask - 1
+
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
+
+ mask_idc = np.asarray(
+ [
+ mask_idc[j] + offset
+ for j in range(len(mask_idc))
+ for offset in range(lengths[j])
+ ]
+ )
+
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
+
+ min_len = min([len(m) for m in mask_idcs])
+ for i, mask_idc in enumerate(mask_idcs):
+ if len(mask_idc) > min_len:
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
+ mask[i, mask_idc] = True
+
+ return mask
+
+
+class WavLMConfig:
+ def __init__(self, cfg=None):
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
+
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
+ self.activation_fn: str = "gelu" # activation function to use
+
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
+ self.conv_bias: bool = False # include bias in conv encoder
+ self.feature_grad_mult: float = (
+ 1.0 # multiply feature extractor var grads by this
+ )
+
+ self.normalize: bool = (
+ False # normalize input to have 0 mean and unit variance during training
+ )
+
+ # dropouts
+ self.dropout: float = 0.1 # dropout probability for the transformer
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
+ self.activation_dropout: float = (
+ 0.0 # dropout probability after activation in FFN
+ )
+ self.encoder_layerdrop: float = (
+ 0.0 # probability of dropping a tarnsformer layer
+ )
+ self.dropout_input: float = (
+ 0.0 # dropout to apply to the input (after feat extr)
+ )
+ self.dropout_features: float = (
+ 0.0 # dropout to apply to the features (after feat extr)
+ )
+
+ # masking
+ self.mask_length: int = 10 # mask length
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
+ self.mask_selection: str = "static" # how to choose mask length
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
+ self.mask_min_space: int = (
+ 1 # min space between spans (if no overlap is enabled)
+ )
+
+ # channel masking
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
+ self.mask_channel_selection: str = (
+ "static" # how to choose mask length for channel masking
+ )
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
+ self.no_mask_channel_overlap: bool = (
+ False # whether to allow channel masks to overlap
+ )
+ self.mask_channel_min_space: int = (
+ 1 # min space between spans (if no overlap is enabled)
+ )
+
+ # positional embeddings
+ self.conv_pos: int = (
+ 128 # number of filters for convolutional positional embeddings
+ )
+ self.conv_pos_groups: int = (
+ 16 # number of groups for convolutional positional embedding
+ )
+
+ # relative position embedding
+ self.relative_position_embedding: bool = (
+ False # apply relative position embedding
+ )
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
+ self.max_distance: int = (
+ 1280 # maximum distance for relative position embedding
+ )
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
+
+ if cfg is not None:
+ self.update(cfg)
+
+ def update(self, cfg: dict):
+ self.__dict__.update(cfg)
+
+
+class WavLM(nn.Module):
+ def __init__(
+ self,
+ cfg: WavLMConfig,
+ ) -> None:
+ super().__init__()
+ logger.info(f"WavLM Config: {cfg.__dict__}")
+
+ self.cfg = cfg
+ feature_enc_layers = eval(cfg.conv_feature_layers)
+ self.embed = feature_enc_layers[-1][0]
+
+ self.feature_extractor = ConvFeatureExtractionModel(
+ conv_layers=feature_enc_layers,
+ dropout=0.0,
+ mode=cfg.extractor_mode,
+ conv_bias=cfg.conv_bias,
+ )
+
+ self.post_extract_proj = (
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
+ if self.embed != cfg.encoder_embed_dim
+ else None
+ )
+
+ self.mask_prob = cfg.mask_prob
+ self.mask_selection = cfg.mask_selection
+ self.mask_other = cfg.mask_other
+ self.mask_length = cfg.mask_length
+ self.no_mask_overlap = cfg.no_mask_overlap
+ self.mask_min_space = cfg.mask_min_space
+
+ self.mask_channel_prob = cfg.mask_channel_prob
+ self.mask_channel_selection = cfg.mask_channel_selection
+ self.mask_channel_other = cfg.mask_channel_other
+ self.mask_channel_length = cfg.mask_channel_length
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
+ self.mask_channel_min_space = cfg.mask_channel_min_space
+
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
+
+ self.feature_grad_mult = cfg.feature_grad_mult
+
+ self.mask_emb = nn.Parameter(
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
+ )
+
+ self.encoder = TransformerEncoder(cfg)
+ self.layer_norm = LayerNorm(self.embed)
+
+ def apply_mask(self, x, padding_mask):
+ B, T, C = x.shape
+ if self.mask_prob > 0:
+ mask_indices = compute_mask_indices(
+ (B, T),
+ padding_mask,
+ self.mask_prob,
+ self.mask_length,
+ self.mask_selection,
+ self.mask_other,
+ min_masks=2,
+ no_overlap=self.no_mask_overlap,
+ min_space=self.mask_min_space,
+ )
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
+ x[mask_indices] = self.mask_emb
+ else:
+ mask_indices = None
+
+ if self.mask_channel_prob > 0:
+ mask_channel_indices = compute_mask_indices(
+ (B, C),
+ None,
+ self.mask_channel_prob,
+ self.mask_channel_length,
+ self.mask_channel_selection,
+ self.mask_channel_other,
+ no_overlap=self.no_mask_channel_overlap,
+ min_space=self.mask_channel_min_space,
+ )
+ mask_channel_indices = (
+ torch.from_numpy(mask_channel_indices)
+ .to(x.device)
+ .unsqueeze(1)
+ .expand(-1, T, -1)
+ )
+ x[mask_channel_indices] = 0
+
+ return x, mask_indices
+
+ def forward_padding_mask(
+ self,
+ features: torch.Tensor,
+ padding_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ extra = padding_mask.size(1) % features.size(1)
+ if extra > 0:
+ padding_mask = padding_mask[:, :-extra]
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
+ padding_mask = padding_mask.all(-1)
+ return padding_mask
+
+ def extract_features(
+ self,
+ source: torch.Tensor,
+ padding_mask: Optional[torch.Tensor] = None,
+ mask: bool = False,
+ ret_conv: bool = False,
+ output_layer: Optional[int] = None,
+ ret_layer_results: bool = False,
+ ):
+ if self.feature_grad_mult > 0:
+ features = self.feature_extractor(source)
+ if self.feature_grad_mult != 1.0:
+ features = GradMultiply.apply(features, self.feature_grad_mult)
+ else:
+ with torch.no_grad():
+ features = self.feature_extractor(source)
+
+ features = features.transpose(1, 2)
+ features = self.layer_norm(features)
+
+ if padding_mask is not None:
+ padding_mask = self.forward_padding_mask(features, padding_mask)
+
+ if self.post_extract_proj is not None:
+ features = self.post_extract_proj(features)
+
+ features = self.dropout_input(features)
+
+ if mask:
+ x, mask_indices = self.apply_mask(features, padding_mask)
+ else:
+ x = features
+
+ # feature: (B, T, D), float
+ # target: (B, T), long
+ # x: (B, T, D), float
+ # padding_mask: (B, T), bool
+ # mask_indices: (B, T), bool
+ x, layer_results = self.encoder(
+ x,
+ padding_mask=padding_mask,
+ layer=None if output_layer is None else output_layer - 1,
+ )
+
+ res = {
+ "x": x,
+ "padding_mask": padding_mask,
+ "features": features,
+ "layer_results": layer_results,
+ }
+
+ feature = res["features"] if ret_conv else res["x"]
+ if ret_layer_results:
+ feature = (feature, res["layer_results"])
+ return feature, res["padding_mask"]
+
+
+class ConvFeatureExtractionModel(nn.Module):
+ def __init__(
+ self,
+ conv_layers: List[Tuple[int, int, int]],
+ dropout: float = 0.0,
+ mode: str = "default",
+ conv_bias: bool = False,
+ conv_type: str = "default",
+ ):
+ super().__init__()
+
+ assert mode in {"default", "layer_norm"}
+
+ def block(
+ n_in,
+ n_out,
+ k,
+ stride,
+ is_layer_norm=False,
+ is_group_norm=False,
+ conv_bias=False,
+ ):
+ def make_conv():
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
+ nn.init.kaiming_normal_(conv.weight)
+ return conv
+
+ assert not (is_layer_norm and is_group_norm), (
+ "layer norm and group norm are exclusive"
+ )
+
+ if is_layer_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ nn.Sequential(
+ TransposeLast(),
+ Fp32LayerNorm(dim, elementwise_affine=True),
+ TransposeLast(),
+ ),
+ nn.GELU(),
+ )
+ elif is_group_norm:
+ return nn.Sequential(
+ make_conv(),
+ nn.Dropout(p=dropout),
+ Fp32GroupNorm(dim, dim, affine=True),
+ nn.GELU(),
+ )
+ else:
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
+
+ self.conv_type = conv_type
+ if self.conv_type == "default":
+ in_d = 1
+ self.conv_layers = nn.ModuleList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
+ (dim, k, stride) = cl
+
+ self.conv_layers.append(
+ block(
+ in_d,
+ dim,
+ k,
+ stride,
+ is_layer_norm=mode == "layer_norm",
+ is_group_norm=mode == "default" and i == 0,
+ conv_bias=conv_bias,
+ )
+ )
+ in_d = dim
+ elif self.conv_type == "conv2d":
+ in_d = 1
+ self.conv_layers = nn.ModuleList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3
+ (dim, k, stride) = cl
+
+ self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
+ self.conv_layers.append(torch.nn.ReLU())
+ in_d = dim
+ elif self.conv_type == "custom":
+ in_d = 1
+ idim = 80
+ self.conv_layers = nn.ModuleList()
+ for i, cl in enumerate(conv_layers):
+ assert len(cl) == 3
+ (dim, k, stride) = cl
+ self.conv_layers.append(
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
+ )
+ self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
+ self.conv_layers.append(torch.nn.ReLU())
+ in_d = dim
+ if (i + 1) % 2 == 0:
+ self.conv_layers.append(
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
+ )
+ idim = int(math.ceil(idim / 2))
+ else:
+ pass
+
+ def forward(self, x, mask=None):
+ # BxT -> BxCxT
+ x = x.unsqueeze(1)
+ if self.conv_type == "custom":
+ for conv in self.conv_layers:
+ if isinstance(conv, nn.LayerNorm):
+ x = x.transpose(1, 2)
+ x = conv(x).transpose(1, 2)
+ else:
+ x = conv(x)
+ x = x.transpose(2, 3).contiguous()
+ x = x.view(x.size(0), -1, x.size(-1))
+ else:
+ for conv in self.conv_layers:
+ x = conv(x)
+ if self.conv_type == "conv2d":
+ b, c, t, f = x.size()
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
+ return x
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+
+ self.dropout = args.dropout
+ self.embedding_dim = args.encoder_embed_dim
+
+ self.pos_conv = nn.Conv1d(
+ self.embedding_dim,
+ self.embedding_dim,
+ kernel_size=args.conv_pos,
+ padding=args.conv_pos // 2,
+ groups=args.conv_pos_groups,
+ )
+ dropout = 0
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
+ nn.init.constant_(self.pos_conv.bias, 0)
+
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
+
+ if hasattr(args, "relative_position_embedding"):
+ self.relative_position_embedding = args.relative_position_embedding
+ self.num_buckets = args.num_buckets
+ self.max_distance = args.max_distance
+ else:
+ self.relative_position_embedding = False
+ self.num_buckets = 0
+ self.max_distance = 0
+
+ self.layers = nn.ModuleList(
+ [
+ TransformerSentenceEncoderLayer(
+ embedding_dim=self.embedding_dim,
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
+ num_attention_heads=args.encoder_attention_heads,
+ dropout=self.dropout,
+ attention_dropout=args.attention_dropout,
+ activation_dropout=args.activation_dropout,
+ activation_fn=args.activation_fn,
+ layer_norm_first=args.layer_norm_first,
+ has_relative_attention_bias=(
+ self.relative_position_embedding and i == 0
+ ),
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ gru_rel_pos=args.gru_rel_pos,
+ )
+ for i in range(args.encoder_layers)
+ ]
+ )
+
+ self.layer_norm_first = args.layer_norm_first
+ self.layer_norm = LayerNorm(self.embedding_dim)
+ self.layerdrop = args.encoder_layerdrop
+
+ self.apply(init_bert_params)
+
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
+
+ if self.layer_norm_first and layer is None:
+ x = self.layer_norm(x)
+
+ return x, layer_results
+
+ def extract_features(
+ self, x, padding_mask=None, streaming_mask=None, tgt_layer=None
+ ):
+ if padding_mask is not None:
+ x[padding_mask] = 0
+
+ x_conv = self.pos_conv(x.transpose(1, 2))
+ x_conv = x_conv.transpose(1, 2)
+ x += x_conv
+
+ if not self.layer_norm_first:
+ x = self.layer_norm(x)
+
+ x = F.dropout(x, p=self.dropout, training=self.training)
+
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1)
+
+ layer_results = []
+ z = None
+ if tgt_layer is not None:
+ layer_results.append((x, z))
+ r = None
+ pos_bias = None
+ for i, layer in enumerate(self.layers):
+ dropout_probability = np.random.random()
+ if not self.training or (dropout_probability > self.layerdrop):
+ x, z, pos_bias = layer(
+ x,
+ self_attn_padding_mask=padding_mask,
+ need_weights=False,
+ self_attn_mask=streaming_mask,
+ pos_bias=pos_bias,
+ )
+ if tgt_layer is not None:
+ layer_results.append((x, z))
+ if i == tgt_layer:
+ r = x
+ break
+
+ if r is not None:
+ x = r
+
+ # T x B x C -> B x T x C
+ x = x.transpose(0, 1)
+
+ return x, layer_results
+
+
+class TransformerSentenceEncoderLayer(nn.Module):
+ """
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
+ models.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: float = 768,
+ ffn_embedding_dim: float = 3072,
+ num_attention_heads: float = 8,
+ dropout: float = 0.1,
+ attention_dropout: float = 0.1,
+ activation_dropout: float = 0.1,
+ activation_fn: str = "relu",
+ layer_norm_first: bool = False,
+ has_relative_attention_bias: bool = False,
+ num_buckets: int = 0,
+ max_distance: int = 0,
+ rescale_init: bool = False,
+ gru_rel_pos: bool = False,
+ ) -> None:
+ super().__init__()
+ # Initialize parameters
+ self.embedding_dim = embedding_dim
+ self.dropout = dropout
+ self.activation_dropout = activation_dropout
+
+ # Initialize blocks
+ self.activation_name = activation_fn
+ self.activation_fn = get_activation_fn(activation_fn)
+ self.self_attn = MultiheadAttention(
+ self.embedding_dim,
+ num_attention_heads,
+ dropout=attention_dropout,
+ self_attention=True,
+ has_relative_attention_bias=has_relative_attention_bias,
+ num_buckets=num_buckets,
+ max_distance=max_distance,
+ rescale_init=rescale_init,
+ gru_rel_pos=gru_rel_pos,
+ )
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(self.activation_dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.layer_norm_first = layer_norm_first
+
+ # layer norm associated with the self attention layer
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
+
+ if self.activation_name == "glu":
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
+ else:
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
+
+ # layer norm associated with the position wise feed-forward NN
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ self_attn_mask: torch.Tensor = None,
+ self_attn_padding_mask: torch.Tensor = None,
+ need_weights: bool = False,
+ pos_bias=None,
+ ):
+ """
+ LayerNorm is applied either before or after the self-attention/ffn
+ modules similar to the original Transformer imlementation.
+ """
+ residual = x
+
+ if self.layer_norm_first:
+ x = self.self_attn_layer_norm(x)
+ x, attn, pos_bias = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=False,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias,
+ )
+ x = self.dropout1(x)
+ x = residual + x
+
+ residual = x
+ x = self.final_layer_norm(x)
+ if self.activation_name == "glu":
+ x = self.fc1(x)
+ else:
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ else:
+ x, attn, pos_bias = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ need_weights=need_weights,
+ attn_mask=self_attn_mask,
+ position_bias=pos_bias,
+ )
+
+ x = self.dropout1(x)
+ x = residual + x
+
+ x = self.self_attn_layer_norm(x)
+
+ residual = x
+ if self.activation_name == "glu":
+ x = self.fc1(x)
+ else:
+ x = self.activation_fn(self.fc1(x))
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.dropout3(x)
+ x = residual + x
+ x = self.final_layer_norm(x)
+
+ return x, attn, pos_bias
diff --git a/WavLM_modules.py b/WavLM_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4446d66e33cff004136019bb980f46a0b39b6324
--- /dev/null
+++ b/WavLM_modules.py
@@ -0,0 +1,765 @@
+# --------------------------------------------------------
+# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
+# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Based on fairseq code bases
+# https://github.com/pytorch/fairseq
+# --------------------------------------------------------
+
+import math
+import warnings
+from typing import Dict, Optional, Tuple
+import torch
+from torch import Tensor, nn
+from torch.nn import Parameter
+import torch.nn.functional as F
+
+
+class TransposeLast(nn.Module):
+ def __init__(self, deconstruct_idx=None):
+ super().__init__()
+ self.deconstruct_idx = deconstruct_idx
+
+ def forward(self, x):
+ if self.deconstruct_idx is not None:
+ x = x[self.deconstruct_idx]
+ return x.transpose(-2, -1)
+
+
+class Fp32LayerNorm(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.layer_norm(
+ input.float(),
+ self.normalized_shape,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+class Fp32GroupNorm(nn.GroupNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, input):
+ output = F.group_norm(
+ input.float(),
+ self.num_groups,
+ self.weight.float() if self.weight is not None else None,
+ self.bias.float() if self.bias is not None else None,
+ self.eps,
+ )
+ return output.type_as(input)
+
+
+class GradMultiply(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, scale):
+ ctx.scale = scale
+ res = x.new(x)
+ return res
+
+ @staticmethod
+ def backward(ctx, grad):
+ return grad * ctx.scale, None
+
+
+class SamePad(nn.Module):
+ def __init__(self, kernel_size, causal=False):
+ super().__init__()
+ if causal:
+ self.remove = kernel_size - 1
+ else:
+ self.remove = 1 if kernel_size % 2 == 0 else 0
+
+ def forward(self, x):
+ if self.remove > 0:
+ x = x[:, :, : -self.remove]
+ return x
+
+
+class Swish(nn.Module):
+ """Swish function"""
+
+ def __init__(self):
+ """Construct an MultiHeadedAttention object."""
+ super(Swish, self).__init__()
+ self.act = torch.nn.Sigmoid()
+
+ def forward(self, x):
+ return x * self.act(x)
+
+
+class GLU_Linear(nn.Module):
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
+ super(GLU_Linear, self).__init__()
+
+ self.glu_type = glu_type
+ self.output_dim = output_dim
+
+ if glu_type == "sigmoid":
+ self.glu_act = torch.nn.Sigmoid()
+ elif glu_type == "swish":
+ self.glu_act = Swish()
+ elif glu_type == "relu":
+ self.glu_act = torch.nn.ReLU()
+ elif glu_type == "gelu":
+ self.glu_act = torch.nn.GELU()
+
+ if bias_in_glu:
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
+ else:
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
+
+ def forward(self, x):
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
+ x = self.linear(x)
+
+ if self.glu_type == "bilinear":
+ x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
+ else:
+ x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
+
+ return x
+
+
+def gelu_accurate(x):
+ if not hasattr(gelu_accurate, "_a"):
+ gelu_accurate._a = math.sqrt(2 / math.pi)
+ return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+
+
+def gelu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.gelu(x.float()).type_as(x)
+
+
+def get_activation_fn(activation: str):
+ """Returns the activation function corresponding to `activation`"""
+
+ if activation == "relu":
+ return F.relu
+ elif activation == "gelu":
+ return gelu
+ elif activation == "gelu_fast":
+ warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
+ return gelu_accurate
+ elif activation == "gelu_accurate":
+ return gelu_accurate
+ elif activation == "tanh":
+ return torch.tanh
+ elif activation == "linear":
+ return lambda x: x
+ elif activation == "glu":
+ return lambda x: x
+ else:
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
+
+
+def init_bert_params(module):
+ """
+ Initialize the weights specific to the BERT Model.
+ This overrides the default initializations depending on the specified arguments.
+ 1. If normal_init_linear_weights is set then weights of linear
+ layer will be initialized using the normal distribution and
+ bais will be set to the specified value.
+ 2. If normal_init_embed_weights is set then weights of embedding
+ layer will be initialized using the normal distribution.
+ 3. If normal_init_proj_weights is set then weights of
+ in_project_weight for MultiHeadAttention initialized using
+ the normal distribution (to be validated).
+ """
+
+ def normal_(data):
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
+ # so that the RNG is consistent with and without FSDP
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
+
+ if isinstance(module, nn.Linear):
+ normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ if isinstance(module, MultiheadAttention):
+ normal_(module.q_proj.weight.data)
+ normal_(module.k_proj.weight.data)
+ normal_(module.v_proj.weight.data)
+
+
+def quant_noise(module, p, block_size):
+ """
+ Wraps modules and applies quantization noise to the weights for
+ subsequent quantization with Iterative Product Quantization as
+ described in "Training with Quantization Noise for Extreme Model Compression"
+ Args:
+ - module: nn.Module
+ - p: amount of Quantization Noise
+ - block_size: size of the blocks for subsequent quantization with iPQ
+ Remarks:
+ - Module weights must have the right sizes wrt the block size
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
+ - For more detail on how to quantize by blocks with convolutional weights,
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
+ - We implement the simplest form of noise here as stated in the paper
+ which consists in randomly dropping blocks
+ """
+
+ # if no quantization noise, don't register hook
+ if p <= 0:
+ return module
+
+ # supported modules
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
+
+ # test whether module.weight has the right sizes wrt block_size
+ is_conv = module.weight.ndim == 4
+
+ # 2D matrix
+ if not is_conv:
+ assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
+
+ # 4D matrix
+ else:
+ # 1x1 convolutions
+ if module.kernel_size == (1, 1):
+ assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
+ # regular convolutions
+ else:
+ k = module.kernel_size[0] * module.kernel_size[1]
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
+
+ def _forward_pre_hook(mod, input):
+ # no noise for evaluation
+ if mod.training:
+ if not is_conv:
+ # gather weight and sizes
+ weight = mod.weight
+ in_features = weight.size(1)
+ out_features = weight.size(0)
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
+
+ else:
+ # gather weight and sizes
+ weight = mod.weight
+ in_channels = mod.in_channels
+ out_channels = mod.out_channels
+
+ # split weight matrix into blocks and randomly drop selected blocks
+ if mod.kernel_size == (1, 1):
+ mask = torch.zeros(
+ int(in_channels // block_size * out_channels),
+ device=weight.device,
+ )
+ mask.bernoulli_(p)
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
+ else:
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
+ mask.bernoulli_(p)
+ mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+
+ # scale weights and apply mask
+ mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
+ s = 1 / (1 - p)
+ mod.weight.data = s * weight.masked_fill(mask, 0)
+
+ module.register_forward_pre_hook(_forward_pre_hook)
+ return module
+
+
+class MultiheadAttention(nn.Module):
+ """Multi-headed attention.
+ See "Attention Is All You Need" for more details.
+ """
+
+ def __init__(
+ self,
+ embed_dim,
+ num_heads,
+ kdim=None,
+ vdim=None,
+ dropout=0.0,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False,
+ self_attention=False,
+ encoder_decoder_attention=False,
+ q_noise=0.0,
+ qn_block_size=8,
+ has_relative_attention_bias=False,
+ num_buckets=32,
+ max_distance=128,
+ gru_rel_pos=False,
+ rescale_init=False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout_module = nn.Dropout(dropout)
+
+ self.has_relative_attention_bias = has_relative_attention_bias
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ if self.has_relative_attention_bias:
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
+
+ self.head_dim = embed_dim // num_heads
+ self.q_head_dim = self.head_dim
+ self.k_head_dim = self.head_dim
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim**-0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, (
+ "Self-attention requires query, key and " "value to be of the same size"
+ )
+
+ k_bias = True
+ if rescale_init:
+ k_bias = False
+
+ k_embed_dim = embed_dim
+ q_embed_dim = embed_dim
+
+ self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
+ self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
+ self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
+
+ self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.gru_rel_pos = gru_rel_pos
+ if self.gru_rel_pos:
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ # Empirically observed the convergence to be much better with
+ # the scaled initialization
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
+ else:
+ nn.init.xavier_uniform_(self.k_proj.weight)
+ nn.init.xavier_uniform_(self.v_proj.weight)
+ nn.init.xavier_uniform_(self.q_proj.weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.out_proj.bias is not None:
+ nn.init.constant_(self.out_proj.bias, 0.0)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+ if self.has_relative_attention_bias:
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
+
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
+ num_buckets = self.num_buckets
+ max_distance = self.max_distance
+ relative_buckets = 0
+
+ if bidirectional:
+ num_buckets = num_buckets // 2
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
+ relative_positions = torch.abs(relative_positions)
+ else:
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
+
+ max_exact = num_buckets // 2
+ is_small = relative_positions < max_exact
+
+ relative_postion_if_large = max_exact + (
+ torch.log(relative_positions.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).to(torch.long)
+ relative_postion_if_large = torch.min(
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
+ )
+
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
+ return relative_buckets
+
+ def compute_bias(self, query_length, key_length):
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
+ relative_position = memory_position - context_position
+ relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
+ values = self.relative_attention_bias(relative_position_bucket)
+ values = values.permute([2, 0, 1])
+ return values
+
+ def forward(
+ self,
+ query,
+ key: Optional[Tensor],
+ value: Optional[Tensor],
+ key_padding_mask: Optional[Tensor] = None,
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
+ need_weights: bool = True,
+ static_kv: bool = False,
+ attn_mask: Optional[Tensor] = None,
+ before_softmax: bool = False,
+ need_head_weights: bool = False,
+ position_bias: Optional[Tensor] = None,
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ """Input shape: Time x Batch x Channel
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ is_tpu = query.device.type == "xla"
+
+ tgt_len, bsz, embed_dim = query.size()
+ src_len = tgt_len
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if key is not None:
+ src_len, key_bsz, _ = key.size()
+ if not torch.jit.is_scripting():
+ assert key_bsz == bsz
+ assert value is not None
+ assert src_len, bsz == value.shape[:2]
+
+ if self.has_relative_attention_bias and position_bias is None:
+ position_bias = self.compute_bias(tgt_len, src_len)
+ position_bias = (
+ position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
+ )
+
+ if (
+ not is_tpu # don't use PyTorch version on TPUs
+ and incremental_state is None
+ and not static_kv
+ # A workaround for quantization to work. Otherwise JIT compilation
+ # treats bias in linear module as method.
+ and not torch.jit.is_scripting()
+ and self.q_head_dim == self.head_dim
+ ):
+ assert key is not None and value is not None
+ assert attn_mask is None
+
+ attn_mask_rel_pos = None
+ if position_bias is not None:
+ attn_mask_rel_pos = position_bias
+ if self.gru_rel_pos:
+ query_layer = query.transpose(0, 1)
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
+ query_layer = query_layer.view(*new_x_shape)
+ query_layer = query_layer.permute(0, 2, 1, 3)
+ _B, _H, _L, __ = query_layer.size()
+
+ gate_a, gate_b = torch.sigmoid(
+ self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
+ ).chunk(2, dim=-1)
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
+
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
+ k_proj_bias = self.k_proj.bias
+ if k_proj_bias is None:
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
+
+ x, attn = F.multi_head_attention_forward(
+ query,
+ key,
+ value,
+ self.embed_dim,
+ self.num_heads,
+ torch.empty([0]),
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
+ self.bias_k,
+ self.bias_v,
+ self.add_zero_attn,
+ self.dropout_module.p,
+ self.out_proj.weight,
+ self.out_proj.bias,
+ self.training,
+ # self.training or self.dropout_module.apply_during_inference,
+ key_padding_mask,
+ need_weights,
+ attn_mask_rel_pos,
+ use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ )
+ return x, attn, position_bias
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if saved_state is not None and "prev_key" in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ q = self.q_proj(query)
+ k = self.k_proj(query)
+ v = self.v_proj(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.q_proj(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.k_proj(key)
+ v = self.v_proj(key)
+
+ else:
+ assert key is not None and value is not None
+ q = self.q_proj(query)
+ k = self.k_proj(key)
+ v = self.v_proj(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
+ ],
+ dim=1,
+ )
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if "prev_key" in saved_state:
+ _prev_key = saved_state["prev_key"]
+ assert _prev_key is not None
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ assert k is not None
+ k = torch.cat([prev_key, k], dim=1)
+ src_len = k.size(1)
+ if "prev_value" in saved_state:
+ _prev_value = saved_state["prev_value"]
+ assert _prev_value is not None
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ assert v is not None
+ v = torch.cat([prev_value, v], dim=1)
+ prev_key_padding_mask: Optional[Tensor] = None
+ if "prev_key_padding_mask" in saved_state:
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
+ assert k is not None and v is not None
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
+ key_padding_mask=key_padding_mask,
+ prev_key_padding_mask=prev_key_padding_mask,
+ batch_size=bsz,
+ src_len=k.size(1),
+ static_kv=static_kv,
+ )
+
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state["prev_key_padding_mask"] = key_padding_mask
+ # In this branch incremental_state is never None
+ assert incremental_state is not None
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
+ assert k is not None
+ assert k.size(1) == src_len
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ assert v is not None
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [
+ key_padding_mask,
+ torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
+ ],
+ dim=1,
+ )
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ attn_mask = attn_mask.unsqueeze(0)
+ attn_weights += attn_mask
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ if not is_tpu:
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
+ float("-inf"),
+ )
+ else:
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
+ attn_weights = attn_weights.transpose(0, 2)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v, position_bias
+
+ if position_bias is not None:
+ if self.gru_rel_pos == 1:
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
+ _B, _H, _L, __ = query_layer.size()
+ gate_a, gate_b = torch.sigmoid(
+ self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
+ ).chunk(2, dim=-1)
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
+
+ position_bias = position_bias.view(attn_weights.size())
+
+ attn_weights = attn_weights + position_bias
+
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = self.dropout_module(attn_weights)
+
+ assert v is not None
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+ attn_weights: Optional[Tensor] = None
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+
+ return attn, attn_weights, position_bias
+
+ @staticmethod
+ def _append_prev_key_padding_mask(
+ key_padding_mask: Optional[Tensor],
+ prev_key_padding_mask: Optional[Tensor],
+ batch_size: int,
+ src_len: int,
+ static_kv: bool,
+ ) -> Optional[Tensor]:
+ # saved key padding masks have shape (bsz, seq_len)
+ if prev_key_padding_mask is not None and static_kv:
+ new_key_padding_mask = prev_key_padding_mask
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
+ # During incremental decoding, as the padding token enters and
+ # leaves the frame, there will be a time when prev or current
+ # is None
+ elif prev_key_padding_mask is not None:
+ if src_len > prev_key_padding_mask.size(1):
+ filler = torch.zeros(
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
+ device=prev_key_padding_mask.device,
+ )
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
+ else:
+ new_key_padding_mask = prev_key_padding_mask.float()
+ elif key_padding_mask is not None:
+ if src_len > key_padding_mask.size(1):
+ filler = torch.zeros(
+ (batch_size, src_len - key_padding_mask.size(1)),
+ device=key_padding_mask.device,
+ )
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
+ else:
+ new_key_padding_mask = key_padding_mask.float()
+ else:
+ new_key_padding_mask = prev_key_padding_mask
+ return new_key_padding_mask
+
+ def _get_input_buffer(
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
+ ) -> Dict[str, Optional[Tensor]]:
+ result = self.get_incremental_state(incremental_state, "attn_state")
+ if result is not None:
+ return result
+ else:
+ empty_result: Dict[str, Optional[Tensor]] = {}
+ return empty_result
+
+ def _set_input_buffer(
+ self,
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
+ buffer: Dict[str, Optional[Tensor]],
+ ):
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
+
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
+ return attn_weights
diff --git a/__pycache__/WavLM.cpython-311.pyc b/__pycache__/WavLM.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f087d57939b223a1c9adb66b9f88c0f330f7ea4c
Binary files /dev/null and b/__pycache__/WavLM.cpython-311.pyc differ
diff --git a/__pycache__/WavLM_modules.cpython-311.pyc b/__pycache__/WavLM_modules.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86401e9108d0c131c1c34bc27e45c09582fc01c9
Binary files /dev/null and b/__pycache__/WavLM_modules.cpython-311.pyc differ
diff --git a/__pycache__/data_utils.cpython-311.pyc b/__pycache__/data_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..099487bea400e4a301ea7bc27336cc22e9fabebb
Binary files /dev/null and b/__pycache__/data_utils.cpython-311.pyc differ
diff --git a/__pycache__/dino_game.cpython-311.pyc b/__pycache__/dino_game.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e6578fd8afcd3ee6ec05d1ff79687c18845f0bd
Binary files /dev/null and b/__pycache__/dino_game.cpython-311.pyc differ
diff --git a/__pycache__/inference_functions.cpython-311.pyc b/__pycache__/inference_functions.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eed8b1be308834e7986aa528fd76e0c25960e2ef
Binary files /dev/null and b/__pycache__/inference_functions.cpython-311.pyc differ
diff --git a/__pycache__/landmarks_extractor.cpython-311.pyc b/__pycache__/landmarks_extractor.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..407a2fc7fcfaab66d423cca8a061a38a711752ad
Binary files /dev/null and b/__pycache__/landmarks_extractor.cpython-311.pyc differ
diff --git a/__pycache__/utils.cpython-311.pyc b/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..410cb6bd93979f08c96b9f0bc37c8ba6455225a4
Binary files /dev/null and b/__pycache__/utils.cpython-311.pyc differ
diff --git a/__pycache__/vae_wrapper.cpython-311.pyc b/__pycache__/vae_wrapper.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18d0e9d364c362c906a072e3a47787da9a25d15d
Binary files /dev/null and b/__pycache__/vae_wrapper.cpython-311.pyc differ
diff --git a/__pycache__/wordle_game.cpython-311.pyc b/__pycache__/wordle_game.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d356c2b82c5144f3b6a32d9e66f5808bafb7acd8
Binary files /dev/null and b/__pycache__/wordle_game.cpython-311.pyc differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..648ae619bc1eeddff0a5672c1cd04c65b9e41e3a
--- /dev/null
+++ b/app.py
@@ -0,0 +1,978 @@
+import gradio as gr
+import torch
+import tempfile
+import os
+from vae_wrapper import VaeWrapper, encode_video_chunk
+from landmarks_extractor import LandmarksExtractor
+import decord
+from utils import (
+ get_raw_audio,
+ save_audio_video,
+ calculate_splits,
+ instantiate_from_config,
+ create_pipeline_inputs,
+)
+from transformers import HubertModel
+from einops import rearrange
+import numpy as np
+from WavLM import WavLM_wrapper
+from omegaconf import OmegaConf
+from inference_functions import (
+ sample_keyframes,
+ sample_interpolation,
+)
+from wordle_game import WordleGame
+import torch.cuda.amp as amp # Import amp for mixed precision
+
+
+# Set default tensor type to float16 for faster computation
+if torch.cuda.is_available():
+ # torch.set_default_tensor_type(torch.cuda.FloatTensor)
+ # Enable TF32 precision for better performance on Ampere+ GPUs
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+
+# Cache for video and audio processing
+cache = {
+ "video": {
+ "path": None,
+ "embedding": None,
+ "frames": None,
+ "landmarks": None,
+ },
+ "audio": {
+ "path": None,
+ "raw_audio": None,
+ "hubert_embedding": None,
+ "wavlm_embedding": None,
+ },
+}
+
+# Create mixed precision scaler
+scaler = amp.GradScaler()
+
+
+def load_model(
+ config: str,
+ device: str = "cuda",
+ ckpt: str = None,
+):
+ """
+ Load a model from configuration.
+
+ Args:
+ config: Path to model configuration file
+ device: Device to load the model on
+ num_frames: Number of frames to process
+ input_key: Input key for the model
+ ckpt: Optional checkpoint path
+
+ Returns:
+ Tuple of (model, filter, batch size)
+ """
+ config = OmegaConf.load(config)
+
+ config["model"]["params"]["input_key"] = "latents"
+
+ if ckpt is not None:
+ config.model.params.ckpt_path = ckpt
+
+ with torch.device(device):
+ model = instantiate_from_config(config.model).to(device).eval()
+ # Convert model to half precision
+ if torch.cuda.is_available():
+ model = model.half()
+ model.first_stage_model = model.first_stage_model.float()
+ print("Converted model to FP16 precision")
+
+ # Compile model for faster inference
+ if torch.cuda.is_available():
+ try:
+ model = torch.compile(model)
+ print(f"Successfully compiled model with torch.compile()")
+ except Exception as e:
+ print(f"Warning: Failed to compile model: {e}")
+
+ return model
+
+
+# keyframe_model = KeyframeModel(device=device)
+# interpolation_model = InterpolationModel(device=device)
+vae_model = VaeWrapper("video")
+if torch.cuda.is_available():
+ vae_model = vae_model.half() # Convert to half precision
+ try:
+ vae_model = torch.compile(vae_model)
+ print("Successfully compiled vae_model in FP16")
+ except Exception as e:
+ print(f"Warning: Failed to compile vae_model: {e}")
+
+hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda()
+if torch.cuda.is_available():
+ hubert_model = hubert_model.half() # Convert to half precision
+ try:
+ hubert_model = torch.compile(hubert_model)
+ print("Successfully compiled hubert_model in FP16")
+ except Exception as e:
+ print(f"Warning: Failed to compile hubert_model: {e}")
+
+wavlm_model = WavLM_wrapper(
+ model_size="Base+",
+ feed_as_frames=False,
+ merge_type="None",
+ model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
+).cuda()
+if torch.cuda.is_available():
+ wavlm_model = wavlm_model.half() # Convert to half precision
+ try:
+ wavlm_model = torch.compile(wavlm_model)
+ print("Successfully compiled wavlm_model in FP16")
+ except Exception as e:
+ print(f"Warning: Failed to compile wavlm_model: {e}")
+
+landmarks_extractor = LandmarksExtractor()
+keyframe_model = load_model(
+ config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
+ ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
+)
+interpolation_model = load_model(
+ config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
+ ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
+)
+keyframe_model.en_and_decode_n_samples_a_time = 2
+interpolation_model.en_and_decode_n_samples_a_time = 2
+
+# Default media paths
+DEFAULT_VIDEO_PATH = os.path.join(
+ os.path.dirname(__file__), "assets", "sample_video.mp4"
+)
+DEFAULT_AUDIO_PATH = os.path.join(
+ os.path.dirname(__file__), "assets", "sample_audio.wav"
+)
+
+
+@torch.no_grad()
+def compute_video_embedding(video_reader, min_len):
+ """Compute embeddings from video"""
+
+ total_frames = min_len
+
+ encoded = []
+ video_frames = []
+ chunk_size = 16
+ resolution = 512
+
+ # # Create a progress bar for Gradio
+ progress = gr.Progress()
+
+ # Calculate total chunks for progress tracking
+ total_chunks = (total_frames + chunk_size - 1) // chunk_size
+
+ for i, start_idx in enumerate(range(0, total_frames, chunk_size)):
+ # Update progress bar
+ progress(i / total_chunks, desc="Processing video chunks")
+
+ end_idx = min(start_idx + chunk_size, total_frames)
+ video_chunk = video_reader.get_batch(range(start_idx, end_idx))
+ # Interpolate video chunk to the target resolution
+ video_chunk = rearrange(video_chunk, "f h w c -> f c h w")
+ video_chunk = torch.nn.functional.interpolate(
+ video_chunk,
+ size=(resolution, resolution),
+ mode="bilinear",
+ align_corners=False,
+ )
+ video_chunk = rearrange(video_chunk, "f c h w -> f h w c")
+ video_frames.append(video_chunk)
+
+ # Convert chunk to FP16 if using CUDA
+ if torch.cuda.is_available():
+ video_chunk = video_chunk.half()
+
+ # Always use autocast for FP16 computation
+ with amp.autocast(enabled=True):
+ encoded.append(encode_video_chunk(vae_model, video_chunk, resolution))
+
+ encoded = torch.cat(encoded, dim=0)
+ video_frames = torch.cat(video_frames, dim=0)
+ video_frames = rearrange(video_frames, "f h w c -> f c h w")
+ torch.cuda.empty_cache()
+ return encoded, video_frames
+
+
+@torch.no_grad()
+def compute_hubert_embedding(raw_audio):
+ """Compute embeddings from audio"""
+ print(f"Computing audio embedding from {raw_audio.shape}")
+
+ audio = (
+ (raw_audio - raw_audio.mean()) / torch.sqrt(raw_audio.var() + 1e-7)
+ ).unsqueeze(0)
+ chunks = 16000 * 20
+
+ # Create a progress bar for Gradio
+ progress = gr.Progress()
+
+ # Get audio embeddings
+ audio_embeddings = []
+ splits = list(calculate_splits(audio, chunks))
+ total_splits = len(splits)
+
+ for i, chunk in enumerate(splits):
+ # Update progress bar
+ progress(i / total_splits, desc="Processing audio chunks")
+
+ # Convert audio chunk to half precision
+ if torch.cuda.is_available():
+ chunk_cuda = chunk.cuda().half()
+ else:
+ chunk_cuda = chunk.cuda()
+
+ # Always use autocast for FP16 computation
+ with amp.autocast(enabled=True):
+ hidden_states = hubert_model(chunk_cuda)[0]
+
+ audio_embeddings.append(hidden_states)
+ audio_embeddings = torch.cat(audio_embeddings, dim=1)
+
+ # audio_embeddings = self.model.wav2vec2(rearrange(audio_frames, "f s -> () (f s)"))[0]
+ if audio_embeddings.shape[1] % 2 != 0:
+ audio_embeddings = torch.cat(
+ [audio_embeddings, torch.zeros_like(audio_embeddings[:, :1])], dim=1
+ )
+ audio_embeddings = rearrange(audio_embeddings, "() (f d) c -> f d c", d=2)
+ torch.cuda.empty_cache()
+
+ return audio_embeddings
+
+
+@torch.no_grad()
+def compute_wavlm_embedding(raw_audio):
+ """Compute embeddings from audio"""
+ audio = rearrange(raw_audio, "(f s) -> f s", s=640)
+
+ if audio.shape[0] % 2 != 0:
+ audio = torch.cat([audio, torch.zeros(1, 640)], dim=0)
+ chunks = 500
+
+ # Create a progress bar for Gradio
+ progress = gr.Progress()
+
+ # Get audio embeddings
+ audio_embeddings = []
+ splits = list(calculate_splits(audio, chunks))
+ total_splits = len(splits)
+
+ for i, chunk in enumerate(splits):
+ # Update progress bar
+ progress(i / total_splits, desc="Processing audio chunks")
+
+ # Convert chunk to half precision
+ if torch.cuda.is_available():
+ chunk_cuda = chunk.unsqueeze(0).cuda().half()
+ else:
+ chunk_cuda = chunk.unsqueeze(0).cuda()
+
+ # Always use autocast for FP16 computation
+ with amp.autocast(enabled=True):
+ wavlm_hidden_states = wavlm_model(chunk_cuda).squeeze(0)
+
+ audio_embeddings.append(wavlm_hidden_states)
+ audio_embeddings = torch.cat(audio_embeddings, dim=0)
+
+ torch.cuda.empty_cache()
+
+ return audio_embeddings
+
+
+@torch.no_grad()
+def extract_video_landmarks(video_frames):
+ """Extract landmarks from video frames"""
+
+ # Create a progress bar for Gradio
+ progress = gr.Progress()
+
+ landmarks = []
+ batch_size = 10
+
+ for i in range(0, len(video_frames), batch_size):
+ # Update progress bar
+ progress(i / len(video_frames), desc="Extracting facial landmarks")
+
+ batch = video_frames[i : i + batch_size].cpu().float()
+ batch_landmarks = landmarks_extractor.extract_landmarks(batch)
+ landmarks.extend(batch_landmarks)
+
+ torch.cuda.empty_cache()
+
+ # Convert landmarks to a list of numpy arrays with consistent shape
+ processed_landmarks = []
+
+ expected_shape = (68, 2) # Common shape for facial landmarks
+
+ # Process each landmark to ensure consistent shape
+ last_valid_landmark = None
+ for i, lm in enumerate(landmarks):
+ if lm is not None and isinstance(lm, np.ndarray) and lm.shape == expected_shape:
+ processed_landmarks.append(lm)
+ last_valid_landmark = lm
+ else:
+ # Print information about inconsistent landmarks
+ if lm is None:
+ print(f"Warning: Landmark at index {i} is None")
+ elif not isinstance(lm, np.ndarray):
+ print(
+ f"Warning: Landmark at index {i} is not a numpy array, type: {type(lm)}"
+ )
+ elif lm.shape != expected_shape:
+ print(
+ f"Warning: Landmark at index {i} has shape {lm.shape}, expected {expected_shape}"
+ )
+
+ # Replace invalid landmarks with the closest valid landmark if available
+ if last_valid_landmark is not None:
+ processed_landmarks.append(last_valid_landmark.copy())
+ else:
+ # If no valid landmark has been seen yet, look ahead for a valid one
+ found_future_valid = False
+ for future_lm in landmarks[i + 1 :]:
+ if (
+ future_lm is not None
+ and isinstance(future_lm, np.ndarray)
+ and future_lm.shape == expected_shape
+ ):
+ processed_landmarks.append(future_lm.copy())
+ found_future_valid = True
+ break
+
+ # If no valid landmark found in the future, use zeros
+ if not found_future_valid:
+ processed_landmarks.append(np.zeros(expected_shape))
+
+ return np.array(processed_landmarks)
+
+
+@torch.no_grad()
+def sample(
+ audio_list,
+ gt_keyframes,
+ masks_keyframes,
+ to_remove,
+ test_keyframes_list,
+ num_frames,
+ device,
+ emb,
+ force_uc_zero_embeddings,
+ n_batch_keyframes,
+ n_batch,
+ test_interpolation_list,
+ audio_interpolation_list,
+ masks_interpolation,
+ gt_interpolation,
+ model_keyframes,
+ model,
+):
+ # Create a progress bar for Gradio
+ progress = gr.Progress()
+
+ condition = torch.zeros(1, 3, 512, 512).to(device)
+ if torch.cuda.is_available():
+ condition = condition.half()
+
+ audio_list = rearrange(audio_list, "(b t) c d -> b t c d", t=num_frames)
+ gt_keyframes = rearrange(gt_keyframes, "(b t) c h w -> b t c h w", t=num_frames)
+ # Rearrange masks_keyframes and save locally
+ masks_keyframes = rearrange(
+ masks_keyframes, "(b t) c h w -> b t c h w", t=num_frames
+ )
+
+ # Convert to_remove into chunks of num_frames
+ to_remove_chunks = [
+ to_remove[i : i + num_frames] for i in range(0, len(to_remove), num_frames)
+ ]
+ test_keyframes_list = [
+ test_keyframes_list[i : i + num_frames]
+ for i in range(0, len(test_keyframes_list), num_frames)
+ ]
+
+ audio_cond = audio_list
+ if emb is not None:
+ embbedings = emb.unsqueeze(0).to(device)
+ if torch.cuda.is_available():
+ embbedings = embbedings.half()
+ else:
+ embbedings = None
+
+ # One batch of keframes is approximately 7 seconds
+ chunk_size = 2
+ complete_video = []
+ start_idx = 0
+ last_frame_z = None
+ last_frame_x = None
+ last_keyframe_idx = None
+ last_to_remove = None
+
+ total_chunks = (len(audio_cond) + chunk_size - 1) // chunk_size
+
+ for chunk_idx, chunk_start in enumerate(range(0, len(audio_cond), chunk_size)):
+ # Update progress bar
+ progress(chunk_idx / total_chunks, desc="Generating video")
+
+ # Clear GPU cache between chunks
+ torch.cuda.empty_cache()
+
+ chunk_end = min(chunk_start + chunk_size, len(audio_cond))
+
+ chunk_audio_cond = audio_cond[chunk_start:chunk_end].cuda()
+ if torch.cuda.is_available():
+ chunk_audio_cond = chunk_audio_cond.half()
+
+ chunk_gt_keyframes = gt_keyframes[chunk_start:chunk_end].cuda()
+ chunk_masks = masks_keyframes[chunk_start:chunk_end].cuda()
+
+ if torch.cuda.is_available():
+ chunk_gt_keyframes = chunk_gt_keyframes.half()
+ chunk_masks = chunk_masks.half()
+
+ test_keyframes_list_unwrapped = [
+ elem
+ for sublist in test_keyframes_list[chunk_start:chunk_end]
+ for elem in sublist
+ ]
+ to_remove_chunks_unwrapped = [
+ elem
+ for sublist in to_remove_chunks[chunk_start:chunk_end]
+ for elem in sublist
+ ]
+
+ if last_keyframe_idx is not None:
+ test_keyframes_list_unwrapped = [
+ last_keyframe_idx
+ ] + test_keyframes_list_unwrapped
+ to_remove_chunks_unwrapped = [last_to_remove] + to_remove_chunks_unwrapped
+
+ last_keyframe_idx = test_keyframes_list_unwrapped[-1]
+ last_to_remove = to_remove_chunks_unwrapped[-1]
+ # Find the first non-None keyframe in the chunk
+ first_keyframe = next(
+ (kf for kf in test_keyframes_list_unwrapped if kf is not None), None
+ )
+
+ # Find the last non-None keyframe in the chunk
+ last_keyframe = next(
+ (kf for kf in reversed(test_keyframes_list_unwrapped) if kf is not None),
+ None,
+ )
+
+ start_idx = next(
+ (
+ idx
+ for idx, comb in enumerate(test_interpolation_list)
+ if comb[0] == first_keyframe
+ ),
+ None,
+ )
+ end_idx = next(
+ (
+ idx
+ for idx, comb in enumerate(reversed(test_interpolation_list))
+ if comb[1] == last_keyframe
+ ),
+ None,
+ )
+
+ if start_idx is not None and end_idx is not None:
+ end_idx = (
+ len(test_interpolation_list) - 1 - end_idx
+ ) # Adjust for reversed enumeration
+ end_idx += 1
+ if start_idx is None:
+ break
+ if end_idx < start_idx:
+ end_idx = len(audio_interpolation_list)
+
+ audio_interpolation_list_chunk = audio_interpolation_list[start_idx:end_idx]
+ chunk_masks_interpolation = masks_interpolation[start_idx:end_idx]
+ gt_interpolation_chunks = gt_interpolation[start_idx:end_idx]
+
+ if torch.cuda.is_available():
+ audio_interpolation_list_chunk = [
+ chunk.half() for chunk in audio_interpolation_list_chunk
+ ]
+ chunk_masks_interpolation = [
+ chunk.half() for chunk in chunk_masks_interpolation
+ ]
+ gt_interpolation_chunks = [
+ chunk.half() for chunk in gt_interpolation_chunks
+ ]
+
+ progress(chunk_idx / total_chunks, desc="Generating keyframes")
+
+ # Always use autocast for FP16 computation
+ with amp.autocast(enabled=True):
+ samples_z = sample_keyframes(
+ model_keyframes,
+ chunk_audio_cond,
+ chunk_gt_keyframes,
+ chunk_masks,
+ condition.cuda(),
+ num_frames,
+ 24,
+ 0.0,
+ device,
+ embbedings.cuda() if embbedings is not None else None,
+ force_uc_zero_embeddings,
+ n_batch_keyframes,
+ 0,
+ 1.0,
+ None,
+ gt_as_cond=False,
+ )
+
+ if last_frame_x is not None:
+ # samples_x = torch.cat([last_frame_x.unsqueeze(0), samples_x], axis=0)
+ samples_z = torch.cat([last_frame_z.unsqueeze(0), samples_z], axis=0)
+
+ # last_frame_x = samples_x[-1]
+ last_frame_z = samples_z[-1]
+
+ progress(chunk_idx / total_chunks, desc="Interpolating frames")
+
+ # Always use autocast for FP16 computation
+ with amp.autocast(enabled=True):
+ vid = sample_interpolation(
+ model,
+ samples_z,
+ # samples_x,
+ audio_interpolation_list_chunk,
+ gt_interpolation_chunks,
+ chunk_masks_interpolation,
+ condition.cuda(),
+ num_frames,
+ device,
+ 1,
+ 24,
+ 0.0,
+ force_uc_zero_embeddings,
+ n_batch,
+ chunk_size,
+ 1.0,
+ None,
+ cut_audio=False,
+ to_remove=to_remove_chunks_unwrapped,
+ )
+
+ if chunk_start == 0:
+ complete_video = vid
+ else:
+ complete_video = np.concatenate([complete_video[:-1], vid], axis=0)
+
+ return complete_video
+
+
+def process_video(video_input, audio_input, max_num_seconds):
+ """Main processing function to generate synchronized video"""
+
+ # Display a message to the user about the processing time
+ gr.Info("Processing video. This may take a while...", duration=10)
+ gr.Info(
+ "If you're tired of waiting, try playing the Wordle game in the other tab!",
+ duration=10,
+ )
+
+ # Use default media if none provided
+ if video_input is None:
+ video_input = DEFAULT_VIDEO_PATH
+ print(f"Using default video: {DEFAULT_VIDEO_PATH}")
+
+ if audio_input is None:
+ audio_input = DEFAULT_AUDIO_PATH
+ print(f"Using default audio: {DEFAULT_AUDIO_PATH}")
+
+ try:
+ # Calculate hashes for cache keys
+ video_path_hash = video_input
+ audio_path_hash = audio_input
+
+ # Check if we need to recompute video embeddings
+ video_cache_hit = cache["video"]["path"] == video_path_hash
+ audio_cache_hit = cache["audio"]["path"] == audio_path_hash
+
+ if video_cache_hit and audio_cache_hit:
+ print("Using cached video and audio computations")
+ # Make copies of cached data to avoid modifying cache
+ video_embedding = cache["video"]["embedding"].clone()
+ video_frames = cache["video"]["frames"].clone()
+ video_landmarks = cache["video"]["landmarks"].copy()
+ raw_audio = cache["audio"]["raw_audio"].clone()
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
+ hubert_embedding = cache["audio"]["hubert_embedding"].clone()
+ wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
+
+ # Ensure all data is truncated to the same length if needed
+ min_len = min(
+ len(video_frames),
+ len(raw_audio),
+ len(hubert_embedding),
+ len(wavlm_embedding),
+ )
+ video_frames = video_frames[:min_len]
+ video_embedding = video_embedding[:min_len]
+ video_landmarks = video_landmarks[:min_len]
+ raw_audio = raw_audio[:min_len]
+ hubert_embedding = hubert_embedding[:min_len]
+ wavlm_embedding = wavlm_embedding[:min_len]
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
+
+ else:
+ # Process video if needed
+ if not video_cache_hit:
+ print("Computing video embeddings and landmarks")
+ video_reader = decord.VideoReader(video_input)
+ decord.bridge.set_bridge("torch")
+
+ if not audio_cache_hit:
+ # Need to process audio to determine min_len
+ raw_audio = get_raw_audio(audio_input, 16000)
+ if len(raw_audio) == 0 or len(video_reader) == 0:
+ raise ValueError("Empty audio or video input")
+
+ min_len = min(len(raw_audio), len(video_reader))
+
+ # Store full audio in cache
+ cache["audio"]["path"] = audio_path_hash
+ cache["audio"]["raw_audio"] = raw_audio.clone()
+
+ # Create truncated copy for processing
+ raw_audio = raw_audio[:min_len]
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
+ else:
+ # Use cached audio - make a copy
+ if cache["audio"]["raw_audio"] is None:
+ raise ValueError("Cached audio is None")
+
+ raw_audio = cache["audio"]["raw_audio"].clone()
+ if len(raw_audio) == 0 or len(video_reader) == 0:
+ raise ValueError("Empty cached audio or video input")
+
+ min_len = min(len(raw_audio), len(video_reader))
+
+ # Create truncated copy for processing
+ raw_audio = raw_audio[:min_len]
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
+
+ # Compute video embeddings and landmarks - store full version in cache
+ video_embedding, video_frames = compute_video_embedding(
+ video_reader, len(video_reader)
+ )
+ video_landmarks = extract_video_landmarks(video_frames)
+
+ # Update video cache with full versions
+ cache["video"]["path"] = video_path_hash
+ cache["video"]["embedding"] = video_embedding
+ cache["video"]["frames"] = video_frames
+ cache["video"]["landmarks"] = video_landmarks
+
+ # Create truncated copies for processing
+ video_embedding = video_embedding[:min_len]
+ video_frames = video_frames[:min_len]
+ video_landmarks = video_landmarks[:min_len]
+
+ else:
+ # Use cached video data - make copies
+ print("Using cached video computations")
+
+ if (
+ cache["video"]["embedding"] is None
+ or cache["video"]["frames"] is None
+ or cache["video"]["landmarks"] is None
+ ):
+ raise ValueError("One or more video cache entries are None")
+
+ if not audio_cache_hit:
+ # New audio with cached video
+ raw_audio = get_raw_audio(audio_input, 16000)
+ if len(raw_audio) == 0:
+ raise ValueError("Empty audio input")
+
+ # Store full audio in cache
+ cache["audio"]["path"] = audio_path_hash
+ cache["audio"]["raw_audio"] = raw_audio.clone()
+
+ # Make copies of video data
+ video_embedding = cache["video"]["embedding"].clone()
+ video_frames = cache["video"]["frames"].clone()
+ video_landmarks = cache["video"]["landmarks"].copy()
+
+ # Determine truncation length and create truncated copies
+ min_len = min(len(raw_audio), len(video_frames))
+ raw_audio = raw_audio[:min_len]
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
+ video_frames = video_frames[:min_len]
+ video_embedding = video_embedding[:min_len]
+ video_landmarks = video_landmarks[:min_len]
+ else:
+ # Both video and audio are cached - should not reach here
+ # as it's handled in the first if statement
+ pass
+
+ # Process audio if needed
+ if not audio_cache_hit:
+ print("Computing audio embeddings")
+
+ # Compute audio embeddings with the truncated audio
+ hubert_embedding = compute_hubert_embedding(raw_audio_reshape)
+ wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape)
+
+ # Update audio cache with full embeddings
+ # Note: raw_audio was already cached above
+ cache["audio"]["hubert_embedding"] = hubert_embedding.clone()
+ cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone()
+ else:
+ # Use cached audio data - make copies
+ if (
+ cache["audio"]["hubert_embedding"] is None
+ or cache["audio"]["wavlm_embedding"] is None
+ ):
+ raise ValueError(
+ "One or more audio embedding cache entries are None"
+ )
+
+ hubert_embedding = cache["audio"]["hubert_embedding"].clone()
+ wavlm_embedding = cache["audio"]["wavlm_embedding"].clone()
+
+ # Make sure embeddings match the truncated video length if needed
+ if "min_len" in locals() and (
+ min_len < len(hubert_embedding) or min_len < len(wavlm_embedding)
+ ):
+ hubert_embedding = hubert_embedding[:min_len]
+ wavlm_embedding = wavlm_embedding[:min_len]
+
+ # Apply max_num_seconds limit if specified
+ if max_num_seconds > 0:
+ # Convert seconds to frames (assuming 25 fps)
+ max_frames = int(max_num_seconds * 25)
+
+ # Truncate all data to max_frames
+ video_embedding = video_embedding[:max_frames]
+ video_frames = video_frames[:max_frames]
+ video_landmarks = video_landmarks[:max_frames]
+ hubert_embedding = hubert_embedding[:max_frames]
+ wavlm_embedding = wavlm_embedding[:max_frames]
+ raw_audio = raw_audio[:max_frames]
+ raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)")
+
+ # Validate shapes before proceeding
+ assert video_embedding.shape[0] == hubert_embedding.shape[0], (
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})"
+ )
+ assert video_embedding.shape[0] == wavlm_embedding.shape[0], (
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})"
+ )
+ assert video_embedding.shape[0] == video_landmarks.shape[0], (
+ f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})"
+ )
+
+ print(f"Hubert embedding shape: {hubert_embedding.shape}")
+ print(f"WavLM embedding shape: {wavlm_embedding.shape}")
+ print(f"Video embedding shape: {video_embedding.shape}")
+ print(f"Video landmarks shape: {video_landmarks.shape}")
+
+ # Create pipeline inputs for models
+ (
+ interpolation_chunks,
+ keyframe_chunks,
+ audio_interpolation_chunks,
+ audio_keyframe_chunks,
+ emb_cond,
+ masks_keyframe_chunks,
+ masks_interpolation_chunks,
+ to_remove,
+ audio_interpolation_idx,
+ audio_keyframe_idx,
+ ) = create_pipeline_inputs(
+ hubert_embedding,
+ wavlm_embedding,
+ 14,
+ video_embedding,
+ video_landmarks,
+ overlap=1,
+ add_zero_flag=True,
+ mask_arms=None,
+ nose_index=28,
+ )
+
+ complete_video = sample(
+ audio_keyframe_chunks,
+ keyframe_chunks,
+ masks_keyframe_chunks,
+ to_remove,
+ audio_keyframe_idx,
+ 14,
+ "cuda",
+ emb_cond,
+ [],
+ 3,
+ 3,
+ audio_interpolation_idx,
+ audio_interpolation_chunks,
+ masks_interpolation_chunks,
+ interpolation_chunks,
+ keyframe_model,
+ interpolation_model,
+ )
+
+ complete_audio = rearrange(
+ raw_audio[: complete_video.shape[0]], "f s -> () (f s)"
+ )
+
+ # 4. Convert frames to video and combine with audio
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video:
+ output_path = temp_video.name
+
+ print("Saving video to", output_path)
+
+ save_audio_video(complete_video, audio=complete_audio, save_path=output_path)
+ torch.cuda.empty_cache()
+ return output_path
+
+ except Exception as e:
+ raise e
+ print(f"Error processing video: {str(e)}")
+ return None
+
+
+def get_max_duration(video_input, audio_input):
+ """Get the maximum duration in seconds for the slider"""
+ try:
+ # Default to 60 seconds if files don't exist
+ if video_input is None or not os.path.exists(video_input):
+ video_input = DEFAULT_VIDEO_PATH
+
+ if audio_input is None or not os.path.exists(audio_input):
+ audio_input = DEFAULT_AUDIO_PATH
+
+ # Get video duration
+ video_reader = decord.VideoReader(video_input)
+ video_duration = len(video_reader) / video_reader.get_avg_fps()
+
+ # Get audio duration
+ raw_audio = get_raw_audio(audio_input, 16000)
+ audio_duration = len(raw_audio) / 25 # Assuming 25 fps
+
+ # Return the minimum of the two durations
+ return min(video_duration, audio_duration)
+ except Exception as e:
+ print(f"Error getting max duration: {str(e)}")
+ return 60 # Default to 60 seconds
+
+
+def new_game_click(state):
+ """Handle the 'New Game' button click."""
+ message = state.new_game()
+ feedback_history = state.get_feedback_history()
+ return state, feedback_history, message
+
+
+def submit_guess_click(guess, state):
+ """Handle the 'Submit Guess' button click."""
+ message = state.submit_guess(guess)
+ feedback_history = state.get_feedback_history()
+ return state, feedback_history, message
+
+
+# Create Gradio interface
+with gr.Blocks(title="Video Synchronization with Diffusion Models") as demo:
+ gr.Markdown("# Video Synchronization with Diffusion Models")
+ gr.Markdown(
+ "Upload a video and audio to create a synchronized video with the same visuals but synchronized to the new audio."
+ )
+
+ with gr.Tabs():
+ with gr.TabItem("Video Synchronization"):
+ with gr.Row():
+ with gr.Column():
+ video_input = gr.Video(
+ label="Input Video",
+ value=DEFAULT_VIDEO_PATH
+ if os.path.exists(DEFAULT_VIDEO_PATH)
+ else None,
+ width=512,
+ height=512,
+ )
+ audio_input = gr.Audio(
+ label="Input Audio",
+ type="filepath",
+ value=DEFAULT_AUDIO_PATH
+ if os.path.exists(DEFAULT_AUDIO_PATH)
+ else None,
+ )
+
+ max_duration = gr.State(value=60) # Default max duration
+
+ max_seconds_slider = gr.Slider(
+ minimum=0,
+ maximum=60, # Will be updated dynamically
+ value=0,
+ step=1,
+ label="Max Duration (seconds, 0 = full length)",
+ info="Limit the processing duration (0 means use full length)",
+ )
+
+ process_button = gr.Button("Generate Synchronized Video")
+
+ with gr.Column("Output Video"):
+ video_output = gr.Video(label="Output Video", width=512, height=512)
+
+ # Update slider max value when inputs change
+ def update_slider_max(video, audio):
+ max_dur = get_max_duration(video, audio)
+ return {"maximum": max_dur, "__type__": "update"}
+
+ video_input.change(
+ update_slider_max, [video_input, audio_input], [max_seconds_slider]
+ )
+ audio_input.change(
+ update_slider_max, [video_input, audio_input], [max_seconds_slider]
+ )
+
+ # Show Wordle message when processing starts and hide when complete
+ process_button.click(
+ fn=process_video,
+ inputs=[video_input, audio_input, max_seconds_slider],
+ outputs=video_output,
+ )
+
+ with gr.TabItem("Wordle Game"):
+ state = gr.State(WordleGame()) # Persist the WordleGame instance
+ guess_input = gr.Textbox(label="Your guess (5 letters)", max_length=5)
+ submit_btn = gr.Button("Submit Guess")
+ new_game_btn = gr.Button("New Game")
+ feedback_display = gr.HTML(label="Guesses")
+ message_display = gr.Textbox(
+ label="Message", interactive=False, value="Click 'New Game' to start."
+ )
+ # Connect the 'New Game' button
+ new_game_btn.click(
+ fn=new_game_click,
+ inputs=[state],
+ outputs=[state, feedback_display, message_display],
+ )
+ # Connect the 'Submit Guess' button
+ submit_btn.click(
+ fn=submit_guess_click,
+ inputs=[guess_input, state],
+ outputs=[state, feedback_display, message_display],
+ )
+
+ gr.Markdown("## How it works")
+ gr.Markdown("""
+ 1. The system extracts embeddings and landmarks from the input video
+ 2. Audio embeddings are computed from the input audio
+ 3. A keyframe model generates key visual frames
+ 4. An interpolation model creates a smooth video between keyframes
+ 5. The final video is rendered with the new audio
+ """)
+
+if __name__ == "__main__":
+ demo.launch()
diff --git a/data_utils.py b/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..38f37a022387b9faddbddcedffb9a2236a3cf147
--- /dev/null
+++ b/data_utils.py
@@ -0,0 +1,635 @@
+import torch
+import numpy as np
+from PIL import Image, ImageDraw
+import cv2
+from functools import partial
+import math
+
+
+def get_size(img):
+ if isinstance(img, (np.ndarray, torch.Tensor)):
+ return img.shape[1::-1]
+ else:
+ return img.size
+
+
+def imresample(img, sz):
+ im_data = torch.nn.functional.interpolate(img, size=sz, mode="area")
+ return im_data
+
+
+def crop_resize(img, box, image_size):
+ if isinstance(img, np.ndarray):
+ img = img[box[1] : box[3], box[0] : box[2]]
+ out = cv2.resize(
+ img, (image_size, image_size), interpolation=cv2.INTER_AREA
+ ).copy()
+ elif isinstance(img, torch.Tensor):
+ img = img[box[1] : box[3], box[0] : box[2]]
+ out = (
+ imresample(
+ img.permute(2, 0, 1).unsqueeze(0).float(), (image_size, image_size)
+ )
+ .byte()
+ .squeeze(0)
+ .permute(1, 2, 0)
+ )
+ else:
+ out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
+ return out
+
+
+def fixed_image_standardization(image_tensor):
+ processed_tensor = (image_tensor - 127.5) / 128.0
+ return processed_tensor
+
+
+def extract_face(img, landmarks, image_size=160, margin=0, postprocess=False):
+ """Extract face + margin from images given facial landmarks.
+
+ Arguments:
+ img {PIL.Image/torch.Tensor/np.ndarray} -- Input image(s) with shape (B, H, W, C)
+ landmarks {numpy.ndarray} -- Facial landmarks with shape (B, 68, 2)
+ image_size {int} -- Output image size in pixels. The image will be square.
+ margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
+ postprocess {bool} -- Whether to apply standardization
+
+ Returns:
+ torch.tensor -- tensor representing the extracted faces with shape (B, H, W, C)
+ """
+ # Calculate bounding boxes from landmarks for all faces in batch
+ x_min = np.min(landmarks, axis=1)[:, 0] # Shape: (B,)
+ y_min = np.min(landmarks, axis=1)[:, 1] # Shape: (B,)
+ x_max = np.max(landmarks, axis=1)[:, 0] # Shape: (B,)
+ y_max = np.max(landmarks, axis=1)[:, 1] # Shape: (B,)
+
+ # Calculate margin for top only
+ box_height = y_max - y_min
+ top_margin = margin * box_height / (image_size - margin)
+
+ # Create boxes for all faces
+ boxes = np.stack(
+ [
+ x_min,
+ np.maximum(y_min - top_margin, 0), # Only add margin to top
+ x_max,
+ y_max,
+ ],
+ axis=1,
+ ).astype(int) # Shape: (B, 4)
+
+ # Process each face in the batch
+ faces = []
+ for i in range(len(boxes)):
+ face = crop_resize(img[i], boxes[i], image_size)
+ faces.append(face)
+
+ faces = torch.stack(faces, dim=0)
+ faces = faces.float()
+
+ if postprocess:
+ faces = fixed_image_standardization(faces)
+
+ return faces
+
+
+def crop_mouth_region(images, landmarks, crop_size=96):
+ """
+ Takes a fixed-size square crop centered on the mouth region.
+
+ Parameters:
+ - images: tensor/array of shape (num_frames, height, width, channels) or (height, width, channels)
+ - landmarks: numpy array of shape (num_frames, 68, 2) or (68, 2)
+ - crop_size: size of the square crop (both height and width)
+ - padding: percentage of padding around the mouth region (0.0 to 1.0)
+
+ Returns:
+ - List of fixed-size crops or single crop if input is single image
+ """
+ # Handle single image case
+ single_image = False
+ if len(images.shape) == 3:
+ images = images[None]
+ landmarks = landmarks[None]
+ single_image = True
+
+ num_frames = len(images)
+ crops = []
+
+ # Mouth landmarks indices (48-67 for mouth region)
+ mouth_indices = range(48, 68)
+
+ for i in range(num_frames):
+ # Get mouth landmarks for current frame
+ mouth_landmarks = landmarks[i][mouth_indices]
+
+ # Find center of mouth
+ center_x = int(np.mean(mouth_landmarks[:, 0]))
+ center_y = int(np.mean(mouth_landmarks[:, 1]))
+
+ # Calculate crop boundaries
+ half_size = crop_size // 2
+ left = max(0, center_x - half_size)
+ right = min(images.shape[2], center_x + half_size)
+ top = max(0, center_y - half_size)
+ bottom = min(images.shape[1], center_y + half_size)
+
+ # Adjust if crop would go out of bounds
+ if left == 0:
+ right = crop_size
+ if right == images.shape[2]:
+ left = images.shape[2] - crop_size
+ if top == 0:
+ bottom = crop_size
+ if bottom == images.shape[1]:
+ top = images.shape[1] - crop_size
+
+ # Take the crop
+ crop = images[i, top:bottom, left:right]
+ crops.append(crop)
+
+ return crops[0] if single_image else crops
+
+
+def create_masks_from_landmarks_box(
+ landmark_list, img_shape, nose_index=28, dtype="uint8", box_expand=0.0
+):
+ height, width = img_shape[:2]
+ num_frames = landmark_list.shape[0]
+
+ # Initialize the masks array
+ masks = np.zeros((num_frames, height, width), dtype=dtype)
+
+ if 0 <= box_expand < 1:
+ box_expand = int(box_expand * width)
+
+ for i in range(num_frames):
+ # Get the landmarks for the current frame
+ landmarks = landmark_list[i]
+
+ # Get the y-coordinate of the nose landmark
+ nose_point_h = landmarks[nose_index, 1]
+ cut_h = nose_point_h
+
+ # Find the leftmost and rightmost landmarks
+ far_left_index = np.argmin(landmarks[:, 0])
+ far_right_index = np.argmax(landmarks[:, 0])
+
+ # Define the points for the mask contour
+ left_up_point = np.array(
+ [landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32
+ )
+ left_down_point = np.array(
+ [landmarks[far_left_index][0], height], dtype=np.int32
+ )
+ right_up_point = np.array(
+ [landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32
+ )
+ right_down_point = np.array(
+ [landmarks[far_right_index][0], height], dtype=np.int32
+ )
+
+ # Define the contour
+ contour = np.array(
+ [[left_up_point, left_down_point, right_down_point, right_up_point]]
+ )
+
+ # Draw the contour on the mask
+ cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
+
+ return torch.from_numpy(masks)
+
+
+def create_masks_from_landmarks_full_size(
+ landmarks_batch,
+ image_height,
+ image_width,
+ start_index=48,
+ end_index=68,
+ offset=0,
+ nose_index=33,
+):
+ """
+ Efficiently creates a batch of masks using vectorized operations where each mask has ones from the highest
+ landmark in the specified range (adjusted by an offset) to the bottom of the image, and zeros otherwise.
+
+ Parameters:
+ - landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
+ - image_height (int): The height of the image for which masks are created.
+ - image_width (int): The width of the image for which masks are created.
+ - start_index (int): The starting index of the range to check (inclusive).
+ - end_index (int): The ending index of the range to check (inclusive).
+ - offset (int): An offset to add or subtract from the y-coordinate of the highest landmark.
+
+ Returns:
+ - np.array: An array of masks of shape (B, image_height, image_width) for each batch.
+ """
+ # Extract the y-coordinates for the specified range across all batches
+ y_coords = landmarks_batch[:, nose_index : nose_index + 1, 1]
+
+ # Find the index of the minimum y-coordinate in the specified range for each batch
+ min_y_indices = np.argmin(y_coords, axis=1)
+
+ # Gather the highest landmarks' y-coordinates using the indices found
+ highest_y_coords = y_coords[np.arange(len(y_coords)), min_y_indices]
+
+ if abs(offset) < 1 and abs(offset) > 0:
+ offset = int(offset * image_height)
+
+ # Apply the offset to the highest y-coordinate
+ adjusted_y_coords = highest_y_coords + offset
+
+ # Clip the coordinates to stay within image boundaries
+ adjusted_y_coords = np.clip(adjusted_y_coords, 0, image_height - 1)
+
+ # Use broadcasting to create a mask without loops
+ # Create a range of indices from 0 to image_height - 1
+ all_indices = np.arange(image_height)
+
+ # Compare each index in 'all_indices' to each 'adjusted_y_coord' in the batch
+ # 'all_indices' has shape (image_height,), we reshape to (1, image_height) to broadcast against (B, 1)
+ mask_2d = (all_indices >= adjusted_y_coords[:, None]).astype(int)
+
+ # Extend the 2D mask to a full 3D mask of size (B, image_height, image_width)
+ full_mask = np.tile(mask_2d[:, :, np.newaxis], (1, 1, image_width))
+
+ return torch.from_numpy(full_mask)
+
+
+def expand_polygon(polygon, expand_size):
+ """
+ Expands the polygon outward by a specified number of pixels.
+
+ Parameters:
+ - polygon (list of tuples): The polygon points as (x, y).
+ - expand_size (int): The number of pixels to expand the polygon outward.
+
+ Returns:
+ - expanded_polygon (list of tuples): The expanded polygon points as (x, y).
+ """
+ if expand_size == 0:
+ return polygon
+
+ # Calculate centroid of the polygon
+ centroid_x = sum([point[0] for point in polygon]) / len(polygon)
+ centroid_y = sum([point[1] for point in polygon]) / len(polygon)
+
+ # Expand each point outward from the centroid
+ expanded_polygon = []
+ for x, y in polygon:
+ vector_x = x - centroid_x
+ vector_y = y - centroid_y
+ length = np.sqrt(vector_x**2 + vector_y**2)
+ if length == 0:
+ expanded_polygon.append((x, y))
+ else:
+ new_x = x + expand_size * (vector_x / length)
+ new_y = y + expand_size * (vector_y / length)
+ expanded_polygon.append((int(new_x), int(new_y)))
+
+ return expanded_polygon
+
+
+def create_masks_from_landmarks_mouth(
+ landmark_list, img_shape, nose_index=33, dtype="uint8", box_expand=0.0
+):
+ height, width = img_shape[:2]
+ num_frames = landmark_list.shape[0]
+
+ # Initialize the masks array
+ masks = np.zeros((num_frames, height, width), dtype=dtype)
+
+ if 0 <= box_expand < 1:
+ box_expand = int(box_expand * width)
+
+ for i in range(num_frames):
+ # Get the landmarks for the current frame
+ landmarks = landmark_list[i]
+
+ # Get the y-coordinate of the nose landmark
+ nose_point_h = landmarks[nose_index, 1]
+ cut_h = nose_point_h
+
+ # Find the leftmost and rightmost landmarks
+ far_left_index = np.argmin(landmarks[:, 0])
+ far_right_index = np.argmax(landmarks[:, 0])
+
+ # Find lowest landmark y-coordinate
+ lowest_y = np.max(landmarks[:, 1])
+ # Add box_expand to the lowest point
+ lowest_y = min(height, lowest_y + box_expand)
+
+ # Define the points for the mask contour
+ left_up_point = np.array(
+ [landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32
+ )
+ left_down_point = np.array(
+ [landmarks[far_left_index][0], lowest_y], dtype=np.int32
+ )
+ right_up_point = np.array(
+ [landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32
+ )
+ right_down_point = np.array(
+ [landmarks[far_right_index][0], lowest_y], dtype=np.int32
+ )
+
+ # Define the contour
+ contour = np.array(
+ [[left_up_point, left_down_point, right_down_point, right_up_point]]
+ )
+
+ # Draw the contour on the mask
+ cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
+
+ return torch.from_numpy(masks)
+
+
+def create_face_mask_from_landmarks(
+ landmarks_batch, image_height, image_width, mask_expand=0
+):
+ """
+ Creates a batch of masks where each mask covers the face region using landmarks.
+
+ Parameters:
+ - landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
+ - image_height (int): The height of the image for which masks are created.
+ - image_width (int): The width of the image for which masks are created.
+ - mask_expand (int): The number of pixels to expand the mask outward.
+
+ Returns:
+ - np.array: An array of masks of shape (B, image_height, image_width) for each batch.
+ """
+ # Initialize an array to hold all masks
+ masks = np.zeros(
+ (landmarks_batch.shape[0], image_height, image_width), dtype=np.uint8
+ )
+
+ if abs(mask_expand) < 1 and abs(mask_expand) > 0:
+ mask_expand = int(mask_expand * image_height)
+
+ for i, landmarks in enumerate(landmarks_batch):
+ # Create a blank image for each mask
+ mask = Image.new("L", (image_width, image_height), 0)
+ draw = ImageDraw.Draw(mask)
+
+ # Extract relevant landmarks for the face
+ jawline_landmarks = landmarks[2:15] # Jawline
+ # upper_face_landmarks = landmarks[17:27] # Eyebrows and top of nose bridge
+
+ # Combine landmarks to form a polygon around the face
+ # face_polygon = np.concatenate((jawline_landmarks, upper_face_landmarks[::-1]), axis=0)
+ face_polygon = jawline_landmarks
+
+ # Convert landmarks to a list of tuples
+ face_polygon = [(int(x), int(y)) for x, y in face_polygon]
+
+ # Expand the polygon if necessary
+ expanded_polygon = expand_polygon(face_polygon, mask_expand)
+
+ # Draw the polygon and fill it
+ draw.polygon(expanded_polygon, outline=1, fill=1)
+
+ # Convert mask to numpy array and add it to the batch of masks
+ masks[i] = np.array(mask)
+
+ return torch.from_numpy(masks)
+
+
+ALL_FIXED_POINTS = (
+ [i for i in range(0, 4)]
+ + [i for i in range(13, 17)]
+ + [i for i in range(27, 36)]
+ + [36, 39, 42, 45]
+)
+
+
+def gaussian_kernel(sigma, width, height):
+ """Create a 2D Gaussian kernel."""
+ x = torch.arange(0, width, 1) - width // 2
+ y = torch.arange(0, height, 1) - height // 2
+ x = x.float()
+ y = y.float()
+ x2 = x**2
+ y2 = y[:, None] ** 2
+ g = torch.exp(-(x2 + y2) / (2 * sigma**2))
+ return g / g.sum()
+
+
+def generate_hm(landmarks, height, width, n_points="all", sigma=3):
+ if n_points == "all":
+ Nlandmarks = range(len(landmarks))
+ elif n_points == "fixed":
+ Nlandmarks = ALL_FIXED_POINTS
+ elif n_points == "stable":
+ Nlandmarks = [33, 36, 39, 42, 45]
+
+ kernel = gaussian_kernel(sigma, width, height)
+ hm = torch.zeros((height, width))
+ for I in Nlandmarks:
+ x0, y0 = landmarks[I]
+ x0, y0 = int(x0), int(y0)
+ left, right = max(0, x0 - width // 2), min(width, x0 + width // 2)
+ top, bottom = max(0, y0 - height // 2), min(height, y0 + height // 2)
+ hm[top:bottom, left:right] += kernel[
+ max(0, -y0 + height // 2) : min(height, height - y0 + height // 2),
+ max(0, -x0 + width // 2) : min(width, width - x0 + width // 2),
+ ]
+ # Normalize the heatmap to have values between 0 and 1
+ max_val = hm.max()
+ if max_val > 0:
+ hm /= max_val
+ return hm
+
+
+def get_heatmap(landmarks, image_size, or_im_size, n_points="stable", sigma=4):
+ stack = []
+ seq_length = landmarks.shape[0]
+ if or_im_size[0] != image_size[0] or or_im_size[1] != image_size[1]:
+ landmarks = scale_landmarks(landmarks, or_im_size, image_size)
+ gen_single_heatmap = partial(
+ generate_hm,
+ height=image_size[0],
+ width=image_size[1],
+ n_points=n_points,
+ sigma=sigma,
+ )
+ for i in range(seq_length):
+ stack.append(gen_single_heatmap(landmarks[i]))
+
+ return torch.stack(stack, axis=0).unsqueeze(0) # (1, seq_length, height, width)
+
+
+def scale_landmarks(landmarks, original_size, target_size):
+ """
+ Scale landmarks from original size to target size.
+
+ Parameters:
+ - landmarks (np.array): An array of shape (N, 2) containing facial landmarks.
+ - original_size (tuple): The size (height, width) for which the landmarks are currently scaled.
+ - target_size (tuple): The size (height, width) to which landmarks should be scaled.
+
+ Returns:
+ - scaled_landmarks (np.array): Scaled landmarks.
+ """
+ scale_y = target_size[0] / original_size[0]
+ scale_x = target_size[1] / original_size[1]
+ scaled_landmarks = landmarks * np.array([scale_x, scale_y])
+ return scaled_landmarks.astype(int)
+
+
+def draw_kps_image(
+ image_shape,
+ original_size,
+ landmarks,
+ color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)],
+ rgb=True,
+ pts_width=4,
+):
+ stick_width = pts_width
+ limb_seq = np.array([[0, 2], [1, 2]])
+ kps = landmarks[[36, 45, 33], :]
+ kps = scale_landmarks(kps, original_size, image_shape)
+ if not rgb: # Grayscale image
+ canvas = np.zeros((image_shape[0], image_shape[1], 1))
+ color_mode = "grayscale"
+ else: # Color image
+ canvas = np.zeros((image_shape[0], image_shape[1], 3))
+ color_mode = "color"
+
+ polygon_cache = {}
+
+ for index in limb_seq:
+ color = color_list[index[0]]
+ if color_mode == "grayscale":
+ color = (
+ int(0.299 * color[2] + 0.587 * color[1] + 0.114 * color[0]),
+ ) # Convert to grayscale intensity
+
+ x = kps[index][:, 0]
+ y = kps[index][:, 1]
+ length = np.sqrt((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2)
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
+
+ cache_key = (
+ color,
+ int(np.mean(x)),
+ int(np.mean(y)),
+ int(length / 2),
+ int(angle),
+ )
+ if cache_key not in polygon_cache:
+ polygon_cache[cache_key] = cv2.ellipse2Poly(
+ (int(np.mean(x)), int(np.mean(y))),
+ (int(length / 2), stick_width),
+ int(angle),
+ 0,
+ 360,
+ 1,
+ )
+
+ polygon = polygon_cache[cache_key]
+ cv2.fillConvexPoly(canvas, polygon, [int(c * 0.6) for c in color])
+
+ for idx, kp in enumerate(kps):
+ if color_mode == "grayscale":
+ color = (
+ int(
+ 0.299 * color_list[idx][2]
+ + 0.587 * color_list[idx][1]
+ + 0.114 * color_list[idx][0]
+ ),
+ )
+ else:
+ color = color_list[idx]
+ cv2.circle(canvas, (int(kp[0]), int(kp[1])), pts_width, color, -1)
+
+ return canvas.transpose(2, 0, 1)
+
+
+def create_landmarks_image(
+ landmarks,
+ original_size=(772, 772),
+ target_size=(772, 772),
+ point_size=3,
+ n_points="all",
+ dim=3,
+):
+ """
+ Creates an image of landmarks on a black background using efficient NumPy operations.
+
+ Parameters:
+ - landmarks (np.array): An array of shape (68, 2) containing facial landmarks.
+ - image_size (tuple): The size of the output image (height, width).
+ - point_size (int): The radius of each landmark point in pixels.
+
+ Returns:
+ - img (np.array): An image array with landmarks plotted.
+ """
+ if n_points == "all":
+ indexes = range(len(landmarks))
+ elif n_points == "fixed":
+ indexes = ALL_FIXED_POINTS
+ elif n_points == "stable":
+ indexes = [33, 36, 39, 42, 45]
+
+ landmarks = landmarks[indexes]
+
+ img = np.zeros(target_size, dtype=np.uint8)
+
+ landmarks = scale_landmarks(landmarks, original_size, target_size)
+
+ # Ensure the landmarks are in bounds and integer
+ landmarks = np.clip(
+ landmarks, [0, 0], [target_size[1] - 1, target_size[0] - 1]
+ ).astype(int)
+
+ # Get x and y coordinates from landmarks
+ x, y = landmarks[:, 0], landmarks[:, 1]
+
+ # Define a grid offset based on point_size around each landmark
+ offset = np.arange(-point_size // 2, point_size // 2 + 1)
+ grid_x, grid_y = np.meshgrid(offset, offset, indexing="ij")
+
+ # Calculate the full set of x and y coordinates for the points
+ full_x = x[:, np.newaxis, np.newaxis] + grid_x[np.newaxis, :, :]
+ full_y = y[:, np.newaxis, np.newaxis] + grid_y[np.newaxis, :, :]
+
+ # Clip the coordinates to stay within image boundaries
+ full_x = np.clip(full_x, 0, target_size[1] - 1)
+ full_y = np.clip(full_y, 0, target_size[0] - 1)
+
+ # Flatten the arrays to use them as indices
+ full_x = full_x.ravel()
+ full_y = full_y.ravel()
+
+ # Set the points in the image
+ img[full_y, full_x] = 255
+
+ return np.stack([img] * dim, axis=0)
+
+
+def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
+ len_file = audio.shape[-1]
+
+ if max_len_sec or max_len_raw:
+ max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
+ if len_file < int(max_len):
+ # dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
+ # extened_wav = np.concatenate((audio_data, dummy[0]))
+ extened_wav = torch.nn.functional.pad(
+ audio, (0, int(max_len) - len_file), "constant"
+ )
+ else:
+ extened_wav = audio[:, : int(max_len)]
+ else:
+ extened_wav = audio
+
+ return extened_wav
+
+
+def ssim_to_bin(ssim_score):
+ # Normalize the SSIM score to a 0-100 scale
+ normalized_diff_ssim = (1 - ((ssim_score + 1) / 2)) * 100
+ # Assign to one of the 100 bins
+ bin_index = float(min(np.floor(normalized_diff_ssim), 99))
+ return bin_index
diff --git a/inference_functions.py b/inference_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b396ed8e46e4bcb94d06be12e7efcd7000251fa
--- /dev/null
+++ b/inference_functions.py
@@ -0,0 +1,493 @@
+import torch
+from typing import Any, Dict, List, Optional, Tuple, Union
+import numpy as np
+from einops import rearrange, repeat
+import math
+
+
+def get_unique_embedder_keys_from_conditioner(conditioner):
+ return list(set([x.input_key for x in conditioner.embedders]))
+
+
+def get_batch(keys, value_dict, N, T, device):
+ batch = {}
+ batch_uc = {}
+
+ for key in keys:
+ if key == "fps_id":
+ batch[key] = (
+ torch.tensor([value_dict["fps_id"]])
+ .to(device)
+ .repeat(int(math.prod(N)))
+ )
+ elif key == "motion_bucket_id":
+ batch[key] = (
+ torch.tensor([value_dict["motion_bucket_id"]])
+ .to(device)
+ .repeat(int(math.prod(N)))
+ )
+ elif key == "cond_aug":
+ batch[key] = repeat(
+ torch.tensor([value_dict["cond_aug"]]).to(device),
+ "1 -> b",
+ b=math.prod(N),
+ )
+ elif key == "cond_frames":
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
+ elif key == "cond_frames_without_noise":
+ batch[key] = repeat(
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
+ )
+ else:
+ batch[key] = value_dict[key]
+
+ if T is not None:
+ batch["num_video_frames"] = T
+
+ for key in batch.keys():
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
+ batch_uc[key] = torch.clone(batch[key])
+ return batch, batch_uc
+
+
+def merge_overlapping_segments(segments: torch.Tensor, overlap: int) -> torch.Tensor:
+ """
+ Merges overlapping segments by averaging overlapping frames.
+ Segments have shape (b, t, ...), where 'b' is the number of segments,
+ 't' is frames per segment, and '...' are other dimensions.
+
+ Args:
+ segments: Tensor of shape (b, t, ...)
+ overlap: Integer, number of frames that overlap between consecutive segments
+
+ Returns:
+ Tensor of the merged video
+ """
+ # Get the shape details
+ b, t, *other_dims = segments.shape
+ num_frames = (b - 1) * (
+ t - overlap
+ ) + t # Calculate the total number of frames in the merged video
+
+ # Initialize the output tensor and a count tensor to keep track of contributions for averaging
+ output_shape = [num_frames] + other_dims
+ output = torch.zeros(output_shape, dtype=segments.dtype, device=segments.device)
+ count = torch.zeros(output_shape, dtype=torch.float32, device=segments.device)
+
+ current_index = 0
+ for i in range(b):
+ end_index = current_index + t
+ # Add the segment to the output tensor
+ output[current_index:end_index] += rearrange(segments[i], "... -> ...")
+ # Increment the count tensor for each frame that's added
+ count[current_index:end_index] += 1
+ # Update the starting index for the next segment
+ current_index += t - overlap
+
+ # Avoid division by zero
+ count[count == 0] = 1
+ # Average the frames where there's overlap
+ output /= count
+
+ return output
+
+
+def get_batch_overlap(
+ keys: List[str],
+ value_dict: Dict[str, Any],
+ N: Tuple[int, ...],
+ T: Optional[int],
+ device: str,
+) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ Create a batch dictionary with overlapping frames for model input.
+
+ Args:
+ keys: List of keys to include in the batch
+ value_dict: Dictionary containing values for each key
+ N: Batch dimensions
+ T: Number of frames (optional)
+ device: Device to place tensors on
+
+ Returns:
+ Tuple of (batch dictionary, unconditional batch dictionary)
+ """
+ batch = {}
+ batch_uc = {}
+
+ for key in keys:
+ if key == "fps_id":
+ batch[key] = (
+ torch.tensor([value_dict["fps_id"]])
+ .to(device)
+ .repeat(int(math.prod(N)))
+ )
+ elif key == "motion_bucket_id":
+ batch[key] = (
+ torch.tensor([value_dict["motion_bucket_id"]])
+ .to(device)
+ .repeat(int(math.prod(N)))
+ )
+ elif key == "cond_aug":
+ batch[key] = repeat(
+ torch.tensor([value_dict["cond_aug"]]).to(device),
+ "1 -> b",
+ b=math.prod(N),
+ )
+ elif key == "cond_frames":
+ batch[key] = repeat(value_dict["cond_frames"], "b ... -> (b t) ...", t=N[0])
+ elif key == "cond_frames_without_noise":
+ batch[key] = repeat(
+ value_dict["cond_frames_without_noise"], "b ... -> (b t) ...", t=N[0]
+ )
+ else:
+ batch[key] = value_dict[key]
+
+ if T is not None:
+ batch["num_video_frames"] = T
+
+ for key in batch.keys():
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
+ batch_uc[key] = torch.clone(batch[key])
+ return batch, batch_uc
+
+
+@torch.inference_mode()
+def sample_keyframes(
+ model_keyframes: Any,
+ audio_list: torch.Tensor,
+ gt_list: torch.Tensor,
+ masks_list: torch.Tensor,
+ condition: torch.Tensor,
+ num_frames: int,
+ fps_id: int,
+ cond_aug: float,
+ device: str,
+ embbedings: Optional[torch.Tensor],
+ force_uc_zero_embeddings: List[str],
+ n_batch_keyframes: int,
+ added_frames: int,
+ strength: float,
+ scale: Optional[Union[float, List[float]]],
+ gt_as_cond: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Sample keyframes using the keyframe generation model.
+
+ Args:
+ model_keyframes: The keyframe generation model
+ audio_list: List of audio embeddings
+ gt_list: List of ground truth frames
+ masks_list: List of masks
+ condition: Conditioning tensor
+ num_frames: Number of frames to generate
+ fps_id: FPS ID
+ cond_aug: Conditioning augmentation factor
+ device: Device to use for computation
+ embbedings: Optional embeddings
+ force_uc_zero_embeddings: List of embeddings to force to zero in unconditional case
+ n_batch_keyframes: Batch size for keyframe generation
+ added_frames: Number of additional frames
+ strength: Strength parameter for sampling
+ scale: Scale parameter for guidance
+ gt_as_cond: Whether to use ground truth as conditioning
+
+ Returns:
+ Tuple of (latent samples, decoded samples)
+ """
+ if scale is not None:
+ model_keyframes.sampler.guider.set_scale(scale)
+ # samples_list = []
+ samples_z_list = []
+ # samples_x_list = []
+
+ for i in range(audio_list.shape[0]):
+ H, W = condition.shape[-2:]
+ assert condition.shape[1] == 3
+ F = 8
+ C = 4
+ shape = (num_frames, C, H // F, W // F)
+
+ audio_cond = audio_list[i].unsqueeze(0)
+
+ value_dict: Dict[str, Any] = {}
+ value_dict["fps_id"] = fps_id
+ value_dict["cond_aug"] = cond_aug
+ value_dict["cond_frames_without_noise"] = condition
+ if embbedings is not None:
+ value_dict["cond_frames"] = embbedings + cond_aug * torch.randn_like(
+ embbedings
+ )
+ else:
+ value_dict["cond_frames"] = condition + cond_aug * torch.randn_like(
+ condition
+ )
+ gt = rearrange(gt_list[i].unsqueeze(0), "b t c h w -> b c t h w").to(device)
+
+ if gt_as_cond:
+ value_dict["cond_frames"] = gt[:, :, 0]
+
+ value_dict["cond_aug"] = cond_aug
+ value_dict["audio_emb"] = audio_cond
+
+ value_dict["gt"] = gt
+ value_dict["masks"] = masks_list[i].unsqueeze(0).transpose(1, 2).to(device)
+
+ with torch.no_grad():
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model_keyframes.conditioner),
+ value_dict,
+ [1, 1],
+ T=num_frames,
+ device=device,
+ )
+
+ c, uc = model_keyframes.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in ["crossattn"]:
+ if c[k].shape[1] != num_frames:
+ uc[k] = repeat(
+ uc[k],
+ "b ... -> b t ...",
+ t=num_frames,
+ )
+ uc[k] = rearrange(
+ uc[k],
+ "b t ... -> (b t) ...",
+ t=num_frames,
+ )
+ c[k] = repeat(
+ c[k],
+ "b ... -> b t ...",
+ t=num_frames,
+ )
+ c[k] = rearrange(
+ c[k],
+ "b t ... -> (b t) ...",
+ t=num_frames,
+ )
+
+ video = torch.randn(shape, device=device)
+
+ additional_model_inputs: Dict[str, torch.Tensor] = {}
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
+ n_batch_keyframes, num_frames
+ ).to(device)
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
+
+ def denoiser(
+ input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor]
+ ) -> torch.Tensor:
+ return model_keyframes.denoiser(
+ model_keyframes.model,
+ input,
+ sigma,
+ c,
+ **additional_model_inputs,
+ )
+
+ samples_z = model_keyframes.sampler(
+ denoiser, video, cond=c, uc=uc, strength=strength
+ )
+ samples_z_list.append(samples_z)
+ # samples_x_list.append(samples_x)
+ # samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+ # samples_list.append(samples)
+
+ video = None
+
+ # samples = (
+ # torch.concat(samples_list)[:-added_frames]
+ # if added_frames > 0
+ # else torch.concat(samples_list)
+ # )
+ samples_z = (
+ torch.concat(samples_z_list)[:-added_frames]
+ if added_frames > 0
+ else torch.concat(samples_z_list)
+ )
+ # samples_x = (
+ # torch.concat(samples_x_list)[:-added_frames]
+ # if added_frames > 0
+ # else torch.concat(samples_x_list)
+ # )
+
+ return samples_z
+
+
+@torch.inference_mode()
+def sample_interpolation(
+ model: Any,
+ samples_z: torch.Tensor,
+ # samples_x: torch.Tensor,
+ audio_interpolation_list: List[torch.Tensor],
+ gt_chunks: List[torch.Tensor],
+ masks_chunks: List[torch.Tensor],
+ condition: torch.Tensor,
+ num_frames: int,
+ device: str,
+ overlap: int,
+ fps_id: int,
+ cond_aug: float,
+ force_uc_zero_embeddings: List[str],
+ n_batch: int,
+ chunk_size: Optional[int],
+ strength: float,
+ scale: Optional[float] = None,
+ cut_audio: bool = False,
+ to_remove: List[bool] = [],
+) -> np.ndarray:
+ """
+ Sample interpolation frames between keyframes.
+
+ Args:
+ model: The interpolation model
+ samples_z: Latent samples from keyframe generation
+ samples_x: Decoded samples from keyframe generation
+ audio_interpolation_list: List of audio embeddings for interpolation
+ gt_chunks: Ground truth video chunks
+ masks_chunks: Mask chunks for conditional generation
+ condition: Visual conditioning
+ num_frames: Number of frames to generate
+ device: Device to run inference on
+ overlap: Number of frames to overlap between segments
+ fps_id: FPS ID for conditioning
+ motion_bucket_id: Motion bucket ID for conditioning
+ cond_aug: Conditioning augmentation strength
+ force_uc_zero_embeddings: Keys to zero out in unconditional embeddings
+ n_batch: Batch size for generation
+ chunk_size: Size of chunks for processing (to manage memory)
+ strength: Strength of the conditioning
+ scale: Optional scale for classifier-free guidance
+ cut_audio: Whether to cut audio embeddings
+ to_remove: List of flags indicating which frames to remove
+
+ Returns:
+ Generated video frames as numpy array
+ """
+ if scale is not None:
+ model.sampler.guider.set_scale(scale)
+
+ # Creating condition for interpolation model. We need to create a list of inputs, each input is [first, last]
+ # The first and last are the first and last frames of the interpolation
+ # interpolation_cond_list = []
+ interpolation_cond_list_emb = []
+
+ # samples_x = [sample for i, sample in zip(to_remove, samples_x) if not i]
+ samples_z = [sample for i, sample in zip(to_remove, samples_z) if not i]
+
+ for i in range(0, len(samples_z) - 1, overlap if overlap > 0 else 2):
+ # interpolation_cond_list.append(
+ # torch.stack([samples_x[i], samples_x[i + 1]], dim=1)
+ # )
+ interpolation_cond_list_emb.append(
+ torch.stack([samples_z[i], samples_z[i + 1]], dim=1)
+ )
+
+ # condition = torch.stack(interpolation_cond_list).to(device)
+ audio_cond = torch.stack(audio_interpolation_list).to(device)
+ embbedings = torch.stack(interpolation_cond_list_emb).to(device)
+
+ gt_chunks = torch.stack(gt_chunks).to(device)
+ masks_chunks = torch.stack(masks_chunks).to(device)
+
+ H, W = 512, 512
+ F = 8
+ C = 4
+ shape = (num_frames * audio_cond.shape[0], C, H // F, W // F)
+
+ value_dict: Dict[str, Any] = {}
+ value_dict["fps_id"] = fps_id
+ value_dict["cond_aug"] = cond_aug
+ # value_dict["cond_frames_without_noise"] = condition
+
+ value_dict["cond_frames"] = embbedings
+ value_dict["cond_aug"] = cond_aug
+ if cut_audio:
+ value_dict["audio_emb"] = audio_cond[:, :, :, :768]
+ else:
+ value_dict["audio_emb"] = audio_cond
+
+ value_dict["gt"] = rearrange(gt_chunks, "b t c h w -> b c t h w").to(device)
+ value_dict["masks"] = masks_chunks.transpose(1, 2).to(device)
+
+ with torch.no_grad():
+ with torch.autocast(device):
+ batch, batch_uc = get_batch_overlap(
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
+ value_dict,
+ [1, num_frames],
+ T=num_frames,
+ device=device,
+ )
+
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in ["crossattn"]:
+ if c[k].shape[1] != num_frames:
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
+
+ video = torch.randn(shape, device=device)
+
+ additional_model_inputs: Dict[str, torch.Tensor] = {}
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
+ n_batch, num_frames
+ ).to(device)
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
+
+ # Debug information
+ print(
+ f"Shapes - Embeddings: {embbedings.shape}, "
+ f"Audio: {audio_cond.shape}, Video: {shape}, Additional inputs: {additional_model_inputs}"
+ )
+
+ if chunk_size is not None:
+ chunk_size = chunk_size * num_frames
+
+ def denoiser(
+ input: torch.Tensor, sigma: torch.Tensor, c: Dict[str, torch.Tensor]
+ ) -> torch.Tensor:
+ return model.denoiser(
+ model.model,
+ input,
+ sigma,
+ c,
+ num_overlap_frames=overlap,
+ num_frames=num_frames,
+ n_skips=n_batch,
+ chunk_size=chunk_size,
+ **additional_model_inputs,
+ )
+
+ samples_z = model.sampler(denoiser, video, cond=c, uc=uc, strength=strength)
+ samples_z = rearrange(samples_z, "(b t) c h w -> b t c h w", t=num_frames)
+ samples_z[:, 0] = embbedings[:, :, 0]
+ samples_z[:, -1] = embbedings[:, :, 1]
+ samples_z = rearrange(samples_z, "b t c h w -> (b t) c h w")
+
+ samples_x = model.decode_first_stage(samples_z)
+
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ # Free up memory
+ video = None
+
+ samples = rearrange(samples, "(b t) c h w -> b t c h w", t=num_frames)
+ samples = merge_overlapping_segments(samples, overlap)
+
+ vid = (
+ (rearrange(samples, "t c h w -> t c h w") * 255).cpu().numpy().astype(np.uint8)
+ )
+
+ return vid
diff --git a/landmarks_extractor.py b/landmarks_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c37aa30b0a7b7f01cadd80eed5ce45bc9539fa64
--- /dev/null
+++ b/landmarks_extractor.py
@@ -0,0 +1,35 @@
+from skimage import io
+import face_alignment
+
+
+class LandmarksExtractor:
+ def __init__(self, device="cuda", landmarks_type="2D", flip=False):
+ self.fa = face_alignment.FaceAlignment(
+ face_alignment.LandmarksType.TWO_D
+ if landmarks_type == "2D"
+ else face_alignment.LandmarksType.THREE_D,
+ flip_input=flip,
+ device=device,
+ face_detector="sfd",
+ )
+
+ self.landmarks = []
+
+ def cuda(self):
+ return self
+
+ def extract_landmarks(self, image):
+ # image: either a path to an image or a numpy array (H, W, C) or tensor batch (B, C, H, W)
+ if isinstance(image, str):
+ image = io.imread(image)
+
+ # Ensure image is on CPU
+ if hasattr(image, "device"):
+ image = image.cpu()
+
+ if len(image.shape) == 3:
+ preds = self.fa.get_landmarks(image)
+ else:
+ preds = self.fa.get_landmarks_from_batch(image)
+
+ return preds
diff --git a/sgm/__init__.py b/sgm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..24bc84af8b1041de34b9816e0507cb1ac207bd13
--- /dev/null
+++ b/sgm/__init__.py
@@ -0,0 +1,4 @@
+from .models import AutoencodingEngine, DiffusionEngine
+from .util import get_configs_path, instantiate_from_config
+
+__version__ = "0.1.0"
diff --git a/sgm/__pycache__/__init__.cpython-311.pyc b/sgm/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9fcbff465ee3c88089f2dd8c569b2e15a702a37c
Binary files /dev/null and b/sgm/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/__pycache__/lr_scheduler.cpython-311.pyc b/sgm/__pycache__/lr_scheduler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..adc5e9ad10a3e02b2734ff87cd27282ddb1d355a
Binary files /dev/null and b/sgm/__pycache__/lr_scheduler.cpython-311.pyc differ
diff --git a/sgm/__pycache__/util.cpython-311.pyc b/sgm/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..411006a5c815e10450718fba34b7605c0c78b629
Binary files /dev/null and b/sgm/__pycache__/util.cpython-311.pyc differ
diff --git a/sgm/callbacks/__pycache__/video_logger.cpython-311.pyc b/sgm/callbacks/__pycache__/video_logger.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a183ae7a00c1cbbdf99d5ede1c153d14ebbfb668
Binary files /dev/null and b/sgm/callbacks/__pycache__/video_logger.cpython-311.pyc differ
diff --git a/sgm/callbacks/custom_ddp.py b/sgm/callbacks/custom_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3c6ee5db213f24cf549fded38e3774fafa042db
--- /dev/null
+++ b/sgm/callbacks/custom_ddp.py
@@ -0,0 +1,10 @@
+# from pytorch_lightning.overrides import LightningDistributedModule
+from pytorch_lightning.strategies import DDPStrategy
+
+
+class CustomDDPPlugin(DDPStrategy):
+ def configure_ddp(self):
+ # self.pre_configure_ddp()
+ self._model = self._setup_model((self.model))
+ self._register_ddp_hooks()
+ self._model._set_static_graph() # THIS IS THE MAGIC LINE
diff --git a/sgm/callbacks/image_logger.py b/sgm/callbacks/image_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b3a467855795383ebf554a3f91883644b870e36
--- /dev/null
+++ b/sgm/callbacks/image_logger.py
@@ -0,0 +1,193 @@
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.loggers import WandbLogger
+import numpy as np
+from pytorch_lightning.utilities import rank_zero_only
+from typing import Union
+import pytorch_lightning as pl
+import os
+from matplotlib import pyplot as plt
+from sgm.util import exists, isheatmap
+import torchvision
+from PIL import Image
+import torch
+import wandb
+from einops import rearrange
+
+
+class ImageLogger(Callback):
+ def __init__(
+ self,
+ batch_frequency,
+ max_images,
+ clamp=True,
+ increase_log_steps=True,
+ rescale=True,
+ disabled=False,
+ log_on_batch_idx=False,
+ log_first_step=False,
+ log_images_kwargs=None,
+ log_before_first_step=False,
+ enable_autocast=True,
+ ):
+ super().__init__()
+ self.enable_autocast = enable_autocast
+ self.rescale = rescale
+ self.batch_freq = batch_frequency
+ self.max_images = max_images
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
+ if not increase_log_steps:
+ self.log_steps = [self.batch_freq]
+ self.clamp = clamp
+ self.disabled = disabled
+ self.log_on_batch_idx = log_on_batch_idx
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+ self.log_first_step = log_first_step
+ self.log_before_first_step = log_before_first_step
+
+ @rank_zero_only
+ def log_local(
+ self,
+ save_dir,
+ split,
+ images,
+ global_step,
+ current_epoch,
+ batch_idx,
+ pl_module: Union[None, pl.LightningModule] = None,
+ ):
+ root = os.path.join(save_dir, "images", split)
+ for k in images:
+ if isheatmap(images[k]):
+ fig, ax = plt.subplots()
+ ax = ax.matshow(images[k].cpu().numpy(), cmap="hot", interpolation="lanczos")
+ plt.colorbar(ax)
+ plt.axis("off")
+
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
+ os.makedirs(root, exist_ok=True)
+ path = os.path.join(root, filename)
+ plt.savefig(path)
+ plt.close()
+ # TODO: support wandb
+ else:
+ grid = torchvision.utils.make_grid(images[k].squeeze(2), nrow=4)
+ if self.rescale:
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ # print(grid.shape, grid.dtype, grid.min(), grid.max(), k)
+ grid = rearrange(grid.squeeze(1), "c h w -> h w c")
+ # grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ img = Image.fromarray(grid)
+ img.save(path)
+ if exists(pl_module):
+ assert isinstance(
+ pl_module.logger, WandbLogger
+ ), "logger_log_image only supports WandbLogger currently"
+ pl_module.logger.log_image(
+ key=f"{split}/{k}",
+ images=[
+ img,
+ ],
+ step=pl_module.global_step,
+ )
+
+ @rank_zero_only
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
+ if (
+ self.check_frequency(check_idx)
+ and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
+ and callable(pl_module.log_images)
+ and
+ # batch_idx > 5 and
+ self.max_images > 0
+ ):
+ logger = type(pl_module.logger)
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+
+ gpu_autocast_kwargs = {
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
+
+ for k in images:
+ N = min(images[k].shape[0], self.max_images)
+ if not isheatmap(images[k]):
+ images[k] = images[k][:N]
+ if isinstance(images[k], torch.Tensor):
+ images[k] = images[k].detach().float().cpu()
+ if self.clamp and not isheatmap(images[k]):
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
+
+ self.log_local(
+ pl_module.logger.save_dir,
+ split,
+ images,
+ pl_module.global_step,
+ pl_module.current_epoch,
+ batch_idx,
+ pl_module=pl_module if isinstance(pl_module.logger, WandbLogger) else None,
+ )
+
+ if is_train:
+ pl_module.train()
+
+ def check_frequency(self, check_idx):
+ if check_idx:
+ check_idx -= 1
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
+ check_idx > 0 or self.log_first_step
+ ):
+ try:
+ self.log_steps.pop(0)
+ except IndexError as e:
+ print(e)
+ pass
+ return True
+ return False
+
+ @rank_zero_only
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
+ self.log_img(pl_module, batch, batch_idx, split="train")
+
+ @rank_zero_only
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
+ if self.log_before_first_step and pl_module.global_step == 0:
+ print(f"{self.__class__.__name__}: logging before training")
+ self.log_img(pl_module, batch, batch_idx, split="train")
+
+ @rank_zero_only
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs):
+ if not self.disabled and pl_module.global_step > 0:
+ self.log_img(pl_module, batch, batch_idx, split="val")
+ if hasattr(pl_module, "calibrate_grad_norm"):
+ if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
+
+
+@rank_zero_only
+def init_wandb(save_dir, opt, config, group_name, name_str):
+ print(f"setting WANDB_DIR to {save_dir}")
+ os.makedirs(save_dir, exist_ok=True)
+
+ os.environ["WANDB_DIR"] = save_dir
+ if opt.debug:
+ wandb.init(project=opt.projectname, mode="offline", group=group_name)
+ else:
+ wandb.init(
+ project=opt.projectname,
+ config=config,
+ settings=wandb.Settings(code_dir="./sgm"),
+ group=group_name,
+ name=name_str,
+ )
diff --git a/sgm/callbacks/setup_callback.py b/sgm/callbacks/setup_callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ca75338db67110f7771ba3349db4d423375dff8
--- /dev/null
+++ b/sgm/callbacks/setup_callback.py
@@ -0,0 +1,86 @@
+from pytorch_lightning.callbacks import Callback
+import pytorch_lightning as pl
+import os
+from omegaconf import OmegaConf
+from pytorch_lightning.utilities import rank_zero_only
+
+MULTINODE_HACKS = True
+
+
+class SetupCallback(Callback):
+ def __init__(
+ self,
+ resume,
+ now,
+ logdir,
+ ckptdir,
+ cfgdir,
+ config,
+ lightning_config,
+ debug,
+ ckpt_name=None,
+ ):
+ super().__init__()
+ self.resume = resume
+ self.now = now
+ self.logdir = logdir
+ self.ckptdir = ckptdir
+ self.cfgdir = cfgdir
+ self.config = config
+ self.lightning_config = lightning_config
+ self.debug = debug
+ self.ckpt_name = ckpt_name
+
+ @rank_zero_only
+ def on_exception(self, trainer: pl.Trainer, pl_module, exception):
+ print("Exception occurred: {}".format(exception))
+ if not self.debug and trainer.global_rank == 0:
+ print("Summoning checkpoint.")
+ if self.ckpt_name is None:
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
+ else:
+ ckpt_path = os.path.join(self.ckptdir, self.ckpt_name)
+ trainer.save_checkpoint(ckpt_path)
+
+ @rank_zero_only
+ def on_fit_start(self, trainer, pl_module):
+ if trainer.global_rank == 0:
+ # Create logdirs and save configs
+ os.makedirs(self.logdir, exist_ok=True)
+ os.makedirs(self.ckptdir, exist_ok=True)
+ os.makedirs(self.cfgdir, exist_ok=True)
+
+ if "callbacks" in self.lightning_config:
+ if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
+ os.makedirs(
+ os.path.join(self.ckptdir, "trainstep_checkpoints"),
+ exist_ok=True,
+ )
+ print("Project config")
+ print(OmegaConf.to_yaml(self.config))
+ if MULTINODE_HACKS:
+ import time
+
+ time.sleep(5)
+ OmegaConf.save(
+ self.config,
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
+ )
+
+ print("Lightning config")
+ print(OmegaConf.to_yaml(self.lightning_config))
+ OmegaConf.save(
+ OmegaConf.create({"lightning": self.lightning_config}),
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
+ )
+
+ else:
+ # ModelCheckpoint callback created log directory --- remove it
+ if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
+ dst, name = os.path.split(self.logdir)
+ dst = os.path.join(dst, "child_runs", name)
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
+ try:
+ os.rename(self.logdir, dst)
+ except FileNotFoundError:
+ pass
diff --git a/sgm/callbacks/video_logger.py b/sgm/callbacks/video_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ac45d85d296d8f73f29e61e5ce04341235c3e6e
--- /dev/null
+++ b/sgm/callbacks/video_logger.py
@@ -0,0 +1,294 @@
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.loggers import WandbLogger
+import numpy as np
+from pytorch_lightning.utilities import rank_zero_only
+from typing import Union
+import pytorch_lightning as pl
+import os
+from sgm.util import exists, suppress_output, default
+import torchvision
+from PIL import Image
+import torch
+import wandb
+import moviepy.editor as mpy
+from einops import rearrange
+import torchaudio
+# import tempfile
+# import cv2
+# import scipy.io.wavfile as wav
+# import ffmpeg
+
+
+@suppress_output
+def save_audio_video(
+ video, audio=None, frame_rate=25, sample_rate=16000, save_path="temp.mp4", keep_intermediate=False
+):
+ """Save audio and video to a single file.
+ video: (t, c, h, w)
+ audio: (channels t)
+ """
+
+ # temp_filename = next(tempfile._get_candidate_names())
+ # if save_path:
+ # save_path = save_path
+ # else:
+ # save_path = "/tmp/" + next(tempfile._get_candidate_names()) + ".mp4"
+ save_path = str(save_path)
+ try:
+ torchvision.io.write_video(
+ "temp_video.mp4", rearrange(video.detach().cpu(), "t c h w -> t h w c").to(torch.uint8), frame_rate
+ )
+ video_clip = mpy.VideoFileClip("temp_video.mp4")
+ if audio is not None:
+ torchaudio.save("temp_audio.wav", audio.detach().cpu(), sample_rate)
+ audio_clip = mpy.AudioFileClip("temp_audio.wav")
+ video_clip = video_clip.set_audio(audio_clip)
+ video_clip.write_videofile(save_path, fps=frame_rate, codec="libx264", audio_codec="aac", verbose=False)
+ if not keep_intermediate:
+ os.remove("temp_video.mp4")
+ if audio is not None:
+ os.remove("temp_audio.wav")
+ return 1
+ except Exception as e:
+ print(e)
+ print("Saving video to file failed")
+ return 0
+
+
+# def write_video_opencv(video, video_rate, video_path):
+# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+# out = cv2.VideoWriter(video_path, fourcc, video_rate, (video.shape[2], video.shape[3]), 0)
+# for frame in list(video):
+# frame = np.squeeze(frame)
+# out.write(np.squeeze(frame))
+# out.release()
+
+
+# # Code mostly inherited from bulletin
+# def save_av_sample(video, video_rate, audio=None, audio_rate=16_000, path=None):
+# # Save video sample in train dir for debugging
+# # video_save = 0.5 * video.detach().cpu().numpy() + 0.5
+# video_save = rearrange(video, "t c h w -> t h w c").detach().cpu().numpy()
+# temp_filename = next(tempfile._get_candidate_names())
+# if path:
+# video_path = path
+# else:
+# video_path = "/tmp/" + next(tempfile._get_candidate_names()) + ".mp4"
+# write_video_opencv((video_save).astype(np.uint8), video_rate, "/tmp/" + temp_filename + ".mp4")
+# audio_save = audio.detach().squeeze().cpu().numpy()
+# wav.write("/tmp/" + temp_filename + ".wav", audio_rate, audio_save)
+# try:
+# in1 = ffmpeg.input("/tmp/" + temp_filename + ".mp4")
+# in2 = ffmpeg.input("/tmp/" + temp_filename + ".wav")
+# out = ffmpeg.output(in1["v"], in2["a"], video_path, loglevel="panic").overwrite_output()
+# out.run(capture_stdout=True, capture_stderr=True)
+# except ffmpeg.Error as e:
+# print("stdout:", e.stdout.decode("utf8"))
+# print("stderr:", e.stderr.decode("utf8"))
+# raise e
+# return video_path
+
+
+class VideoLogger(Callback):
+ def __init__(
+ self,
+ batch_frequency,
+ max_videos,
+ clamp=True,
+ increase_log_steps=True,
+ rescale=True,
+ disabled=False,
+ log_on_batch_idx=False,
+ log_first_step=False,
+ log_videos_kwargs=None,
+ log_before_first_step=False,
+ enable_autocast=True,
+ batch_frequency_val=None,
+ ):
+ super().__init__()
+ self.enable_autocast = enable_autocast
+ self.rescale = rescale
+ self.batch_freq = batch_frequency
+ self.max_videos = max_videos
+ self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
+ if not increase_log_steps:
+ self.log_steps = [self.batch_freq]
+ self.batch_freq_val = default(batch_frequency_val, self.batch_freq)
+ self.log_steps_val = [2**n for n in range(int(np.log2(self.batch_freq_val)) + 1)]
+ if not increase_log_steps:
+ self.log_steps_val = [self.batch_freq_val]
+ self.clamp = clamp
+ self.disabled = disabled
+ self.log_on_batch_idx = log_on_batch_idx
+ self.log_videos_kwargs = log_videos_kwargs if log_videos_kwargs else {}
+ self.log_first_step = log_first_step
+ self.log_before_first_step = log_before_first_step
+
+ @rank_zero_only
+ def log_local(
+ self,
+ save_dir,
+ split,
+ log_elements,
+ raw_audio,
+ global_step,
+ current_epoch,
+ batch_idx,
+ pl_module: Union[None, pl.LightningModule] = None,
+ ):
+ root = os.path.join(save_dir, "videos", split)
+ for k in log_elements:
+ element = log_elements[k]
+ if len(element.shape) == 4:
+ grid = torchvision.utils.make_grid(element, nrow=4)
+ if self.rescale:
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ img = Image.fromarray(grid)
+ img.save(path)
+ if exists(pl_module):
+ assert isinstance(
+ pl_module.logger, WandbLogger
+ ), "logger_log_image only supports WandbLogger currently"
+ pl_module.logger.log_image(
+ key=f"{split}/{k}",
+ images=[
+ img,
+ ],
+ step=pl_module.global_step,
+ )
+ elif len(element.shape) == 5:
+ video = element
+ if self.rescale:
+ video = (video + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ video = video * 255.0
+ video = video.permute(0, 2, 1, 3, 4).cpu().detach().to(torch.uint8) # b,t,c,h,w
+ for i in range(video.shape[0]):
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}_{}.mp4".format(k, global_step, current_epoch, batch_idx, i)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ log_audio = raw_audio[i] if raw_audio is not None else None
+ success = save_audio_video(
+ video[i],
+ audio=log_audio.unsqueeze(0) if log_audio is not None else None,
+ frame_rate=25,
+ sample_rate=16000,
+ save_path=path,
+ keep_intermediate=False,
+ )
+
+ # video_path = save_av_sample(video[i], 25, audio=raw_audio, audio_rate=16000, path=None)
+ if exists(pl_module):
+ assert isinstance(
+ pl_module.logger, WandbLogger
+ ), "logger_log_image only supports WandbLogger currently"
+ pl_module.logger.experiment.log(
+ {
+ f"{split}/{k}": wandb.Video(
+ path if success else video,
+ # caption=f"diffused videos w {n_frames} frames (condition left, generated right)",
+ fps=25,
+ format="mp4",
+ )
+ },
+ )
+
+ @rank_zero_only
+ def log_video(self, pl_module, batch, batch_idx, split="train"):
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
+ # print(f"check_idx: {check_idx}", f"split: {split}")
+ if (
+ self.check_frequency(check_idx, split=split)
+ and hasattr(pl_module, "log_videos") # batch_idx % self.batch_freq == 0
+ and callable(pl_module.log_videos)
+ and
+ # batch_idx > 5 and
+ self.max_videos > 0
+ ):
+ logger = type(pl_module.logger)
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+
+ gpu_autocast_kwargs = {
+ "enabled": self.enable_autocast, # torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs):
+ videos = pl_module.log_videos(batch, split=split, **self.log_videos_kwargs)
+
+ for k in videos:
+ N = min(videos[k].shape[0], self.max_videos)
+ videos[k] = videos[k][:N]
+ if isinstance(videos[k], torch.Tensor):
+ videos[k] = videos[k].detach().float().cpu()
+ if self.clamp:
+ videos[k] = torch.clamp(videos[k], -1.0, 1.0)
+
+ raw_audio = batch.get("raw_audio", None)
+
+ self.log_local(
+ pl_module.logger.save_dir,
+ split,
+ videos,
+ raw_audio,
+ pl_module.global_step,
+ pl_module.current_epoch,
+ batch_idx,
+ pl_module=pl_module if isinstance(pl_module.logger, WandbLogger) else None,
+ )
+
+ if is_train:
+ pl_module.train()
+
+ def check_frequency(self, check_idx, split="train"):
+ if split == "val":
+ if check_idx:
+ check_idx -= 1
+ if ((check_idx % self.batch_freq_val) == 0 or (check_idx in self.log_steps_val)) and (
+ check_idx > 0 or self.log_first_step
+ ):
+ try:
+ self.log_steps_val.pop(0)
+ except IndexError as e:
+ print(e)
+ pass
+ return True
+ return False
+ if check_idx:
+ check_idx -= 1
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
+ check_idx > 0 or self.log_first_step
+ ):
+ try:
+ self.log_steps.pop(0)
+ except IndexError as e:
+ print(e)
+ pass
+ return True
+ return False
+
+ @rank_zero_only
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
+ self.log_video(pl_module, batch, batch_idx, split="train")
+
+ @rank_zero_only
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
+ if self.log_before_first_step and pl_module.global_step == 0:
+ print(f"{self.__class__.__name__}: logging before training")
+ self.log_video(pl_module, batch, batch_idx, split="train")
+
+ @rank_zero_only
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs):
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
+ self.log_video(pl_module, batch, batch_idx, split="val")
+ if hasattr(pl_module, "calibrate_grad_norm"):
+ if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
diff --git a/sgm/data/__init__.py b/sgm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3076683da945963a7b12d4444a382fb43ddab58
--- /dev/null
+++ b/sgm/data/__init__.py
@@ -0,0 +1 @@
+# from .dataset import StableDataModuleFromConfig
diff --git a/sgm/data/__pycache__/__init__.cpython-311.pyc b/sgm/data/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10d1f6bb6c7c5c6b8ecb27300ed74697df0557a0
Binary files /dev/null and b/sgm/data/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/data/__pycache__/data_utils.cpython-311.pyc b/sgm/data/__pycache__/data_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b10618770a486e3a9cb7150f069226e15a67e0d
Binary files /dev/null and b/sgm/data/__pycache__/data_utils.cpython-311.pyc differ
diff --git a/sgm/data/__pycache__/mask.cpython-311.pyc b/sgm/data/__pycache__/mask.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3553f277ba58e5668a8e999e4fcef6130ed46c1b
Binary files /dev/null and b/sgm/data/__pycache__/mask.cpython-311.pyc differ
diff --git a/sgm/data/__pycache__/video_datamodule_latent.cpython-311.pyc b/sgm/data/__pycache__/video_datamodule_latent.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ff4dccc9248f32b81c3e47c503d6963179adaaf6
Binary files /dev/null and b/sgm/data/__pycache__/video_datamodule_latent.cpython-311.pyc differ
diff --git a/sgm/data/__pycache__/video_dataset_latent.cpython-311.pyc b/sgm/data/__pycache__/video_dataset_latent.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5086b19282bf6b23ae44dbd9aa075b957e0df7f5
Binary files /dev/null and b/sgm/data/__pycache__/video_dataset_latent.cpython-311.pyc differ
diff --git a/sgm/data/data_utils.py b/sgm/data/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e1103d56f9b9c31b860f4c8b18bcdce5f7bd0be
--- /dev/null
+++ b/sgm/data/data_utils.py
@@ -0,0 +1,561 @@
+import torch
+import numpy as np
+from PIL import Image, ImageDraw
+import cv2
+from functools import partial
+import math
+
+
+def get_size(img):
+ if isinstance(img, (np.ndarray, torch.Tensor)):
+ return img.shape[1::-1]
+ else:
+ return img.size
+
+
+def imresample(img, sz):
+ im_data = torch.nn.functional.interpolate(img, size=sz, mode="area")
+ return im_data
+
+
+def crop_resize(img, box, image_size):
+ if isinstance(img, np.ndarray):
+ img = img[box[1] : box[3], box[0] : box[2]]
+ out = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_AREA).copy()
+ elif isinstance(img, torch.Tensor):
+ img = img[box[1] : box[3], box[0] : box[2]]
+ out = (
+ imresample(img.permute(2, 0, 1).unsqueeze(0).float(), (image_size, image_size))
+ .byte()
+ .squeeze(0)
+ .permute(1, 2, 0)
+ )
+ else:
+ out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
+ return out
+
+
+def fixed_image_standardization(image_tensor):
+ processed_tensor = (image_tensor - 127.5) / 128.0
+ return processed_tensor
+
+
+def extract_face(img, landmarks, image_size=160, margin=0, postprocess=False):
+ """Extract face + margin from images given facial landmarks.
+
+ Arguments:
+ img {PIL.Image/torch.Tensor/np.ndarray} -- Input image(s) with shape (B, H, W, C)
+ landmarks {numpy.ndarray} -- Facial landmarks with shape (B, 68, 2)
+ image_size {int} -- Output image size in pixels. The image will be square.
+ margin {int} -- Margin to add to bounding box, in terms of pixels in the final image.
+ postprocess {bool} -- Whether to apply standardization
+
+ Returns:
+ torch.tensor -- tensor representing the extracted faces with shape (B, H, W, C)
+ """
+ # Calculate bounding boxes from landmarks for all faces in batch
+ x_min = np.min(landmarks, axis=1)[:, 0] # Shape: (B,)
+ y_min = np.min(landmarks, axis=1)[:, 1] # Shape: (B,)
+ x_max = np.max(landmarks, axis=1)[:, 0] # Shape: (B,)
+ y_max = np.max(landmarks, axis=1)[:, 1] # Shape: (B,)
+
+ # Calculate margin for top only
+ box_height = y_max - y_min
+ top_margin = margin * box_height / (image_size - margin)
+
+ # Create boxes for all faces
+ boxes = np.stack(
+ [
+ x_min,
+ np.maximum(y_min - top_margin, 0), # Only add margin to top
+ x_max,
+ y_max,
+ ],
+ axis=1,
+ ).astype(int) # Shape: (B, 4)
+
+ # Process each face in the batch
+ faces = []
+ for i in range(len(boxes)):
+ face = crop_resize(img[i], boxes[i], image_size)
+ faces.append(face)
+
+ faces = torch.stack(faces, dim=0)
+ faces = faces.float()
+
+ if postprocess:
+ faces = fixed_image_standardization(faces)
+
+ return faces
+
+
+def crop_mouth_region(images, landmarks, crop_size=96):
+ """
+ Takes a fixed-size square crop centered on the mouth region.
+
+ Parameters:
+ - images: tensor/array of shape (num_frames, height, width, channels) or (height, width, channels)
+ - landmarks: numpy array of shape (num_frames, 68, 2) or (68, 2)
+ - crop_size: size of the square crop (both height and width)
+ - padding: percentage of padding around the mouth region (0.0 to 1.0)
+
+ Returns:
+ - List of fixed-size crops or single crop if input is single image
+ """
+ # Handle single image case
+ single_image = False
+ if len(images.shape) == 3:
+ images = images[None]
+ landmarks = landmarks[None]
+ single_image = True
+
+ num_frames = len(images)
+ crops = []
+
+ # Mouth landmarks indices (48-67 for mouth region)
+ mouth_indices = range(48, 68)
+
+ for i in range(num_frames):
+ # Get mouth landmarks for current frame
+ mouth_landmarks = landmarks[i][mouth_indices]
+
+ # Find center of mouth
+ center_x = int(np.mean(mouth_landmarks[:, 0]))
+ center_y = int(np.mean(mouth_landmarks[:, 1]))
+
+ # Calculate crop boundaries
+ half_size = crop_size // 2
+ left = max(0, center_x - half_size)
+ right = min(images.shape[2], center_x + half_size)
+ top = max(0, center_y - half_size)
+ bottom = min(images.shape[1], center_y + half_size)
+
+ # Adjust if crop would go out of bounds
+ if left == 0:
+ right = crop_size
+ if right == images.shape[2]:
+ left = images.shape[2] - crop_size
+ if top == 0:
+ bottom = crop_size
+ if bottom == images.shape[1]:
+ top = images.shape[1] - crop_size
+
+ # Take the crop
+ crop = images[i, top:bottom, left:right]
+ crops.append(crop)
+
+ return crops[0] if single_image else crops
+
+
+def create_masks_from_landmarks_box(landmark_list, img_shape, nose_index=28, dtype="uint8", box_expand=0.0):
+ height, width = img_shape[:2]
+ num_frames = landmark_list.shape[0]
+
+ # Initialize the masks array
+ masks = np.zeros((num_frames, height, width), dtype=dtype)
+
+ if 0 <= box_expand < 1:
+ box_expand = int(box_expand * width)
+
+ for i in range(num_frames):
+ # Get the landmarks for the current frame
+ landmarks = landmark_list[i]
+
+ # Get the y-coordinate of the nose landmark
+ nose_point_h = landmarks[nose_index, 1]
+ cut_h = nose_point_h
+
+ # Find the leftmost and rightmost landmarks
+ far_left_index = np.argmin(landmarks[:, 0])
+ far_right_index = np.argmax(landmarks[:, 0])
+
+ # Define the points for the mask contour
+ left_up_point = np.array([landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32)
+ left_down_point = np.array([landmarks[far_left_index][0], height], dtype=np.int32)
+ right_up_point = np.array([landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32)
+ right_down_point = np.array([landmarks[far_right_index][0], height], dtype=np.int32)
+
+ # Define the contour
+ contour = np.array([[left_up_point, left_down_point, right_down_point, right_up_point]])
+
+ # Draw the contour on the mask
+ cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
+
+ return torch.from_numpy(masks)
+
+
+def create_masks_from_landmarks_full_size(
+ landmarks_batch, image_height, image_width, start_index=48, end_index=68, offset=0, nose_index=33
+):
+ """
+ Efficiently creates a batch of masks using vectorized operations where each mask has ones from the highest
+ landmark in the specified range (adjusted by an offset) to the bottom of the image, and zeros otherwise.
+
+ Parameters:
+ - landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
+ - image_height (int): The height of the image for which masks are created.
+ - image_width (int): The width of the image for which masks are created.
+ - start_index (int): The starting index of the range to check (inclusive).
+ - end_index (int): The ending index of the range to check (inclusive).
+ - offset (int): An offset to add or subtract from the y-coordinate of the highest landmark.
+
+ Returns:
+ - np.array: An array of masks of shape (B, image_height, image_width) for each batch.
+ """
+ # Extract the y-coordinates for the specified range across all batches
+ y_coords = landmarks_batch[:, nose_index : nose_index + 1, 1]
+
+ # Find the index of the minimum y-coordinate in the specified range for each batch
+ min_y_indices = np.argmin(y_coords, axis=1)
+
+ # Gather the highest landmarks' y-coordinates using the indices found
+ highest_y_coords = y_coords[np.arange(len(y_coords)), min_y_indices]
+
+ if abs(offset) < 1 and abs(offset) > 0:
+ offset = int(offset * image_height)
+
+ # Apply the offset to the highest y-coordinate
+ adjusted_y_coords = highest_y_coords + offset
+
+ # Clip the coordinates to stay within image boundaries
+ adjusted_y_coords = np.clip(adjusted_y_coords, 0, image_height - 1)
+
+ # Use broadcasting to create a mask without loops
+ # Create a range of indices from 0 to image_height - 1
+ all_indices = np.arange(image_height)
+
+ # Compare each index in 'all_indices' to each 'adjusted_y_coord' in the batch
+ # 'all_indices' has shape (image_height,), we reshape to (1, image_height) to broadcast against (B, 1)
+ mask_2d = (all_indices >= adjusted_y_coords[:, None]).astype(int)
+
+ # Extend the 2D mask to a full 3D mask of size (B, image_height, image_width)
+ full_mask = np.tile(mask_2d[:, :, np.newaxis], (1, 1, image_width))
+
+ return torch.from_numpy(full_mask)
+
+
+def expand_polygon(polygon, expand_size):
+ """
+ Expands the polygon outward by a specified number of pixels.
+
+ Parameters:
+ - polygon (list of tuples): The polygon points as (x, y).
+ - expand_size (int): The number of pixels to expand the polygon outward.
+
+ Returns:
+ - expanded_polygon (list of tuples): The expanded polygon points as (x, y).
+ """
+ if expand_size == 0:
+ return polygon
+
+ # Calculate centroid of the polygon
+ centroid_x = sum([point[0] for point in polygon]) / len(polygon)
+ centroid_y = sum([point[1] for point in polygon]) / len(polygon)
+
+ # Expand each point outward from the centroid
+ expanded_polygon = []
+ for x, y in polygon:
+ vector_x = x - centroid_x
+ vector_y = y - centroid_y
+ length = np.sqrt(vector_x**2 + vector_y**2)
+ if length == 0:
+ expanded_polygon.append((x, y))
+ else:
+ new_x = x + expand_size * (vector_x / length)
+ new_y = y + expand_size * (vector_y / length)
+ expanded_polygon.append((int(new_x), int(new_y)))
+
+ return expanded_polygon
+
+
+def create_masks_from_landmarks_mouth(landmark_list, img_shape, nose_index=33, dtype="uint8", box_expand=0.0):
+ height, width = img_shape[:2]
+ num_frames = landmark_list.shape[0]
+
+ # Initialize the masks array
+ masks = np.zeros((num_frames, height, width), dtype=dtype)
+
+ if 0 <= box_expand < 1:
+ box_expand = int(box_expand * width)
+
+ for i in range(num_frames):
+ # Get the landmarks for the current frame
+ landmarks = landmark_list[i]
+
+ # Get the y-coordinate of the nose landmark
+ nose_point_h = landmarks[nose_index, 1]
+ cut_h = nose_point_h
+
+ # Find the leftmost and rightmost landmarks
+ far_left_index = np.argmin(landmarks[:, 0])
+ far_right_index = np.argmax(landmarks[:, 0])
+
+ # Find lowest landmark y-coordinate
+ lowest_y = np.max(landmarks[:, 1])
+ # Add box_expand to the lowest point
+ lowest_y = min(height, lowest_y + box_expand)
+
+ # Define the points for the mask contour
+ left_up_point = np.array([landmarks[far_left_index][0], cut_h - box_expand], dtype=np.int32)
+ left_down_point = np.array([landmarks[far_left_index][0], lowest_y], dtype=np.int32)
+ right_up_point = np.array([landmarks[far_right_index][0], cut_h - box_expand], dtype=np.int32)
+ right_down_point = np.array([landmarks[far_right_index][0], lowest_y], dtype=np.int32)
+
+ # Define the contour
+ contour = np.array([[left_up_point, left_down_point, right_down_point, right_up_point]])
+
+ # Draw the contour on the mask
+ cv2.drawContours(masks[i], [contour], -1, color=(1), thickness=cv2.FILLED)
+
+ return torch.from_numpy(masks)
+
+
+def create_face_mask_from_landmarks(landmarks_batch, image_height, image_width, mask_expand=0):
+ """
+ Creates a batch of masks where each mask covers the face region using landmarks.
+
+ Parameters:
+ - landmarks_batch (np.array): An array of shape (B, 68, 2) containing facial landmarks for multiple samples.
+ - image_height (int): The height of the image for which masks are created.
+ - image_width (int): The width of the image for which masks are created.
+ - mask_expand (int): The number of pixels to expand the mask outward.
+
+ Returns:
+ - np.array: An array of masks of shape (B, image_height, image_width) for each batch.
+ """
+ # Initialize an array to hold all masks
+ masks = np.zeros((landmarks_batch.shape[0], image_height, image_width), dtype=np.uint8)
+
+ if abs(mask_expand) < 1 and abs(mask_expand) > 0:
+ mask_expand = int(mask_expand * image_height)
+
+ for i, landmarks in enumerate(landmarks_batch):
+ # Create a blank image for each mask
+ mask = Image.new("L", (image_width, image_height), 0)
+ draw = ImageDraw.Draw(mask)
+
+ # Extract relevant landmarks for the face
+ jawline_landmarks = landmarks[2:15] # Jawline
+ # upper_face_landmarks = landmarks[17:27] # Eyebrows and top of nose bridge
+
+ # Combine landmarks to form a polygon around the face
+ # face_polygon = np.concatenate((jawline_landmarks, upper_face_landmarks[::-1]), axis=0)
+ face_polygon = jawline_landmarks
+
+ # Convert landmarks to a list of tuples
+ face_polygon = [(int(x), int(y)) for x, y in face_polygon]
+
+ # Expand the polygon if necessary
+ expanded_polygon = expand_polygon(face_polygon, mask_expand)
+
+ # Draw the polygon and fill it
+ draw.polygon(expanded_polygon, outline=1, fill=1)
+
+ # Convert mask to numpy array and add it to the batch of masks
+ masks[i] = np.array(mask)
+
+ return torch.from_numpy(masks)
+
+
+ALL_FIXED_POINTS = (
+ [i for i in range(0, 4)] + [i for i in range(13, 17)] + [i for i in range(27, 36)] + [36, 39, 42, 45]
+)
+
+
+def gaussian_kernel(sigma, width, height):
+ """Create a 2D Gaussian kernel."""
+ x = torch.arange(0, width, 1) - width // 2
+ y = torch.arange(0, height, 1) - height // 2
+ x = x.float()
+ y = y.float()
+ x2 = x**2
+ y2 = y[:, None] ** 2
+ g = torch.exp(-(x2 + y2) / (2 * sigma**2))
+ return g / g.sum()
+
+
+def generate_hm(landmarks, height, width, n_points="all", sigma=3):
+ if n_points == "all":
+ Nlandmarks = range(len(landmarks))
+ elif n_points == "fixed":
+ Nlandmarks = ALL_FIXED_POINTS
+ elif n_points == "stable":
+ Nlandmarks = [33, 36, 39, 42, 45]
+
+ kernel = gaussian_kernel(sigma, width, height)
+ hm = torch.zeros((height, width))
+ for I in Nlandmarks:
+ x0, y0 = landmarks[I]
+ x0, y0 = int(x0), int(y0)
+ left, right = max(0, x0 - width // 2), min(width, x0 + width // 2)
+ top, bottom = max(0, y0 - height // 2), min(height, y0 + height // 2)
+ hm[top:bottom, left:right] += kernel[
+ max(0, -y0 + height // 2) : min(height, height - y0 + height // 2),
+ max(0, -x0 + width // 2) : min(width, width - x0 + width // 2),
+ ]
+ # Normalize the heatmap to have values between 0 and 1
+ max_val = hm.max()
+ if max_val > 0:
+ hm /= max_val
+ return hm
+
+
+def get_heatmap(landmarks, image_size, or_im_size, n_points="stable", sigma=4):
+ stack = []
+ seq_length = landmarks.shape[0]
+ if or_im_size[0] != image_size[0] or or_im_size[1] != image_size[1]:
+ landmarks = scale_landmarks(landmarks, or_im_size, image_size)
+ gen_single_heatmap = partial(
+ generate_hm,
+ height=image_size[0],
+ width=image_size[1],
+ n_points=n_points,
+ sigma=sigma,
+ )
+ for i in range(seq_length):
+ stack.append(gen_single_heatmap(landmarks[i]))
+
+ return torch.stack(stack, axis=0).unsqueeze(0) # (1, seq_length, height, width)
+
+
+def scale_landmarks(landmarks, original_size, target_size):
+ """
+ Scale landmarks from original size to target size.
+
+ Parameters:
+ - landmarks (np.array): An array of shape (N, 2) containing facial landmarks.
+ - original_size (tuple): The size (height, width) for which the landmarks are currently scaled.
+ - target_size (tuple): The size (height, width) to which landmarks should be scaled.
+
+ Returns:
+ - scaled_landmarks (np.array): Scaled landmarks.
+ """
+ scale_y = target_size[0] / original_size[0]
+ scale_x = target_size[1] / original_size[1]
+ scaled_landmarks = landmarks * np.array([scale_x, scale_y])
+ return scaled_landmarks.astype(int)
+
+
+def draw_kps_image(
+ image_shape, original_size, landmarks, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)], rgb=True, pts_width=4
+):
+ stick_width = pts_width
+ limb_seq = np.array([[0, 2], [1, 2]])
+ kps = landmarks[[36, 45, 33], :]
+ kps = scale_landmarks(kps, original_size, image_shape)
+ if not rgb: # Grayscale image
+ canvas = np.zeros((image_shape[0], image_shape[1], 1))
+ color_mode = "grayscale"
+ else: # Color image
+ canvas = np.zeros((image_shape[0], image_shape[1], 3))
+ color_mode = "color"
+
+ polygon_cache = {}
+
+ for index in limb_seq:
+ color = color_list[index[0]]
+ if color_mode == "grayscale":
+ color = (int(0.299 * color[2] + 0.587 * color[1] + 0.114 * color[0]),) # Convert to grayscale intensity
+
+ x = kps[index][:, 0]
+ y = kps[index][:, 1]
+ length = np.sqrt((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2)
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
+
+ cache_key = (color, int(np.mean(x)), int(np.mean(y)), int(length / 2), int(angle))
+ if cache_key not in polygon_cache:
+ polygon_cache[cache_key] = cv2.ellipse2Poly(
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stick_width), int(angle), 0, 360, 1
+ )
+
+ polygon = polygon_cache[cache_key]
+ cv2.fillConvexPoly(canvas, polygon, [int(c * 0.6) for c in color])
+
+ for idx, kp in enumerate(kps):
+ if color_mode == "grayscale":
+ color = (int(0.299 * color_list[idx][2] + 0.587 * color_list[idx][1] + 0.114 * color_list[idx][0]),)
+ else:
+ color = color_list[idx]
+ cv2.circle(canvas, (int(kp[0]), int(kp[1])), pts_width, color, -1)
+
+ return canvas.transpose(2, 0, 1)
+
+
+def create_landmarks_image(
+ landmarks, original_size=(772, 772), target_size=(772, 772), point_size=3, n_points="all", dim=3
+):
+ """
+ Creates an image of landmarks on a black background using efficient NumPy operations.
+
+ Parameters:
+ - landmarks (np.array): An array of shape (68, 2) containing facial landmarks.
+ - image_size (tuple): The size of the output image (height, width).
+ - point_size (int): The radius of each landmark point in pixels.
+
+ Returns:
+ - img (np.array): An image array with landmarks plotted.
+ """
+ if n_points == "all":
+ indexes = range(len(landmarks))
+ elif n_points == "fixed":
+ indexes = ALL_FIXED_POINTS
+ elif n_points == "stable":
+ indexes = [33, 36, 39, 42, 45]
+
+ landmarks = landmarks[indexes]
+
+ img = np.zeros(target_size, dtype=np.uint8)
+
+ landmarks = scale_landmarks(landmarks, original_size, target_size)
+
+ # Ensure the landmarks are in bounds and integer
+ landmarks = np.clip(landmarks, [0, 0], [target_size[1] - 1, target_size[0] - 1]).astype(int)
+
+ # Get x and y coordinates from landmarks
+ x, y = landmarks[:, 0], landmarks[:, 1]
+
+ # Define a grid offset based on point_size around each landmark
+ offset = np.arange(-point_size // 2, point_size // 2 + 1)
+ grid_x, grid_y = np.meshgrid(offset, offset, indexing="ij")
+
+ # Calculate the full set of x and y coordinates for the points
+ full_x = x[:, np.newaxis, np.newaxis] + grid_x[np.newaxis, :, :]
+ full_y = y[:, np.newaxis, np.newaxis] + grid_y[np.newaxis, :, :]
+
+ # Clip the coordinates to stay within image boundaries
+ full_x = np.clip(full_x, 0, target_size[1] - 1)
+ full_y = np.clip(full_y, 0, target_size[0] - 1)
+
+ # Flatten the arrays to use them as indices
+ full_x = full_x.ravel()
+ full_y = full_y.ravel()
+
+ # Set the points in the image
+ img[full_y, full_x] = 255
+
+ return np.stack([img] * dim, axis=0)
+
+
+def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
+ len_file = audio.shape[-1]
+
+ if max_len_sec or max_len_raw:
+ max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
+ if len_file < int(max_len):
+ # dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
+ # extened_wav = np.concatenate((audio_data, dummy[0]))
+ extened_wav = torch.nn.functional.pad(audio, (0, int(max_len) - len_file), "constant")
+ else:
+ extened_wav = audio[:, : int(max_len)]
+ else:
+ extened_wav = audio
+
+ return extened_wav
+
+
+def ssim_to_bin(ssim_score):
+ # Normalize the SSIM score to a 0-100 scale
+ normalized_diff_ssim = (1 - ((ssim_score + 1) / 2)) * 100
+ # Assign to one of the 100 bins
+ bin_index = float(min(np.floor(normalized_diff_ssim), 99))
+ return bin_index
diff --git a/sgm/data/dataset.py b/sgm/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b726149996591c6c3db69230e1bb68c07d2faa12
--- /dev/null
+++ b/sgm/data/dataset.py
@@ -0,0 +1,80 @@
+from typing import Optional
+
+import torchdata.datapipes.iter
+import webdataset as wds
+from omegaconf import DictConfig
+from pytorch_lightning import LightningDataModule
+
+try:
+ from sdata import create_dataset, create_dummy_dataset, create_loader
+except ImportError as e:
+ print("#" * 100)
+ print("Datasets not yet available")
+ print("to enable, we need to add stable-datasets as a submodule")
+ print("please use ``git submodule update --init --recursive``")
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
+ print("#" * 100)
+ exit(1)
+
+
+class StableDataModuleFromConfig(LightningDataModule):
+ def __init__(
+ self,
+ train: DictConfig,
+ validation: Optional[DictConfig] = None,
+ test: Optional[DictConfig] = None,
+ skip_val_loader: bool = False,
+ dummy: bool = False,
+ ):
+ super().__init__()
+ self.train_config = train
+ assert (
+ "datapipeline" in self.train_config and "loader" in self.train_config
+ ), "train config requires the fields `datapipeline` and `loader`"
+
+ self.val_config = validation
+ if not skip_val_loader:
+ if self.val_config is not None:
+ assert (
+ "datapipeline" in self.val_config and "loader" in self.val_config
+ ), "validation config requires the fields `datapipeline` and `loader`"
+ else:
+ print(
+ "Warning: No Validation datapipeline defined, using that one from training"
+ )
+ self.val_config = train
+
+ self.test_config = test
+ if self.test_config is not None:
+ assert (
+ "datapipeline" in self.test_config and "loader" in self.test_config
+ ), "test config requires the fields `datapipeline` and `loader`"
+
+ self.dummy = dummy
+ if self.dummy:
+ print("#" * 100)
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
+ print("#" * 100)
+
+ def setup(self, stage: str) -> None:
+ print("Preparing datasets")
+ if self.dummy:
+ data_fn = create_dummy_dataset
+ else:
+ data_fn = create_dataset
+
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
+ if self.val_config:
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
+ if self.test_config:
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
+
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
+ return loader
+
+ def val_dataloader(self) -> wds.DataPipeline:
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
+
+ def test_dataloader(self) -> wds.DataPipeline:
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
diff --git a/sgm/data/mask.py b/sgm/data/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..7196958377e34e0ad1bed21c781fefd2667e7573
--- /dev/null
+++ b/sgm/data/mask.py
@@ -0,0 +1,525 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+"""
+Functions taken from https://github.com/DanBigioi/DiffusionVideoEditing
+
+
+"""
+
+import cv2
+import numpy as np
+import torch
+
+" Countour from 2:15 not good for head poses "
+
+
+def face_mask(img_shape, landmark_list, dtype="uint8"):
+ height, width = img_shape[:2]
+ mask = np.ones((height, width, 1), dtype=dtype)
+ cv2.drawContours(
+ mask, np.int32([landmark_list[2:15]]), -1, color=(0), thickness=cv2.FILLED
+ )
+
+ return mask
+
+
+def face_mask_jaw_box(img_shape, landmark_list, dtype="uint8", kernel_size=10):
+ nose = 33
+ jaw = 8
+
+ height, width = img_shape[:2]
+ mask = np.ones((height, width, 1), dtype=dtype)
+ combined_landmarks = np.concatenate((landmark_list[2:15], [landmark_list[33]]))
+
+ # Draw the combined contour on the mask
+ cv2.drawContours(
+ mask, [np.int32(combined_landmarks)], -1, color=(0), thickness=cv2.FILLED
+ )
+
+ inverted_mask = 1 - mask
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
+ mask = cv2.dilate(inverted_mask, kernel, iterations=1)
+ mask = np.expand_dims(
+ mask, axis=-1
+ ) # Add a singleton dimension to match the number of channels
+ mask = 1 - mask
+
+ cut_h = landmark_list[nose][1]
+
+ far_left = int(np.argmin(landmark_list[:, 0]))
+ far_right = int(np.argmax(landmark_list[:, 0]))
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
+ height_landmarks = min(landmark_list[jaw, 1] + 20, height)
+ left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
+ right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
+
+ # print(cut_h, cut_h + 10, height_landmarks)
+
+ mask_box = [left_up_point, left_down_point, right_down_point, right_up_point]
+
+ return mask, mask_box
+
+
+" Stretch the tight face mask - Countour from 2:15 but dilate, not good for extreme head poses "
+
+
+def face_mask_stretch(img_shape, landmark_list, dtype="uint8", kernel_size=10):
+ height, width = img_shape[:2]
+ mask = np.ones((height, width, 1), dtype=dtype)
+ combined_landmarks = np.concatenate((landmark_list[2:15], [landmark_list[33]]))
+
+ # Draw the combined contour on the mask
+ cv2.drawContours(
+ mask, [np.int32(combined_landmarks)], -1, color=(0), thickness=cv2.FILLED
+ )
+
+ # cv2.drawContours(mask, np.int32([landmark_list[2:15]]), -1, color=(0), thickness=cv2.FILLED)
+ inverted_mask = 1 - mask
+
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
+ mask = cv2.dilate(inverted_mask, kernel, iterations=1)
+ mask = np.expand_dims(
+ mask, axis=-1
+ ) # Add a singleton dimension to match the number of channels
+ mask = 1 - mask
+
+ return mask
+
+
+" Small box around mouth - Use far left, far right points for extreme head poses, cut between nose and upper mouth point"
+
+
+def face_mask_box_pose(img_shape, landmark_list, dtype="uint8"):
+ """
+ When the head pose is different than frontal then the normal cropping with landmarks does not work correctly.
+ Crop using as height the middle nose point
+ Take the left/right corners using the far_left and far_right landmarks
+ TODO: Maybe it is better to add some more pixels to have a bigger mask, especially on large head poses
+ """
+
+ height, width = img_shape[:2]
+
+ nose = 33
+ upper_lip = 51
+ jaw = 8
+
+ nose_point_h = landmark_list[nose, 1]
+ upper_lip_point = landmark_list[upper_lip, 1]
+ cut_h = (upper_lip_point - nose_point_h) / 2 + nose_point_h
+
+ # cut_h = landmark_list[nose][1]
+
+ mask = np.ones((height, width, 1), dtype=dtype)
+
+ far_left = int(np.argmin(landmark_list[:, 0]))
+ far_right = int(np.argmax(landmark_list[:, 0]))
+
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
+
+ height_landmarks = min(landmark_list[jaw, 1] + 20, height)
+ left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
+ right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
+
+ cv2.drawContours(
+ mask,
+ np.int32(
+ [
+ [
+ left_up_point,
+ left_down_point,
+ right_up_point,
+ right_down_point,
+ left_up_point,
+ right_up_point,
+ left_down_point,
+ right_down_point,
+ ]
+ ]
+ ),
+ -1,
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ return mask
+
+
+" Small box around mouth - Use far left, far right points for extreme head poses, cut from nose"
+
+
+def face_mask_box_pose_nose(
+ img_shape,
+ landmark_list,
+ dtype="uint8",
+ get_box=False,
+ pixels_above_nose=None,
+ pixels_under_jaw=None,
+):
+ height, width = img_shape[:2]
+
+ nose = 33
+ jaw = 8
+
+ cut_h = landmark_list[nose][1]
+ if pixels_above_nose is not None:
+ # this is only for inference to take a bigger mask and blend it back to the original frame
+ cut_h = cut_h - pixels_above_nose
+
+ mask = np.ones((height, width, 1), dtype=dtype)
+
+ far_left = int(np.argmin(landmark_list[:, 0]))
+ far_right = int(np.argmax(landmark_list[:, 0]))
+
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h]) # 2
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h]) # 15
+
+ height_landmarks = min(landmark_list[jaw, 1] + 20, height)
+ if pixels_under_jaw is not None:
+ height_landmarks = min(landmark_list[jaw, 1] + pixels_under_jaw, height)
+ left_down_point = np.int32([landmark_list[far_left][0], height_landmarks])
+ right_down_point = np.int32([landmark_list[far_right][0], height_landmarks])
+
+ cv2.drawContours(
+ mask,
+ np.int32(
+ [
+ [
+ left_up_point,
+ left_down_point,
+ right_up_point,
+ right_down_point,
+ left_up_point,
+ right_up_point,
+ left_down_point,
+ right_down_point,
+ ]
+ ]
+ ),
+ -1,
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ if get_box:
+ mask_box = [left_up_point, left_down_point, right_down_point, right_up_point]
+ return mask, mask_box
+ else:
+ return mask
+
+
+def face_mask_box_pose_big(
+ img_shape, landmark_list, dtype="uint8", cut_h=None, far_left=None, far_right=None
+):
+ height, width = img_shape[:2]
+ mask = np.ones((height, width, 1), dtype=dtype)
+ nose = 33
+ nose_point_h = landmark_list[nose, 1]
+ if cut_h is None:
+ cut_h = nose_point_h
+
+ if far_right is None and far_left is None:
+ far_left = int(np.argmin(landmark_list[:, 0]))
+ far_right = int(np.argmax(landmark_list[:, 0]))
+
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h])
+ left_down_point = np.int32([landmark_list[far_left][0], height])
+
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h])
+ right_down_point = np.int32([landmark_list[far_right][0], height])
+ else:
+ left_up_point = np.int32([far_left, cut_h])
+ left_down_point = np.int32([far_left, height])
+
+ right_up_point = np.int32([far_right, cut_h])
+ right_down_point = np.int32([far_right, height])
+
+ cv2.drawContours(
+ mask,
+ np.int32(
+ [
+ [
+ left_up_point,
+ left_down_point,
+ right_up_point,
+ right_down_point,
+ left_up_point,
+ right_up_point,
+ left_down_point,
+ right_down_point,
+ ]
+ ]
+ ),
+ -1,
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ return mask
+
+
+def face_mask_box_pose_big_cover_nose(img_shape, landmark_list, dtype="uint8"):
+ height, width = img_shape[:2]
+
+ middle_nose_point = 29
+
+ cut_h = landmark_list[middle_nose_point, 1]
+
+ mask = np.ones((height, width, 1), dtype=dtype)
+
+ far_left = int(np.argmin(landmark_list[:, 0]))
+ far_right = int(np.argmax(landmark_list[:, 0]))
+
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h])
+ left_down_point = np.int32([landmark_list[far_left][0], height])
+
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h])
+ right_down_point = np.int32([landmark_list[far_right][0], height])
+
+ cv2.drawContours(
+ mask,
+ np.int32(
+ [
+ [
+ left_up_point,
+ left_down_point,
+ right_up_point,
+ right_down_point,
+ left_up_point,
+ right_up_point,
+ left_down_point,
+ right_down_point,
+ ]
+ ]
+ ),
+ -1,
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ return mask
+
+
+def face_mask_square(img_shape, landmark_list, dtype="uint8"):
+ height, width = img_shape[:2]
+
+ mask = np.ones((height, width, 1), dtype=dtype)
+
+ far_left = np.min(landmark_list[:, 0])
+ far_right = np.max(landmark_list[:, 1])
+ print("far_left {}, far_right {}".format(far_left, far_right))
+
+ left_p = 2
+ right_p = 14
+
+ print(
+ "left_p {}, right_p {}".format(
+ landmark_list[left_p][0], landmark_list[right_p][0]
+ )
+ )
+
+ cv2.drawContours(
+ mask,
+ np.int32(
+ [
+ [
+ landmark_list[left_p],
+ [landmark_list[left_p][0], height],
+ landmark_list[right_p],
+ [landmark_list[right_p][0], height],
+ landmark_list[left_p],
+ landmark_list[right_p],
+ [landmark_list[left_p][0], height],
+ [landmark_list[right_p][0], height],
+ ]
+ ]
+ ),
+ -1,
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ return mask
+
+
+" Used for half face "
+
+
+def bbox2mask(img_shape, bbox, dtype="uint8"):
+ """Generate mask in ndarray from bbox.
+
+ The returned mask has the shape of (h, w, 1). '1' indicates the
+ hole and '0' indicates the valid regions.
+
+ We prefer to use `uint8` as the data type of masks, which may be different
+ from other codes in the community.
+
+ Args:
+ img_shape (tuple[int]): The size of the image.
+ bbox (tuple[int]): Configuration tuple, (top, left, height, width)
+ dtype (str): Indicate the data type of returned masks. Default: 'uint8'
+
+ Return:
+ numpy.ndarray: Mask in the shape of (h, w, 1).
+ """
+
+ height, width = img_shape[:2]
+
+ mask = np.ones((height, width, 1), dtype=dtype)
+ mask[bbox[0] : bbox[0] + bbox[2], bbox[1] : bbox[1] + bbox[3], :] = 0.0
+
+ return mask
+
+
+def face_mask_cheeks(img_shape, landmark_list, dtype="uint8"):
+ height, width = img_shape[:2]
+ mask = np.ones((height, width, 1), dtype=dtype)
+
+ middle_nose_point = 29
+ nose = 33
+ cut_h = int(landmark_list[middle_nose_point, 1])
+
+ far_left = int(np.argmin(landmark_list[:, 0]))
+ far_right = int(np.argmax(landmark_list[:, 0]))
+
+ left_up_point = np.int32([landmark_list[far_left][0], cut_h])
+ left_down_point = np.int32([landmark_list[far_left][0], height])
+
+ right_up_point = np.int32([landmark_list[far_right][0], cut_h])
+ right_down_point = np.int32([landmark_list[far_right][0], height])
+
+ cv2.drawContours(
+ mask,
+ np.int32(
+ [
+ [
+ left_up_point,
+ left_down_point,
+ right_up_point,
+ right_down_point,
+ left_up_point,
+ right_up_point,
+ left_down_point,
+ right_down_point,
+ ]
+ ]
+ ),
+ -1,
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ # Calculate the bounding box coordinates for the nose
+ nose_jaw_dist = (
+ abs(landmark_list[2][0] - landmark_list[middle_nose_point][0]) * 0.10
+ ) # 1, 15
+ # nose_right_dist = (landmark_list[middle_nose_point][0] - landmark_list[1][0]) * 0.10
+ # nose_left_dist = (landmark_list[15][0] - landmark_list[middle_nose_point][0]) * 0.10
+ #
+
+ nose_min_x = int(landmark_list[31][0] - nose_jaw_dist)
+ nose_max_x = int(landmark_list[35][0] + nose_jaw_dist)
+ # nose_min_x = int(landmark_list[31][0] - nose_right_dist)
+ # nose_max_x = int(landmark_list[35][0] + nose_left_dist)
+ nose_min_y = cut_h
+ nose_max_y = int(landmark_list[nose, 1])
+
+ # Clear the nose area from the mask using a rectangle
+ mask_nose = np.ones((height, width, 1), dtype=dtype)
+ cv2.rectangle(
+ mask_nose,
+ (nose_min_x, nose_min_y),
+ (nose_max_x, nose_max_y),
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ mask_nose = 1 - mask_nose
+ mask = mask + mask_nose
+
+ return mask
+
+
+def face_mask_cheeks_batch(
+ img_shape, landmark_list, dtype="uint8", box_expand=0.0, show_nose=True
+):
+ height, width = img_shape[:2]
+
+ # Handle both single and multiple landmarks
+ if len(landmark_list.shape) == 2:
+ landmark_list = landmark_list[None, ...] # Add batch dimension
+ num_frames = landmark_list.shape[0]
+
+ # Initialize masks for all frames
+ masks = np.ones((num_frames, height, width), dtype=dtype)
+
+ for i in range(num_frames):
+ landmarks = landmark_list[i]
+ middle_nose_point = 29
+ nose = 33
+ cut_h = int(landmarks[middle_nose_point, 1])
+
+ # Add height expansion
+ if box_expand > 0:
+ cut_h = max(0, cut_h - int(box_expand * height))
+
+ far_left = int(np.argmin(landmarks[:, 0]))
+ far_right = int(np.argmax(landmarks[:, 0]))
+
+ left_up_point = np.int32([landmarks[far_left][0], cut_h])
+ left_down_point = np.int32([landmarks[far_left][0], height])
+
+ right_up_point = np.int32([landmarks[far_right][0], cut_h])
+ right_down_point = np.int32([landmarks[far_right][0], height])
+
+ cv2.drawContours(
+ masks[i],
+ np.int32(
+ [
+ [
+ left_up_point,
+ left_down_point,
+ right_up_point,
+ right_down_point,
+ left_up_point,
+ right_up_point,
+ left_down_point,
+ right_down_point,
+ ]
+ ]
+ ),
+ -1,
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ if show_nose:
+ # Calculate the bounding box coordinates for the nose
+ nose_jaw_dist = (
+ abs(landmarks[2][0] - landmarks[middle_nose_point][0]) * 0.10
+ ) # 1, 15
+
+ nose_min_x = int(landmarks[31][0] - nose_jaw_dist)
+ nose_max_x = int(landmarks[35][0] + nose_jaw_dist)
+ nose_min_y = cut_h
+ nose_max_y = int(landmarks[nose, 1])
+
+ # Clear the nose area from the mask using a rectangle
+ mask_nose = np.ones((height, width), dtype=dtype)
+ cv2.rectangle(
+ mask_nose,
+ (nose_min_x, nose_min_y),
+ (nose_max_x, nose_max_y),
+ color=(0),
+ thickness=cv2.FILLED,
+ )
+
+ mask_nose = 1 - mask_nose
+ masks[i] = masks[i] + mask_nose
+
+ # If input was single frame, return single mask
+ if landmark_list.shape[0] == 1:
+ return masks[0]
+
+ return 1 - torch.from_numpy(masks)
diff --git a/sgm/data/video_datamodule_latent.py b/sgm/data/video_datamodule_latent.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f20cc5548235f2a81d542d84e2398282d88017a
--- /dev/null
+++ b/sgm/data/video_datamodule_latent.py
@@ -0,0 +1,138 @@
+from typing import Any, Dict, Optional
+
+from pytorch_lightning import LightningDataModule
+from torch.utils.data import DataLoader
+from omegaconf import DictConfig
+
+import sys
+import pyrootutils
+
+root = pyrootutils.setup_root(__file__, pythonpath=True)
+sys.path.append(root)
+from sgm.data.video_dataset_latent import VideoDataset
+
+
+class VideoDataModule(LightningDataModule):
+ """
+ A DataModule implements 5 key methods:
+
+ def prepare_data(self):
+ # things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
+ # download data, pre-process, split, save to disk, etc...
+ def setup(self, stage):
+ # things to do on every process in DDP
+ # load data, set variables, etc...
+ def train_dataloader(self):
+ # return train dataloader
+ def val_dataloader(self):
+ # return validation dataloader
+ def test_dataloader(self):
+ # return test dataloader
+ def teardown(self):
+ # called on every process in DDP
+ # clean up after fit or test
+
+ This allows you to share a full dataset without explaining how to download,
+ split, transform and process the data.
+
+ Read the docs:
+ https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
+ """
+
+ def __init__(
+ self,
+ train: DictConfig,
+ validation: Optional[DictConfig] = None,
+ test: Optional[DictConfig] = None,
+ skip_val_loader: bool = False,
+ ):
+ super().__init__()
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # also ensures init params will be stored in ckpt
+ self.train_config = train
+ assert "datapipeline" in self.train_config and "loader" in self.train_config, (
+ "train config requires the fields `datapipeline` and `loader`"
+ )
+
+ self.val_config = validation
+ if not skip_val_loader:
+ if self.val_config is not None:
+ assert (
+ "datapipeline" in self.val_config and "loader" in self.val_config
+ ), "validation config requires the fields `datapipeline` and `loader`"
+ else:
+ print(
+ "Warning: No Validation datapipeline defined, using that one from training"
+ )
+ self.val_config = train
+
+ self.test_config = test
+ if self.test_config is not None:
+ assert (
+ "datapipeline" in self.test_config and "loader" in self.test_config
+ ), "test config requires the fields `datapipeline` and `loader`"
+
+ def setup(self, stage: Optional[str] = None):
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
+
+ This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
+ careful not to execute things like random split twice!
+ """
+ print("Preparing datasets")
+
+ self.train_datapipeline = VideoDataset(**self.train_config.datapipeline)
+ if self.val_config:
+ self.val_datapipeline = VideoDataset(**self.val_config.datapipeline)
+ if self.test_config:
+ self.test_datapipeline = VideoDataset(**self.test_config.datapipeline)
+
+ def train_dataloader(self):
+ return DataLoader(self.train_datapipeline, **self.train_config.loader)
+
+ def val_dataloader(self):
+ if self.val_datapipeline:
+ return DataLoader(self.val_datapipeline, **self.val_config.loader)
+ else:
+ return None
+
+ def test_dataloader(self):
+ if self.test_datapipeline:
+ return DataLoader(self.test_datapipeline, **self.test_config.loader)
+ else:
+ return None
+
+ def teardown(self, stage: Optional[str] = None):
+ """Clean up after fit or test."""
+ pass
+
+ def state_dict(self):
+ """Extra things to save to checkpoint."""
+ return {}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]):
+ """Things to do when loading checkpoint."""
+ pass
+
+
+if __name__ == "__main__":
+ import hydra
+ import omegaconf
+ import pyrootutils
+ import cv2
+
+ root = pyrootutils.setup_root(__file__, pythonpath=True)
+ cfg = omegaconf.OmegaConf.load(
+ root / "configs" / "datamodule" / "image_datamodule.yaml"
+ )
+ # cfg.data_dir = str(root / "data")
+ data = hydra.utils.instantiate(cfg)
+ data.prepare_data()
+ data.setup()
+ print(data.data_train.__getitem__(0)[0].shape)
+ batch = next(iter(data.train_dataloader()))
+ identity, target = batch
+ image_identity = (identity[0].permute(1, 2, 0).numpy() + 1) / 2 * 255
+ image_other = (target[0].permute(1, 2, 0).numpy() + 1) / 2 * 255
+ cv2.imwrite("image_identity.png", image_identity[:, :, ::-1])
+ cv2.imwrite("image_other.png", image_other[:, :, ::-1])
diff --git a/sgm/data/video_dataset_latent.py b/sgm/data/video_dataset_latent.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ff9f63f998380732a8d3f6c4738434c33a71f54
--- /dev/null
+++ b/sgm/data/video_dataset_latent.py
@@ -0,0 +1,780 @@
+import random
+import numpy as np
+from functools import partial
+from torch.utils.data import Dataset, WeightedRandomSampler
+import torch.nn.functional as F
+import torch
+import math
+import decord
+from einops import rearrange
+from more_itertools import sliding_window
+from omegaconf import ListConfig
+import torchaudio
+import soundfile as sf
+from torchvision.transforms import RandomHorizontalFlip
+from audiomentations import Compose, AddGaussianNoise, PitchShift
+from safetensors.torch import load_file
+from tqdm import tqdm
+import cv2
+from sgm.data.data_utils import (
+ create_masks_from_landmarks_full_size,
+ create_face_mask_from_landmarks,
+ create_masks_from_landmarks_box,
+ create_masks_from_landmarks_mouth,
+)
+from sgm.data.mask import face_mask_cheeks_batch
+
+torchaudio.set_audio_backend("sox_io")
+decord.bridge.set_bridge("torch")
+
+
+def exists(x):
+ return x is not None
+
+
+def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
+ len_file = audio.shape[-1]
+
+ if max_len_sec or max_len_raw:
+ max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
+ if len_file < int(max_len):
+ extened_wav = torch.nn.functional.pad(
+ audio, (0, int(max_len) - len_file), "constant"
+ )
+ else:
+ extened_wav = audio[:, : int(max_len)]
+ else:
+ extened_wav = audio
+
+ return extened_wav
+
+
+# Similar to regular video dataset but trades flexibility for speed
+class VideoDataset(Dataset):
+ def __init__(
+ self,
+ filelist,
+ resize_size=None,
+ audio_folder="Audio",
+ video_folder="CroppedVideos",
+ emotions_folder="emotions",
+ landmarks_folder=None,
+ audio_emb_folder=None,
+ video_extension=".avi",
+ audio_extension=".wav",
+ audio_rate=16000,
+ latent_folder=None,
+ audio_in_video=False,
+ fps=25,
+ num_frames=5,
+ need_cond=True,
+ step=1,
+ mode="prediction",
+ scale_audio=False,
+ augment=False,
+ augment_audio=False,
+ use_latent=False,
+ latent_type="stable",
+ latent_scale=1, # For backwards compatibility
+ from_audio_embedding=False,
+ load_all_possible_indexes=False,
+ audio_emb_type="wavlm",
+ cond_noise=[-3.0, 0.5],
+ motion_id=255.0,
+ data_mean=None,
+ data_std=None,
+ use_latent_condition=False,
+ skip_frames=0,
+ get_separate_id=False,
+ virtual_increase=1,
+ filter_by_length=False,
+ select_randomly=False,
+ balance_datasets=True,
+ use_emotions=False,
+ get_original_frames=False,
+ add_extra_audio_emb=False,
+ expand_box=0.0,
+ nose_index=28,
+ what_mask="full",
+ get_masks=False,
+ ):
+ self.audio_folder = audio_folder
+ self.from_audio_embedding = from_audio_embedding
+ self.audio_emb_type = audio_emb_type
+ self.cond_noise = cond_noise
+ self.latent_condition = use_latent_condition
+ precomputed_latent = latent_type
+ self.audio_emb_folder = (
+ audio_emb_folder if audio_emb_folder is not None else audio_folder
+ )
+ self.skip_frames = skip_frames
+ self.get_separate_id = get_separate_id
+ self.fps = fps
+ self.virtual_increase = virtual_increase
+ self.select_randomly = select_randomly
+ self.use_emotions = use_emotions
+ self.emotions_folder = emotions_folder
+ self.get_original_frames = get_original_frames
+ self.add_extra_audio_emb = add_extra_audio_emb
+ self.expand_box = expand_box
+ self.nose_index = nose_index
+ self.landmarks_folder = landmarks_folder
+ self.what_mask = what_mask
+ self.get_masks = get_masks
+
+ assert not (exists(data_mean) ^ exists(data_std)), (
+ "Both data_mean and data_std should be provided"
+ )
+
+ if data_mean is not None:
+ data_mean = rearrange(torch.as_tensor(data_mean), "c -> c () () ()")
+ data_std = rearrange(torch.as_tensor(data_std), "c -> c () () ()")
+ self.data_mean = data_mean
+ self.data_std = data_std
+ self.motion_id = motion_id
+ self.latent_folder = (
+ latent_folder if latent_folder is not None else video_folder
+ )
+ self.audio_in_video = audio_in_video
+
+ self.filelist = []
+ self.audio_filelist = []
+ self.landmark_filelist = [] if get_masks else None
+ with open(filelist, "r") as files:
+ for f in files.readlines():
+ f = f.rstrip()
+
+ audio_path = f.replace(video_folder, audio_folder).replace(
+ video_extension, audio_extension
+ )
+
+ self.filelist += [f]
+ self.audio_filelist += [audio_path]
+ if self.get_masks:
+ landmark_path = f.replace(video_folder, landmarks_folder).replace(
+ video_extension, ".npy"
+ )
+ self.landmark_filelist += [landmark_path]
+
+ self.resize_size = resize_size
+ if use_latent and not precomputed_latent:
+ self.resize_size *= 4 if latent_type in ["stable", "ldm"] else 8
+ self.scale_audio = scale_audio
+ self.step = step
+ self.use_latent = use_latent
+ self.precomputed_latent = precomputed_latent
+ self.latent_type = latent_type
+ self.latent_scale = latent_scale
+ self.video_ext = video_extension
+ self.video_folder = video_folder
+
+ self.augment = augment
+ self.maybe_augment = RandomHorizontalFlip(p=0.5) if augment else lambda x: x
+ self.maybe_augment_audio = (
+ Compose(
+ [
+ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.002, p=0.25),
+ # TimeStretch(min_rate=0.8, max_rate=1.25, p=0.3),
+ PitchShift(min_semitones=-1, max_semitones=1, p=0.25),
+ # Shift(min_fraction=-0.5, max_fraction=0.5, p=0.333),
+ ]
+ )
+ if augment_audio
+ else lambda x, sample_rate: x
+ )
+ self.maybe_augment_audio = partial(
+ self.maybe_augment_audio, sample_rate=audio_rate
+ )
+
+ self.mode = mode
+ if mode == "interpolation":
+ need_cond = False # Interpolation does not need condition as first and last frame becomes the condition
+ self.need_cond = need_cond # If need cond will extract one more frame than the number of frames
+ if get_separate_id:
+ self.need_cond = True
+ # It is used for the conditional model when the condition is not on the temporal dimension
+ num_frames = num_frames if not self.need_cond else num_frames + 1
+
+ vr = decord.VideoReader(self.filelist[0])
+ self.video_rate = math.ceil(vr.get_avg_fps())
+ print(f"Video rate: {self.video_rate}")
+ self.audio_rate = audio_rate
+ a2v_ratio = fps / float(self.audio_rate)
+ self.samples_per_frame = math.ceil(1 / a2v_ratio)
+
+ if get_separate_id:
+ assert mode == "prediction", (
+ "Separate identity frame is only supported for prediction mode"
+ )
+ # No need for extra frame if we are getting a separate identity frame
+ self.need_cond = True
+ num_frames -= 1
+ self.num_frames = num_frames
+ self.load_all_possible_indexes = load_all_possible_indexes
+ if load_all_possible_indexes:
+ self._indexes = self._get_indexes(
+ self.filelist, self.audio_filelist, self.landmark_filelist
+ )
+ else:
+ if filter_by_length:
+ self._indexes = self.filter_by_length(
+ self.filelist, self.audio_filelist, self.landmark_filelist
+ )
+ else:
+ if self.get_masks:
+ self._indexes = list(
+ zip(self.filelist, self.audio_filelist, self.landmark_filelist)
+ )
+ else:
+ self._indexes = list(
+ zip(
+ self.filelist,
+ self.audio_filelist,
+ [None] * len(self.filelist),
+ )
+ )
+
+ self.balance_datasets = balance_datasets
+ if self.balance_datasets:
+ self.weights = self._calculate_weights()
+ self.sampler = WeightedRandomSampler(
+ self.weights, num_samples=len(self._indexes), replacement=True
+ )
+
+ def __len__(self):
+ return len(self._indexes) * self.virtual_increase
+
+ def _load_landmarks(self, filename, original_size, target_size, indexes):
+ landmarks = np.load(filename, allow_pickle=True)[indexes, :]
+ if self.what_mask == "full":
+ mask = create_masks_from_landmarks_full_size(
+ landmarks,
+ original_size[0],
+ original_size[1],
+ offset=self.expand_box,
+ nose_index=self.nose_index,
+ )
+ elif self.what_mask == "box":
+ mask = create_masks_from_landmarks_box(
+ landmarks,
+ (original_size[0], original_size[1]),
+ box_expand=self.expand_box,
+ nose_index=self.nose_index,
+ )
+ elif self.what_mask == "heart":
+ mask = face_mask_cheeks_batch(
+ original_size, landmarks, box_expand=0.0, show_nose=True
+ )
+ elif self.what_mask == "mouth":
+ mask = create_masks_from_landmarks_mouth(
+ landmarks,
+ (original_size[0], original_size[1]),
+ box_expand=0.01,
+ nose_index=self.nose_index,
+ )
+ else:
+ mask = create_face_mask_from_landmarks(
+ landmarks, original_size[0], original_size[1], mask_expand=0.05
+ )
+ # Interpolate the mask to the target size
+ mask = F.interpolate(
+ mask.unsqueeze(1).float(), size=target_size, mode="nearest"
+ )
+
+ return mask, landmarks
+
+ def get_emotions(self, video_file, video_indexes):
+ emotions_path = video_file.replace(
+ self.video_folder, self.emotions_folder
+ ).replace(self.video_ext, ".pt")
+ emotions = torch.load(emotions_path)
+ return (
+ emotions["valence"][video_indexes],
+ emotions["arousal"][video_indexes],
+ emotions["labels"][video_indexes],
+ )
+
+ def get_frame_indices(self, total_video_frames, select_randomly=False, start_idx=0):
+ if select_randomly:
+ # Randomly select self.num_frames indices from the available range
+ available_indices = list(range(start_idx, total_video_frames))
+ if len(available_indices) < self.num_frames:
+ raise ValueError(
+ "Not enough frames in the video to sample with given parameters."
+ )
+ indexes = random.sample(available_indices, self.num_frames)
+ return sorted(indexes) # Sort to maintain temporal order
+ else:
+ # Calculate the maximum possible start index
+ max_start_idx = total_video_frames - (
+ (self.num_frames - 1) * (self.skip_frames + 1) + 1
+ )
+
+ # Generate a random start index
+ if max_start_idx > 0:
+ start_idx = np.random.randint(start_idx, max_start_idx)
+ else:
+ raise ValueError(
+ "Not enough frames in the video to sample with given parameters."
+ )
+
+ # Generate the indices
+ indexes = [
+ start_idx + i * (self.skip_frames + 1) for i in range(self.num_frames)
+ ]
+
+ return indexes
+
+ def _load_audio(self, filename, max_len_sec, start=None, indexes=None):
+ audio, sr = sf.read(
+ filename,
+ start=math.ceil(start * self.audio_rate),
+ frames=math.ceil(self.audio_rate * max_len_sec),
+ always_2d=True,
+ ) # e.g (16000, 1)
+ audio = audio.T # (1, 16000)
+ assert sr == self.audio_rate, (
+ f"Audio rate is {sr} but should be {self.audio_rate}"
+ )
+ audio = audio.mean(0, keepdims=True)
+ audio = self.maybe_augment_audio(audio)
+ audio = torch.from_numpy(audio).float()
+ # audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=self.audio_rate)
+ audio = trim_pad_audio(audio, self.audio_rate, max_len_sec=max_len_sec)
+ return audio[0]
+
+ def ensure_shape(self, tensors):
+ target_length = self.samples_per_frame
+ processed_tensors = []
+ for tensor in tensors:
+ current_length = tensor.shape[1]
+ diff = current_length - target_length
+ assert abs(diff) <= 5, (
+ f"Expected shape {target_length}, but got {current_length}"
+ )
+ if diff < 0:
+ # Calculate how much padding is needed
+ padding_needed = target_length - current_length
+ # Pad the tensor
+ padded_tensor = F.pad(tensor, (0, padding_needed))
+ processed_tensors.append(padded_tensor)
+ elif diff > 0:
+ # Trim the tensor
+ trimmed_tensor = tensor[:, :target_length]
+ processed_tensors.append(trimmed_tensor)
+ else:
+ # If it's already the correct size
+ processed_tensors.append(tensor)
+ return torch.cat(processed_tensors)
+
+ def normalize_latents(self, latents):
+ if self.data_mean is not None:
+ # Normalize latents to 0 mean and 0.5 std
+ latents = ((latents - self.data_mean) / self.data_std) * 0.5
+ return latents
+
+ def convert_indexes(self, indexes_25fps, fps_from=25, fps_to=60):
+ ratio = fps_to / fps_from
+ indexes_60fps = [int(index * ratio) for index in indexes_25fps]
+ return indexes_60fps
+
+ def _get_frames_and_audio(self, idx):
+ if self.load_all_possible_indexes:
+ indexes, video_file, audio_file, land_file = self._indexes[idx]
+ if self.audio_in_video:
+ vr = decord.AVReader(video_file, sample_rate=self.audio_rate)
+ else:
+ vr = decord.VideoReader(video_file)
+ len_video = len(vr)
+ if "AA_processed" in video_file or "1000actors_nsv" in video_file:
+ len_video *= 25 / 60
+ len_video = int(len_video)
+ else:
+ video_file, audio_file, land_file = self._indexes[idx]
+ if self.audio_in_video:
+ vr = decord.AVReader(video_file, sample_rate=self.audio_rate)
+ else:
+ vr = decord.VideoReader(video_file)
+ len_video = len(vr)
+ if "AA_processed" in video_file or "1000actors_nsv" in video_file:
+ len_video *= 25 / 60
+ len_video = int(len_video)
+
+ indexes = self.get_frame_indices(
+ len_video,
+ select_randomly=self.select_randomly,
+ start_idx=120 if "1000actors_nsv" in video_file else 0,
+ )
+
+ if self.get_separate_id:
+ id_idx = np.random.randint(0, len_video)
+ indexes.insert(0, id_idx)
+
+ if "AA_processed" in video_file or "1000actors_nsv" in video_file:
+ video_indexes = self.convert_indexes(indexes, fps_from=25, fps_to=60)
+ audio_file = audio_file.replace("_output_output", "")
+ if self.audio_emb_type == "wav2vec2" and "AA_processed" in video_file:
+ audio_path_extra = ".safetensors"
+ else:
+ audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors"
+
+ video_path_extra = f"_{self.latent_type}_512_latent.safetensors"
+ audio_path_extra_extra = (
+ ".pt" if "AA_processed" in video_file else "_beats_emb.pt"
+ )
+
+ else:
+ video_indexes = indexes
+ audio_path_extra = f"_{self.audio_emb_type}_emb.safetensors"
+ video_path_extra = f"_{self.latent_type}_512_latent.safetensors"
+ audio_path_extra_extra = "_beats_emb.pt"
+
+ emotions = None
+ if self.use_emotions:
+ emotions = self.get_emotions(video_file, video_indexes)
+ if self.get_separate_id:
+ emotions = (emotions[0][1:], emotions[1][1:], emotions[2][1:])
+
+ raw_audio = None
+ if self.audio_in_video:
+ raw_audio, frames_video = vr.get_batch(video_indexes)
+ raw_audio = rearrange(self.ensure_shape(raw_audio), "f s -> (f s)")
+
+ if self.use_latent and self.precomputed_latent:
+ latent_file = video_file.replace(self.video_ext, video_path_extra).replace(
+ self.video_folder, self.latent_folder
+ )
+ frames = load_file(latent_file)["latents"][video_indexes, :, :, :]
+
+ if frames.shape[-1] != 64:
+ print(f"Frames shape: {frames.shape}, video file: {video_file}")
+
+ frames = rearrange(frames, "t c h w -> c t h w") * self.latent_scale
+ frames = self.normalize_latents(frames)
+ else:
+ if self.audio_in_video:
+ frames = frames_video.permute(3, 0, 1, 2).float()
+ else:
+ frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float()
+
+ if raw_audio is None:
+ # Audio is not in video
+ raw_audio = self._load_audio(
+ audio_file,
+ max_len_sec=frames.shape[1] / self.fps,
+ start=indexes[0] / self.fps,
+ # indexes=indexes,
+ )
+ if not self.from_audio_embedding:
+ audio = raw_audio
+ audio_frames = rearrange(audio, "(f s) -> f s", s=self.samples_per_frame)
+ else:
+ audio = load_file(
+ audio_file.replace(self.audio_folder, self.audio_emb_folder).split(".")[
+ 0
+ ]
+ + audio_path_extra
+ )["audio"]
+ audio_frames = audio[indexes, :]
+ if self.add_extra_audio_emb:
+ audio_extra = torch.load(
+ audio_file.replace(self.audio_folder, self.audio_emb_folder).split(
+ "."
+ )[0]
+ + audio_path_extra_extra
+ )
+ audio_extra = audio_extra[indexes, :]
+ audio_frames = torch.cat([audio_frames, audio_extra], dim=-1)
+
+ audio_frames = (
+ audio_frames[1:] if self.need_cond else audio_frames
+ ) # Remove audio of first frame
+
+ if self.get_original_frames:
+ original_frames = vr.get_batch(video_indexes).permute(3, 0, 1, 2).float()
+ original_frames = self.scale_and_crop((original_frames / 255.0) * 2 - 1)
+ original_frames = (
+ original_frames[:, 1:] if self.need_cond else original_frames
+ )
+ else:
+ original_frames = None
+
+ if not self.use_latent or (self.use_latent and not self.precomputed_latent):
+ frames = self.scale_and_crop((frames / 255.0) * 2 - 1)
+
+ target = frames[:, 1:] if self.need_cond else frames
+ if self.mode == "prediction":
+ if self.use_latent:
+ if self.audio_in_video:
+ clean_cond = (
+ frames_video[0].unsqueeze(0).permute(3, 0, 1, 2).float()
+ )
+ else:
+ clean_cond = (
+ vr[video_indexes[0]].unsqueeze(0).permute(3, 0, 1, 2).float()
+ )
+ original_size = clean_cond.shape[-2:]
+ clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1).squeeze(
+ 0
+ )
+ if self.latent_condition:
+ noisy_cond = frames[:, 0]
+ else:
+ noisy_cond = clean_cond
+ else:
+ clean_cond = frames[:, 0]
+ noisy_cond = clean_cond
+ elif self.mode == "interpolation":
+ if self.use_latent:
+ if self.audio_in_video:
+ clean_cond = frames_video[[0, -1]].permute(3, 0, 1, 2).float()
+ else:
+ clean_cond = (
+ vr.get_batch([video_indexes[0], video_indexes[-1]])
+ .permute(3, 0, 1, 2)
+ .float()
+ )
+ original_size = clean_cond.shape[-2:]
+ clean_cond = self.scale_and_crop((clean_cond / 255.0) * 2 - 1)
+ if self.latent_condition:
+ noisy_cond = torch.stack([target[:, 0], target[:, -1]], dim=1)
+ else:
+ noisy_cond = clean_cond
+ else:
+ clean_cond = torch.stack([target[:, 0], target[:, -1]], dim=1)
+ noisy_cond = clean_cond
+
+ # Add noise to conditional frame
+ if self.cond_noise and isinstance(self.cond_noise, ListConfig):
+ cond_noise = (
+ self.cond_noise[0] + self.cond_noise[1] * torch.randn((1,))
+ ).exp()
+ noisy_cond = noisy_cond + cond_noise * torch.randn_like(noisy_cond)
+ else:
+ noisy_cond = noisy_cond + self.cond_noise * torch.randn_like(noisy_cond)
+ cond_noise = self.cond_noise
+
+ if self.get_masks:
+ target_size = (
+ (self.resize_size, self.resize_size)
+ if not self.use_latent
+ else (self.resize_size // 8, self.resize_size // 8)
+ )
+ masks, landmarks = self._load_landmarks(
+ land_file, original_size, target_size, video_indexes
+ )
+
+ landmarks = None
+ masks = (
+ masks.permute(1, 0, 2, 3)[:, 1:]
+ if self.need_cond
+ else masks.permute(1, 0, 2, 3)
+ )
+ else:
+ masks = None
+ landmarks = None
+
+ return (
+ original_frames,
+ clean_cond,
+ noisy_cond,
+ target,
+ audio_frames,
+ raw_audio,
+ cond_noise,
+ emotions,
+ masks,
+ landmarks,
+ )
+
+ def filter_by_length(self, video_filelist, audio_filelist):
+ def with_opencv(filename):
+ video = cv2.VideoCapture(filename)
+ frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
+
+ return int(frame_count)
+
+ filtered_video = []
+ filtered_audio = []
+ min_length = (self.num_frames - 1) * (self.skip_frames + 1) + 1
+ for vid_file, audio_file in tqdm(
+ zip(video_filelist, audio_filelist),
+ total=len(video_filelist),
+ desc="Filtering",
+ ):
+ # vr = decord.VideoReader(vid_file)
+
+ len_video = with_opencv(vid_file)
+ # Short videos
+ if len_video < min_length:
+ continue
+ filtered_video.append(vid_file)
+ filtered_audio.append(audio_file)
+ print(f"New number of files: {len(filtered_video)}")
+ return filtered_video, filtered_audio
+
+ def _get_indexes(self, video_filelist, audio_filelist):
+ indexes = []
+ self.og_shape = None
+ for vid_file, audio_file in zip(video_filelist, audio_filelist):
+ vr = decord.VideoReader(vid_file)
+ if self.og_shape is None:
+ self.og_shape = vr[0].shape[-2]
+ len_video = len(vr)
+ # Short videos
+ if len_video < self.num_frames:
+ continue
+ else:
+ possible_indexes = list(
+ sliding_window(range(len_video), self.num_frames)
+ )[:: self.step]
+ possible_indexes = list(
+ map(lambda x: (x, vid_file, audio_file), possible_indexes)
+ )
+ indexes.extend(possible_indexes)
+ print("Indexes", len(indexes), "\n")
+ return indexes
+
+ def scale_and_crop(self, video):
+ h, w = video.shape[-2], video.shape[-1]
+ # scale shorter side to resolution
+
+ if self.resize_size is not None:
+ scale = self.resize_size / min(h, w)
+ if h < w:
+ target_size = (self.resize_size, math.ceil(w * scale))
+ else:
+ target_size = (math.ceil(h * scale), self.resize_size)
+ video = F.interpolate(
+ video,
+ size=target_size,
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+
+ # center crop
+ h, w = video.shape[-2], video.shape[-1]
+ w_start = (w - self.resize_size) // 2
+ h_start = (h - self.resize_size) // 2
+ video = video[
+ :,
+ :,
+ h_start : h_start + self.resize_size,
+ w_start : w_start + self.resize_size,
+ ]
+ return self.maybe_augment(video)
+
+ def _calculate_weights(self):
+ aa_processed_count = sum(
+ 1
+ for item in self._indexes
+ if "AA_processed" in (item[1] if len(item) == 3 else item[0])
+ )
+ nsv_processed_count = sum(
+ 1
+ for item in self._indexes
+ if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0])
+ )
+ other_count = len(self._indexes) - aa_processed_count - nsv_processed_count
+
+ aa_processed_weight = 1 / aa_processed_count if aa_processed_count > 0 else 0
+ nsv_processed_weight = 1 / nsv_processed_count if nsv_processed_count > 0 else 0
+ other_weight = 1 / other_count if other_count > 0 else 0
+
+ print(
+ f"AA processed count: {aa_processed_count}, NSV processed count: {nsv_processed_count}, other count: {other_count}"
+ )
+ print(f"AA processed weight: {aa_processed_weight}")
+ print(f"NSV processed weight: {nsv_processed_weight}")
+ print(f"Other weight: {other_weight}")
+
+ weights = [
+ aa_processed_weight
+ if "AA_processed" in (item[1] if len(item) == 3 else item[0])
+ else nsv_processed_weight
+ if "1000actors_nsv" in (item[1] if len(item) == 3 else item[0])
+ else other_weight
+ for item in self._indexes
+ ]
+ return weights
+
+ def __getitem__(self, idx):
+ if self.balance_datasets:
+ idx = self.sampler.__iter__().__next__()
+
+ try:
+ (
+ original_frames,
+ clean_cond,
+ noisy_cond,
+ target,
+ audio,
+ raw_audio,
+ cond_noise,
+ emotions,
+ masks,
+ landmarks,
+ ) = self._get_frames_and_audio(idx % len(self._indexes))
+ except Exception as e:
+ print(f"Error with index {idx}: {e}")
+ return self.__getitem__(np.random.randint(0, len(self)))
+ out_data = {}
+
+ if original_frames is not None:
+ out_data["original_frames"] = original_frames
+
+ if audio is not None:
+ out_data["audio_emb"] = audio
+ out_data["raw_audio"] = raw_audio
+
+ if self.use_emotions:
+ out_data["valence"] = emotions[0]
+ out_data["arousal"] = emotions[1]
+ out_data["emo_labels"] = emotions[2]
+ if self.use_latent:
+ input_key = "latents"
+ else:
+ input_key = "frames"
+ out_data[input_key] = target
+ if noisy_cond is not None:
+ out_data["cond_frames"] = noisy_cond
+ out_data["cond_frames_without_noise"] = clean_cond
+ if cond_noise is not None:
+ out_data["cond_aug"] = cond_noise
+
+ if masks is not None:
+ out_data["masks"] = masks
+ out_data["gt"] = target
+ if landmarks is not None:
+ out_data["landmarks"] = landmarks
+
+ out_data["motion_bucket_id"] = torch.tensor([self.motion_id])
+ out_data["fps_id"] = torch.tensor([self.fps - 1])
+ out_data["num_video_frames"] = self.num_frames
+ out_data["image_only_indicator"] = torch.zeros(self.num_frames)
+ return out_data
+
+
+if __name__ == "__main__":
+ import torchvision.transforms as transforms
+ import cv2
+
+ transform = transforms.Compose(transforms=[transforms.Resize((256, 256))])
+ dataset = VideoDataset(
+ "/vol/paramonos2/projects/antoni/datasets/mahnob/filelist_videos_val.txt",
+ transform=transform,
+ num_frames=25,
+ )
+ print(len(dataset))
+ idx = np.random.randint(0, len(dataset))
+
+ for i in range(10):
+ print(dataset[i][0].shape, dataset[i][1].shape)
+
+ image_identity = (dataset[idx][0].permute(1, 2, 0).numpy() + 1) / 2 * 255
+ image_other = (dataset[idx][1][:, -1].permute(1, 2, 0).numpy() + 1) / 2 * 255
+ cv2.imwrite("image_identity.png", image_identity[:, :, ::-1])
+ for i in range(25):
+ image = (dataset[idx][1][:, i].permute(1, 2, 0).numpy() + 1) / 2 * 255
+ cv2.imwrite(f"tmp_vid_dataset/image_{i}.png", image[:, :, ::-1])
diff --git a/sgm/inference/api.py b/sgm/inference/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..a359a67bcd9740acc9e320d2f26dc6a3befb36e0
--- /dev/null
+++ b/sgm/inference/api.py
@@ -0,0 +1,385 @@
+import pathlib
+from dataclasses import asdict, dataclass
+from enum import Enum
+from typing import Optional
+
+from omegaconf import OmegaConf
+
+from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
+ do_sample)
+from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
+ DPMPP2SAncestralSampler,
+ EulerAncestralSampler,
+ EulerEDMSampler,
+ HeunEDMSampler,
+ LinearMultistepSampler)
+from sgm.util import load_model_from_config
+
+
+class ModelArchitecture(str, Enum):
+ SD_2_1 = "stable-diffusion-v2-1"
+ SD_2_1_768 = "stable-diffusion-v2-1-768"
+ SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
+ SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
+ SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
+ SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
+
+
+class Sampler(str, Enum):
+ EULER_EDM = "EulerEDMSampler"
+ HEUN_EDM = "HeunEDMSampler"
+ EULER_ANCESTRAL = "EulerAncestralSampler"
+ DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
+ DPMPP2M = "DPMPP2MSampler"
+ LINEAR_MULTISTEP = "LinearMultistepSampler"
+
+
+class Discretization(str, Enum):
+ LEGACY_DDPM = "LegacyDDPMDiscretization"
+ EDM = "EDMDiscretization"
+
+
+class Guider(str, Enum):
+ VANILLA = "VanillaCFG"
+ IDENTITY = "IdentityGuider"
+
+
+class Thresholder(str, Enum):
+ NONE = "None"
+
+
+@dataclass
+class SamplingParams:
+ width: int = 1024
+ height: int = 1024
+ steps: int = 50
+ sampler: Sampler = Sampler.DPMPP2M
+ discretization: Discretization = Discretization.LEGACY_DDPM
+ guider: Guider = Guider.VANILLA
+ thresholder: Thresholder = Thresholder.NONE
+ scale: float = 6.0
+ aesthetic_score: float = 5.0
+ negative_aesthetic_score: float = 5.0
+ img2img_strength: float = 1.0
+ orig_width: int = 1024
+ orig_height: int = 1024
+ crop_coords_top: int = 0
+ crop_coords_left: int = 0
+ sigma_min: float = 0.0292
+ sigma_max: float = 14.6146
+ rho: float = 3.0
+ s_churn: float = 0.0
+ s_tmin: float = 0.0
+ s_tmax: float = 999.0
+ s_noise: float = 1.0
+ eta: float = 1.0
+ order: int = 4
+
+
+@dataclass
+class SamplingSpec:
+ width: int
+ height: int
+ channels: int
+ factor: int
+ is_legacy: bool
+ config: str
+ ckpt: str
+ is_guided: bool
+
+
+model_specs = {
+ ModelArchitecture.SD_2_1: SamplingSpec(
+ height=512,
+ width=512,
+ channels=4,
+ factor=8,
+ is_legacy=True,
+ config="sd_2_1.yaml",
+ ckpt="v2-1_512-ema-pruned.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SD_2_1_768: SamplingSpec(
+ height=768,
+ width=768,
+ channels=4,
+ factor=8,
+ is_legacy=True,
+ config="sd_2_1_768.yaml",
+ ckpt="v2-1_768-ema-pruned.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
+ height=1024,
+ width=1024,
+ channels=4,
+ factor=8,
+ is_legacy=False,
+ config="sd_xl_base.yaml",
+ ckpt="sd_xl_base_0.9.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
+ height=1024,
+ width=1024,
+ channels=4,
+ factor=8,
+ is_legacy=True,
+ config="sd_xl_refiner.yaml",
+ ckpt="sd_xl_refiner_0.9.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
+ height=1024,
+ width=1024,
+ channels=4,
+ factor=8,
+ is_legacy=False,
+ config="sd_xl_base.yaml",
+ ckpt="sd_xl_base_1.0.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
+ height=1024,
+ width=1024,
+ channels=4,
+ factor=8,
+ is_legacy=True,
+ config="sd_xl_refiner.yaml",
+ ckpt="sd_xl_refiner_1.0.safetensors",
+ is_guided=True,
+ ),
+}
+
+
+class SamplingPipeline:
+ def __init__(
+ self,
+ model_id: ModelArchitecture,
+ model_path="checkpoints",
+ config_path="configs/inference",
+ device="cuda",
+ use_fp16=True,
+ ) -> None:
+ if model_id not in model_specs:
+ raise ValueError(f"Model {model_id} not supported")
+ self.model_id = model_id
+ self.specs = model_specs[self.model_id]
+ self.config = str(pathlib.Path(config_path, self.specs.config))
+ self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
+ self.device = device
+ self.model = self._load_model(device=device, use_fp16=use_fp16)
+
+ def _load_model(self, device="cuda", use_fp16=True):
+ config = OmegaConf.load(self.config)
+ model = load_model_from_config(config, self.ckpt)
+ if model is None:
+ raise ValueError(f"Model {self.model_id} could not be loaded")
+ model.to(device)
+ if use_fp16:
+ model.conditioner.half()
+ model.model.half()
+ return model
+
+ def text_to_image(
+ self,
+ params: SamplingParams,
+ prompt: str,
+ negative_prompt: str = "",
+ samples: int = 1,
+ return_latents: bool = False,
+ ):
+ sampler = get_sampler_config(params)
+ value_dict = asdict(params)
+ value_dict["prompt"] = prompt
+ value_dict["negative_prompt"] = negative_prompt
+ value_dict["target_width"] = params.width
+ value_dict["target_height"] = params.height
+ return do_sample(
+ self.model,
+ sampler,
+ value_dict,
+ samples,
+ params.height,
+ params.width,
+ self.specs.channels,
+ self.specs.factor,
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
+ return_latents=return_latents,
+ filter=None,
+ )
+
+ def image_to_image(
+ self,
+ params: SamplingParams,
+ image,
+ prompt: str,
+ negative_prompt: str = "",
+ samples: int = 1,
+ return_latents: bool = False,
+ ):
+ sampler = get_sampler_config(params)
+
+ if params.img2img_strength < 1.0:
+ sampler.discretization = Img2ImgDiscretizationWrapper(
+ sampler.discretization,
+ strength=params.img2img_strength,
+ )
+ height, width = image.shape[2], image.shape[3]
+ value_dict = asdict(params)
+ value_dict["prompt"] = prompt
+ value_dict["negative_prompt"] = negative_prompt
+ value_dict["target_width"] = width
+ value_dict["target_height"] = height
+ return do_img2img(
+ image,
+ self.model,
+ sampler,
+ value_dict,
+ samples,
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
+ return_latents=return_latents,
+ filter=None,
+ )
+
+ def refiner(
+ self,
+ params: SamplingParams,
+ image,
+ prompt: str,
+ negative_prompt: Optional[str] = None,
+ samples: int = 1,
+ return_latents: bool = False,
+ ):
+ sampler = get_sampler_config(params)
+ value_dict = {
+ "orig_width": image.shape[3] * 8,
+ "orig_height": image.shape[2] * 8,
+ "target_width": image.shape[3] * 8,
+ "target_height": image.shape[2] * 8,
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ "crop_coords_top": 0,
+ "crop_coords_left": 0,
+ "aesthetic_score": 6.0,
+ "negative_aesthetic_score": 2.5,
+ }
+
+ return do_img2img(
+ image,
+ self.model,
+ sampler,
+ value_dict,
+ samples,
+ skip_encode=True,
+ return_latents=return_latents,
+ filter=None,
+ )
+
+
+def get_guider_config(params: SamplingParams):
+ if params.guider == Guider.IDENTITY:
+ guider_config = {
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
+ }
+ elif params.guider == Guider.VANILLA:
+ scale = params.scale
+
+ thresholder = params.thresholder
+
+ if thresholder == Thresholder.NONE:
+ dyn_thresh_config = {
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
+ }
+ else:
+ raise NotImplementedError
+
+ guider_config = {
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
+ }
+ else:
+ raise NotImplementedError
+ return guider_config
+
+
+def get_discretization_config(params: SamplingParams):
+ if params.discretization == Discretization.LEGACY_DDPM:
+ discretization_config = {
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
+ }
+ elif params.discretization == Discretization.EDM:
+ discretization_config = {
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
+ "params": {
+ "sigma_min": params.sigma_min,
+ "sigma_max": params.sigma_max,
+ "rho": params.rho,
+ },
+ }
+ else:
+ raise ValueError(f"unknown discretization {params.discretization}")
+ return discretization_config
+
+
+def get_sampler_config(params: SamplingParams):
+ discretization_config = get_discretization_config(params)
+ guider_config = get_guider_config(params)
+ sampler = None
+ if params.sampler == Sampler.EULER_EDM:
+ return EulerEDMSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ s_churn=params.s_churn,
+ s_tmin=params.s_tmin,
+ s_tmax=params.s_tmax,
+ s_noise=params.s_noise,
+ verbose=True,
+ )
+ if params.sampler == Sampler.HEUN_EDM:
+ return HeunEDMSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ s_churn=params.s_churn,
+ s_tmin=params.s_tmin,
+ s_tmax=params.s_tmax,
+ s_noise=params.s_noise,
+ verbose=True,
+ )
+ if params.sampler == Sampler.EULER_ANCESTRAL:
+ return EulerAncestralSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ eta=params.eta,
+ s_noise=params.s_noise,
+ verbose=True,
+ )
+ if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
+ return DPMPP2SAncestralSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ eta=params.eta,
+ s_noise=params.s_noise,
+ verbose=True,
+ )
+ if params.sampler == Sampler.DPMPP2M:
+ return DPMPP2MSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ verbose=True,
+ )
+ if params.sampler == Sampler.LINEAR_MULTISTEP:
+ return LinearMultistepSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ order=params.order,
+ verbose=True,
+ )
+
+ raise ValueError(f"unknown sampler {params.sampler}!")
diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..31b0ec3dc414bf522261e35f73805810cd35582d
--- /dev/null
+++ b/sgm/inference/helpers.py
@@ -0,0 +1,305 @@
+import math
+import os
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from einops import rearrange
+from imwatermark import WatermarkEncoder
+from omegaconf import ListConfig
+from PIL import Image
+from torch import autocast
+
+from sgm.util import append_dims
+
+
+class WatermarkEmbedder:
+ def __init__(self, watermark):
+ self.watermark = watermark
+ self.num_bits = len(WATERMARK_BITS)
+ self.encoder = WatermarkEncoder()
+ self.encoder.set_watermark("bits", self.watermark)
+
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Adds a predefined watermark to the input image
+
+ Args:
+ image: ([N,] B, RGB, H, W) in range [0, 1]
+
+ Returns:
+ same as input but watermarked
+ """
+ squeeze = len(image.shape) == 4
+ if squeeze:
+ image = image[None, ...]
+ n = image.shape[0]
+ image_np = rearrange(
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
+ ).numpy()[:, :, :, ::-1]
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
+ # watermarking libary expects input as cv2 BGR format
+ for k in range(image_np.shape[0]):
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
+ image = torch.from_numpy(
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
+ ).to(image.device)
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
+ if squeeze:
+ image = image[0]
+ return image
+
+
+# A fixed 48-bit message that was choosen at random
+# WATERMARK_MESSAGE = 0xB3EC907BB19E
+WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
+# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
+WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
+embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
+
+
+def get_unique_embedder_keys_from_conditioner(conditioner):
+ return list({x.input_key for x in conditioner.embedders})
+
+
+def perform_save_locally(save_path, samples):
+ os.makedirs(os.path.join(save_path), exist_ok=True)
+ base_count = len(os.listdir(os.path.join(save_path)))
+ samples = embed_watermark(samples)
+ for sample in samples:
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
+ Image.fromarray(sample.astype(np.uint8)).save(
+ os.path.join(save_path, f"{base_count:09}.png")
+ )
+ base_count += 1
+
+
+class Img2ImgDiscretizationWrapper:
+ """
+ wraps a discretizer, and prunes the sigmas
+ params:
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
+ """
+
+ def __init__(self, discretization, strength: float = 1.0):
+ self.discretization = discretization
+ self.strength = strength
+ assert 0.0 <= self.strength <= 1.0
+
+ def __call__(self, *args, **kwargs):
+ # sigmas start large first, and decrease then
+ sigmas = self.discretization(*args, **kwargs)
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
+ sigmas = torch.flip(sigmas, (0,))
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
+ sigmas = torch.flip(sigmas, (0,))
+ print(f"sigmas after pruning: ", sigmas)
+ return sigmas
+
+
+def do_sample(
+ model,
+ sampler,
+ value_dict,
+ num_samples,
+ H,
+ W,
+ C,
+ F,
+ force_uc_zero_embeddings: Optional[List] = None,
+ batch2model_input: Optional[List] = None,
+ return_latents=False,
+ filter=None,
+ device="cuda",
+):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ if batch2model_input is None:
+ batch2model_input = []
+
+ with torch.no_grad():
+ with autocast(device) as precision_scope:
+ with model.ema_scope():
+ num_samples = [num_samples]
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
+ value_dict,
+ num_samples,
+ )
+ for key in batch:
+ if isinstance(batch[key], torch.Tensor):
+ print(key, batch[key].shape)
+ elif isinstance(batch[key], list):
+ print(key, [len(l) for l in batch[key]])
+ else:
+ print(key, batch[key])
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in c:
+ if not k == "crossattn":
+ c[k], uc[k] = map(
+ lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
+ )
+
+ additional_model_inputs = {}
+ for k in batch2model_input:
+ additional_model_inputs[k] = batch[k]
+
+ shape = (math.prod(num_samples), C, H // F, W // F)
+ randn = torch.randn(shape).to(device)
+
+ def denoiser(input, sigma, c):
+ return model.denoiser(
+ model.model, input, sigma, c, **additional_model_inputs
+ )
+
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
+ samples_x = model.decode_first_stage(samples_z)
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if filter is not None:
+ samples = filter(samples)
+
+ if return_latents:
+ return samples, samples_z
+ return samples
+
+
+def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
+ # Hardcoded demo setups; might undergo some changes in the future
+
+ batch = {}
+ batch_uc = {}
+
+ for key in keys:
+ if key == "txt":
+ batch["txt"] = (
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
+ .reshape(N)
+ .tolist()
+ )
+ batch_uc["txt"] = (
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
+ .reshape(N)
+ .tolist()
+ )
+ elif key == "original_size_as_tuple":
+ batch["original_size_as_tuple"] = (
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+ elif key == "crop_coords_top_left":
+ batch["crop_coords_top_left"] = (
+ torch.tensor(
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
+ )
+ .to(device)
+ .repeat(*N, 1)
+ )
+ elif key == "aesthetic_score":
+ batch["aesthetic_score"] = (
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
+ )
+ batch_uc["aesthetic_score"] = (
+ torch.tensor([value_dict["negative_aesthetic_score"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+
+ elif key == "target_size_as_tuple":
+ batch["target_size_as_tuple"] = (
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+ else:
+ batch[key] = value_dict[key]
+
+ for key in batch.keys():
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
+ batch_uc[key] = torch.clone(batch[key])
+ return batch, batch_uc
+
+
+def get_input_image_tensor(image: Image.Image, device="cuda"):
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h})")
+ width, height = map(
+ lambda x: x - x % 64, (w, h)
+ ) # resize to integer multiple of 64
+ image = image.resize((width, height))
+ image_array = np.array(image.convert("RGB"))
+ image_array = image_array[None].transpose(0, 3, 1, 2)
+ image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
+ return image_tensor.to(device)
+
+
+def do_img2img(
+ img,
+ model,
+ sampler,
+ value_dict,
+ num_samples,
+ force_uc_zero_embeddings=[],
+ additional_kwargs={},
+ offset_noise_level: float = 0.0,
+ return_latents=False,
+ skip_encode=False,
+ filter=None,
+ device="cuda",
+):
+ with torch.no_grad():
+ with autocast(device) as precision_scope:
+ with model.ema_scope():
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
+ value_dict,
+ [num_samples],
+ )
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in c:
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
+
+ for k in additional_kwargs:
+ c[k] = uc[k] = additional_kwargs[k]
+ if skip_encode:
+ z = img
+ else:
+ z = model.encode_first_stage(img)
+ noise = torch.randn_like(z)
+ sigmas = sampler.discretization(sampler.num_steps)
+ sigma = sigmas[0].to(z.device)
+
+ if offset_noise_level > 0.0:
+ noise = noise + offset_noise_level * append_dims(
+ torch.randn(z.shape[0], device=z.device), z.ndim
+ )
+ noised_z = z + noise * append_dims(sigma, z.ndim)
+ noised_z = noised_z / torch.sqrt(
+ 1.0 + sigmas[0] ** 2.0
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
+
+ def denoiser(x, sigma, c):
+ return model.denoiser(model.model, x, sigma, c)
+
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
+ samples_x = model.decode_first_stage(samples_z)
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if filter is not None:
+ samples = filter(samples)
+
+ if return_latents:
+ return samples, samples_z
+ return samples
diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f4d384c1fcaff0df13e0564450d3fa972ace42
--- /dev/null
+++ b/sgm/lr_scheduler.py
@@ -0,0 +1,135 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+
+ def __init__(
+ self,
+ warm_up_steps,
+ lr_min,
+ lr_max,
+ lr_start,
+ max_decay_steps,
+ verbosity_interval=0,
+ ):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (
+ self.lr_max - self.lr_start
+ ) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (
+ self.lr_max_decay_steps - self.lr_warm_up_steps
+ )
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+
+ def __init__(
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
+ ):
+ assert (
+ len(warm_up_steps)
+ == len(f_min)
+ == len(f_max)
+ == len(f_start)
+ == len(cycle_lengths)
+ )
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
+ )
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
+ self.cycle_lengths[cycle] - n
+ ) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
diff --git a/sgm/models/__init__.py b/sgm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c410b3747afc208e4204c8f140170e0a7808eace
--- /dev/null
+++ b/sgm/models/__init__.py
@@ -0,0 +1,2 @@
+from .autoencoder import AutoencodingEngine
+from .diffusion import DiffusionEngine
diff --git a/sgm/models/__pycache__/__init__.cpython-311.pyc b/sgm/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..25cd0a406e03365fc11a32e2f1876ec2c4fdf78b
Binary files /dev/null and b/sgm/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/models/__pycache__/autoencoder.cpython-311.pyc b/sgm/models/__pycache__/autoencoder.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..650163f210cf872ab7fc18df8eff91fc39ae604f
Binary files /dev/null and b/sgm/models/__pycache__/autoencoder.cpython-311.pyc differ
diff --git a/sgm/models/__pycache__/diffusion.cpython-311.pyc b/sgm/models/__pycache__/diffusion.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f27540136e40a36aca31c003640502668d801490
Binary files /dev/null and b/sgm/models/__pycache__/diffusion.cpython-311.pyc differ
diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2949b91011a2be7a6b8ca17ce260812f20ce8b75
--- /dev/null
+++ b/sgm/models/autoencoder.py
@@ -0,0 +1,615 @@
+import logging
+import math
+import re
+from abc import abstractmethod
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+from einops import rearrange
+from packaging import version
+
+from ..modules.autoencoding.regularizers import AbstractRegularizer
+from ..modules.ema import LitEma
+from ..util import (default, get_nested_attribute, get_obj_from_str,
+ instantiate_from_config)
+
+logpy = logging.getLogger(__name__)
+
+
+class AbstractAutoencoder(pl.LightningModule):
+ """
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
+ unCLIP models, etc. Hence, it is fairly general, and specific features
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
+ """
+
+ def __init__(
+ self,
+ ema_decay: Union[None, float] = None,
+ monitor: Union[None, str] = None,
+ input_key: str = "jpg",
+ ):
+ super().__init__()
+
+ self.input_key = input_key
+ self.use_ema = ema_decay is not None
+ if monitor is not None:
+ self.monitor = monitor
+
+ if self.use_ema:
+ self.model_ema = LitEma(self, decay=ema_decay)
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ self.automatic_optimization = False
+
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
+ if ckpt is None:
+ return
+ if isinstance(ckpt, str):
+ ckpt = {
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
+ "params": {"ckpt_path": ckpt},
+ }
+ engine = instantiate_from_config(ckpt)
+ engine(self)
+
+ @abstractmethod
+ def get_input(self, batch) -> Any:
+ raise NotImplementedError()
+
+ def on_train_batch_end(self, *args, **kwargs):
+ # for EMA computation
+ if self.use_ema:
+ self.model_ema(self)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ logpy.info(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ logpy.info(f"{context}: Restored training weights")
+
+ @abstractmethod
+ def encode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("encode()-method of abstract base class called")
+
+ @abstractmethod
+ def decode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("decode()-method of abstract base class called")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self) -> Any:
+ raise NotImplementedError()
+
+
+class AutoencodingEngine(AbstractAutoencoder):
+ """
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
+ (we also restore them explicitly as special cases for legacy reasons).
+ Regularizations such as KL or VQ are moved to the regularizer class.
+ """
+
+ def __init__(
+ self,
+ *args,
+ encoder_config: Dict,
+ decoder_config: Dict,
+ loss_config: Dict,
+ regularizer_config: Dict,
+ optimizer_config: Union[Dict, None] = None,
+ lr_g_factor: float = 1.0,
+ trainable_ae_params: Optional[List[List[str]]] = None,
+ ae_optimizer_args: Optional[List[dict]] = None,
+ trainable_disc_params: Optional[List[List[str]]] = None,
+ disc_optimizer_args: Optional[List[dict]] = None,
+ disc_start_iter: int = 0,
+ diff_boost_factor: float = 3.0,
+ ckpt_engine: Union[None, str, dict] = None,
+ ckpt_path: Optional[str] = None,
+ additional_decode_keys: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.automatic_optimization = False # pytorch lightning
+
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
+ self.regularization: AbstractRegularizer = instantiate_from_config(
+ regularizer_config
+ )
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.Adam"}
+ )
+ self.diff_boost_factor = diff_boost_factor
+ self.disc_start_iter = disc_start_iter
+ self.lr_g_factor = lr_g_factor
+ self.trainable_ae_params = trainable_ae_params
+ if self.trainable_ae_params is not None:
+ self.ae_optimizer_args = default(
+ ae_optimizer_args,
+ [{} for _ in range(len(self.trainable_ae_params))],
+ )
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
+ else:
+ self.ae_optimizer_args = [{}] # makes type consitent
+
+ self.trainable_disc_params = trainable_disc_params
+ if self.trainable_disc_params is not None:
+ self.disc_optimizer_args = default(
+ disc_optimizer_args,
+ [{} for _ in range(len(self.trainable_disc_params))],
+ )
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
+ else:
+ self.disc_optimizer_args = [{}] # makes type consitent
+
+ if ckpt_path is not None:
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
+
+ def get_input(self, batch: Dict) -> torch.Tensor:
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in channels-first
+ # format (e.g., bchw instead if bhwc)
+ return batch[self.input_key]
+
+ def get_autoencoder_params(self) -> list:
+ params = []
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
+ params += list(self.loss.get_trainable_autoencoder_parameters())
+ if hasattr(self.regularization, "get_trainable_parameters"):
+ params += list(self.regularization.get_trainable_parameters())
+ params = params + list(self.encoder.parameters())
+ params = params + list(self.decoder.parameters())
+ return params
+
+ def get_discriminator_params(self) -> list:
+ if hasattr(self.loss, "get_trainable_parameters"):
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
+ else:
+ params = []
+ return params
+
+ def get_last_layer(self):
+ return self.decoder.get_last_layer()
+
+ def encode(
+ self,
+ x: torch.Tensor,
+ return_reg_log: bool = False,
+ unregularized: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ z = self.encoder(x)
+ if unregularized:
+ return z, dict()
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
+ x = self.decoder(z, **kwargs)
+ return x
+
+ def forward(
+ self, x: torch.Tensor, **additional_decode_kwargs
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ z, reg_log = self.encode(x, return_reg_log=True)
+ dec = self.decode(z, **additional_decode_kwargs)
+ return z, dec, reg_log
+
+ def inner_training_step(
+ self, batch: dict, batch_idx: int, optimizer_idx: int = 0
+ ) -> torch.Tensor:
+ x = self.get_input(batch)
+ additional_decode_kwargs = {
+ key: batch[key] for key in self.additional_decode_keys.intersection(batch)
+ }
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
+ if hasattr(self.loss, "forward_keys"):
+ extra_info = {
+ "z": z,
+ "optimizer_idx": optimizer_idx,
+ "global_step": self.global_step,
+ "last_layer": self.get_last_layer(),
+ "split": "train",
+ "regularization_log": regularization_log,
+ "autoencoder": self,
+ }
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
+ else:
+ extra_info = dict()
+
+ if optimizer_idx == 0:
+ # autoencode
+ out_loss = self.loss(x, xrec, **extra_info)
+ if isinstance(out_loss, tuple):
+ aeloss, log_dict_ae = out_loss
+ else:
+ # simple loss function
+ aeloss = out_loss
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
+
+ self.log_dict(
+ log_dict_ae,
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ sync_dist=False,
+ )
+ self.log(
+ "loss",
+ aeloss.mean().detach(),
+ prog_bar=True,
+ logger=False,
+ on_epoch=False,
+ on_step=True,
+ )
+ return aeloss
+ elif optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
+ # -> discriminator always needs to return a tuple
+ self.log_dict(
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
+ )
+ return discloss
+ else:
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
+
+ def training_step(self, batch: dict, batch_idx: int):
+ opts = self.optimizers()
+ if not isinstance(opts, list):
+ # Non-adversarial case
+ opts = [opts]
+ optimizer_idx = batch_idx % len(opts)
+ if self.global_step < self.disc_start_iter:
+ optimizer_idx = 0
+ opt = opts[optimizer_idx]
+ opt.zero_grad()
+ with opt.toggle_model():
+ loss = self.inner_training_step(
+ batch, batch_idx, optimizer_idx=optimizer_idx
+ )
+ self.manual_backward(loss)
+ opt.step()
+
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ log_dict.update(log_dict_ema)
+ return log_dict
+
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
+ x = self.get_input(batch)
+
+ z, xrec, regularization_log = self(x)
+ if hasattr(self.loss, "forward_keys"):
+ extra_info = {
+ "z": z,
+ "optimizer_idx": 0,
+ "global_step": self.global_step,
+ "last_layer": self.get_last_layer(),
+ "split": "val" + postfix,
+ "regularization_log": regularization_log,
+ "autoencoder": self,
+ }
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
+ else:
+ extra_info = dict()
+ out_loss = self.loss(x, xrec, **extra_info)
+ if isinstance(out_loss, tuple):
+ aeloss, log_dict_ae = out_loss
+ else:
+ # simple loss function
+ aeloss = out_loss
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
+ full_log_dict = log_dict_ae
+
+ if "optimizer_idx" in extra_info:
+ extra_info["optimizer_idx"] = 1
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
+ full_log_dict.update(log_dict_disc)
+ self.log(
+ f"val{postfix}/loss/rec",
+ log_dict_ae[f"val{postfix}/loss/rec"],
+ sync_dist=True,
+ )
+ self.log_dict(full_log_dict, sync_dist=True)
+ return full_log_dict
+
+ def get_param_groups(
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ groups = []
+ num_params = 0
+ for names, args in zip(parameter_names, optimizer_args):
+ params = []
+ for pattern_ in names:
+ pattern_params = []
+ pattern = re.compile(pattern_)
+ for p_name, param in self.named_parameters():
+ if re.match(pattern, p_name):
+ pattern_params.append(param)
+ num_params += param.numel()
+ if len(pattern_params) == 0:
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
+ params.extend(pattern_params)
+ groups.append({"params": params, **args})
+ return groups, num_params
+
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
+ if self.trainable_ae_params is None:
+ ae_params = self.get_autoencoder_params()
+ else:
+ ae_params, num_ae_params = self.get_param_groups(
+ self.trainable_ae_params, self.ae_optimizer_args
+ )
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
+ if self.trainable_disc_params is None:
+ disc_params = self.get_discriminator_params()
+ else:
+ disc_params, num_disc_params = self.get_param_groups(
+ self.trainable_disc_params, self.disc_optimizer_args
+ )
+ logpy.info(
+ f"Number of trainable discriminator parameters: {num_disc_params:,}"
+ )
+ opt_ae = self.instantiate_optimizer_from_config(
+ ae_params,
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
+ self.optimizer_config,
+ )
+ opts = [opt_ae]
+ if len(disc_params) > 0:
+ opt_disc = self.instantiate_optimizer_from_config(
+ disc_params, self.learning_rate, self.optimizer_config
+ )
+ opts.append(opt_disc)
+
+ return opts
+
+ @torch.no_grad()
+ def log_images(
+ self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
+ ) -> dict:
+ log = dict()
+ additional_decode_kwargs = {}
+ x = self.get_input(batch)
+ additional_decode_kwargs.update(
+ {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
+ )
+
+ _, xrec, _ = self(x, **additional_decode_kwargs)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
+ diff.clamp_(0, 1.0)
+ log["diff"] = 2.0 * diff - 1.0
+ # diff_boost shows location of small errors, by boosting their
+ # brightness.
+ log["diff_boost"] = (
+ 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
+ )
+ if hasattr(self.loss, "log_images"):
+ log.update(self.loss.log_images(x, xrec))
+ with self.ema_scope():
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
+ log["reconstructions_ema"] = xrec_ema
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
+ diff_ema.clamp_(0, 1.0)
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
+ log["diff_boost_ema"] = (
+ 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
+ )
+ if additional_log_kwargs:
+ additional_decode_kwargs.update(additional_log_kwargs)
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
+ log_str = "reconstructions-" + "-".join(
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
+ )
+ log[log_str] = xrec_add
+ return log
+
+
+class AutoencodingEngineLegacy(AutoencodingEngine):
+ def __init__(self, embed_dim: int, **kwargs):
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
+ ddconfig = kwargs.pop("ddconfig")
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
+ super().__init__(
+ encoder_config={
+ "target": "sgm.modules.diffusionmodules.model.Encoder",
+ "params": ddconfig,
+ },
+ decoder_config={
+ "target": "sgm.modules.diffusionmodules.model.Decoder",
+ "params": ddconfig,
+ },
+ **kwargs,
+ )
+ self.quant_conv = torch.nn.Conv2d(
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
+ (1 + ddconfig["double_z"]) * embed_dim,
+ 1,
+ )
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
+
+ def get_autoencoder_params(self) -> list:
+ params = super().get_autoencoder_params()
+ return params
+
+ def encode(
+ self, x: torch.Tensor, return_reg_log: bool = False
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ if self.max_batch_size is None:
+ z = self.encoder(x)
+ z = self.quant_conv(z)
+ else:
+ N = x.shape[0]
+ bs = self.max_batch_size
+ n_batches = int(math.ceil(N / bs))
+ z = list()
+ for i_batch in range(n_batches):
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
+ z_batch = self.quant_conv(z_batch)
+ z.append(z_batch)
+ z = torch.cat(z, 0)
+
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
+ if self.max_batch_size is None:
+ dec = self.post_quant_conv(z)
+ dec = self.decoder(dec, **decoder_kwargs)
+ else:
+ N = z.shape[0]
+ bs = self.max_batch_size
+ n_batches = int(math.ceil(N / bs))
+ dec = list()
+ for i_batch in range(n_batches):
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
+ dec.append(dec_batch)
+ dec = torch.cat(dec, 0)
+
+ return dec
+
+
+class AutoencoderKL(AutoencodingEngineLegacy):
+ def __init__(self, **kwargs):
+ if "lossconfig" in kwargs:
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
+ super().__init__(
+ regularizer_config={
+ "target": (
+ "sgm.modules.autoencoding.regularizers"
+ ".DiagonalGaussianRegularizer"
+ )
+ },
+ **kwargs,
+ )
+
+
+class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
+ def __init__(
+ self,
+ embed_dim: int,
+ n_embed: int,
+ sane_index_shape: bool = False,
+ **kwargs,
+ ):
+ if "lossconfig" in kwargs:
+ logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
+ super().__init__(
+ regularizer_config={
+ "target": (
+ "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
+ ),
+ "params": {
+ "n_e": n_embed,
+ "e_dim": embed_dim,
+ "sane_index_shape": sane_index_shape,
+ },
+ },
+ **kwargs,
+ )
+
+
+class IdentityFirstStage(AbstractAutoencoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get_input(self, x: Any) -> Any:
+ return x
+
+ def encode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+ def decode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+
+class AEIntegerWrapper(nn.Module):
+ def __init__(
+ self,
+ model: nn.Module,
+ shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
+ regularization_key: str = "regularization",
+ encoder_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ super().__init__()
+ self.model = model
+ assert hasattr(model, "encode") and hasattr(
+ model, "decode"
+ ), "Need AE interface"
+ self.regularization = get_nested_attribute(model, regularization_key)
+ self.shape = shape
+ self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
+
+ def encode(self, x) -> torch.Tensor:
+ assert (
+ not self.training
+ ), f"{self.__class__.__name__} only supports inference currently"
+ _, log = self.model.encode(x, **self.encoder_kwargs)
+ assert isinstance(log, dict)
+ inds = log["min_encoding_indices"]
+ return rearrange(inds, "b ... -> b (...)")
+
+ def decode(
+ self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
+ ) -> torch.Tensor:
+ # expect inds shape (b, s) with s = h*w
+ shape = default(shape, self.shape) # Optional[(h, w)]
+ if shape is not None:
+ assert len(shape) == 2, f"Unhandeled shape {shape}"
+ inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
+ h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
+ h = rearrange(h, "b h w c -> b c h w")
+ return self.model.decode(h)
+
+
+class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
+ def __init__(self, **kwargs):
+ if "lossconfig" in kwargs:
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
+ super().__init__(
+ regularizer_config={
+ "target": (
+ "sgm.modules.autoencoding.regularizers"
+ ".DiagonalGaussianRegularizer"
+ ),
+ "params": {"sample": False},
+ },
+ **kwargs,
+ )
diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aa46e869478abb981595cc0e81ee3c2ecb37793
--- /dev/null
+++ b/sgm/models/diffusion.py
@@ -0,0 +1,747 @@
+import os
+import math
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple, Union
+import re
+import pytorch_lightning as pl
+import torch
+from omegaconf import ListConfig, OmegaConf
+from safetensors.torch import load_file as load_safetensors
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange
+from diffusers.models.attention_processor import IPAdapterAttnProcessor2_0
+
+from ..modules import UNCONDITIONAL_CONFIG
+from ..modules.autoencoding.temporal_ae import VideoDecoder
+from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
+from ..modules.ema import LitEma
+from ..util import (
+ default,
+ disabled_train,
+ get_obj_from_str,
+ instantiate_from_config,
+ log_txt_as_img,
+)
+
+
+class DiffusionEngine(pl.LightningModule):
+ def __init__(
+ self,
+ network_config,
+ denoiser_config,
+ first_stage_config,
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ network_wrapper: Union[None, str, Dict, ListConfig, OmegaConf] = None,
+ ckpt_path: Union[None, str] = None,
+ remove_keys_from_weights: Union[None, List, Tuple] = None,
+ pattern_to_remove: Union[None, str] = None,
+ remove_keys_from_unet_weights: Union[None, List, Tuple] = None,
+ use_ema: bool = False,
+ ema_decay_rate: float = 0.9999,
+ scale_factor: float = 1.0,
+ disable_first_stage_autocast=False,
+ input_key: str = "jpg",
+ log_keys: Union[List, None] = None,
+ no_log_keys: Union[List, None] = None,
+ no_cond_log: bool = False,
+ compile_model: bool = False,
+ en_and_decode_n_samples_a_time: Optional[int] = None,
+ only_train_ipadapter: Optional[bool] = False,
+ to_unfreeze: Optional[List[str]] = [],
+ to_freeze: Optional[List[str]] = [],
+ separate_unet_ckpt: Optional[str] = None,
+ use_thunder: Optional[bool] = False,
+ is_dubbing: Optional[bool] = False,
+ bad_model_path: Optional[str] = None,
+ bad_model_config: Optional[Dict] = None,
+ ):
+ super().__init__()
+
+ # self.automatic_optimization = False
+ self.log_keys = log_keys
+ self.no_log_keys = no_log_keys
+ self.input_key = input_key
+ self.is_dubbing = is_dubbing
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.AdamW"}
+ )
+ self.model = self.initialize_network(
+ network_config, network_wrapper, compile_model=compile_model
+ )
+
+ self.denoiser = instantiate_from_config(denoiser_config)
+
+ self.sampler = (
+ instantiate_from_config(sampler_config)
+ if sampler_config is not None
+ else None
+ )
+ self.is_guided = True
+ if (
+ self.sampler
+ and "IdentityGuider" in sampler_config["params"]["guider_config"]["target"]
+ ):
+ self.is_guided = False
+ if self.sampler is not None:
+ config_guider = sampler_config["params"]["guider_config"]
+ sampler_config["params"]["guider_config"] = None
+ self.sampler_no_guidance = instantiate_from_config(sampler_config)
+ sampler_config["params"]["guider_config"] = config_guider
+ self.conditioner = instantiate_from_config(
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
+ )
+ self.scheduler_config = scheduler_config
+ self._init_first_stage(first_stage_config)
+
+ self.loss_fn = (
+ instantiate_from_config(loss_fn_config)
+ if loss_fn_config is not None
+ else None
+ )
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.scale_factor = scale_factor
+ self.disable_first_stage_autocast = disable_first_stage_autocast
+ self.no_cond_log = no_cond_log
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(
+ ckpt_path,
+ remove_keys_from_weights=remove_keys_from_weights,
+ pattern_to_remove=pattern_to_remove,
+ )
+ if separate_unet_ckpt is not None:
+ sd = torch.load(separate_unet_ckpt)["state_dict"]
+ if remove_keys_from_unet_weights is not None:
+ for k in list(sd.keys()):
+ for remove_key in remove_keys_from_unet_weights:
+ if remove_key in k:
+ del sd[k]
+ self.model.diffusion_model.load_state_dict(sd, strict=False)
+
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
+ print(
+ "Using",
+ self.en_and_decode_n_samples_a_time,
+ "samples at a time for encoding and decoding",
+ )
+
+ if to_freeze:
+ for name, p in self.model.diffusion_model.named_parameters():
+ for layer in to_freeze:
+ if layer[0] == "!":
+ if layer[1:] not in name:
+ # print("Freezing", name)
+ p.requires_grad = False
+ else:
+ if layer in name:
+ # print("Freezing", name)
+ p.requires_grad = False
+ # if "time_" in name:
+ # print("Freezing", name)
+ # p.requires_grad = False
+
+ if only_train_ipadapter:
+ # Freeze the model
+ for p in self.model.parameters():
+ p.requires_grad = False
+ # Unfreeze the adapter projection layer
+ for p in self.model.diffusion_model.encoder_hid_proj.parameters():
+ p.requires_grad = True
+ # Unfreeze the cross-attention layer
+ for att_layer in self.model.diffusion_model.attn_processors.values():
+ if isinstance(att_layer, IPAdapterAttnProcessor2_0):
+ for p in att_layer.parameters():
+ p.requires_grad = True
+
+ # for name, p in self.named_parameters():
+ # if p.requires_grad:
+ # print(name)
+
+ if to_unfreeze:
+ for name in to_unfreeze:
+ for p in getattr(self.model.diffusion_model, name).parameters():
+ p.requires_grad = True
+
+ if use_thunder:
+ import thunder
+
+ self.model.diffusion_model = thunder.jit(self.model.diffusion_model)
+
+ if "Karras" in denoiser_config.target:
+ assert bad_model_path is not None, (
+ "bad_model_path must be provided for KarrasGuidanceDenoiser"
+ )
+ karras_config = default(bad_model_config, network_config)
+ bad_model = self.initialize_network(
+ karras_config, network_wrapper, compile_model=compile_model
+ )
+ state_dict = self.load_bad_model_weights(bad_model_path)
+ bad_model.load_state_dict(state_dict)
+ self.denoiser.set_bad_network(bad_model)
+
+ def load_bad_model_weights(self, path: str) -> None:
+ print(f"Restoring bad model from {path}")
+ state_dict = torch.load(path, map_location="cpu")
+ new_dict = {}
+ for k, v in state_dict["module"].items():
+ if "learned_mask" in k:
+ new_dict[k.replace("_forward_module.", "").replace("model.", "")] = v
+ if "diffusion_model" in k:
+ new_dict["diffusion_model" + k.split("diffusion_model")[1]] = v
+ return new_dict
+
+ def initialize_network(self, network_config, network_wrapper, compile_model=False):
+ model = instantiate_from_config(network_config)
+ if isinstance(network_wrapper, str) or network_wrapper is None:
+ model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
+ model, compile_model=compile_model
+ )
+ else:
+ target = network_wrapper["target"]
+ params = network_wrapper.get("params", dict())
+ model = get_obj_from_str(target)(
+ model, compile_model=compile_model, **params
+ )
+ return model
+
+ def init_from_ckpt(
+ self,
+ path: str,
+ remove_keys_from_weights: Optional[Union[List, Tuple]] = None,
+ pattern_to_remove: str = None,
+ ) -> None:
+ print(f"Restoring from {path}")
+ if path.endswith("ckpt"):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ elif path.endswith("pt"):
+ sd = torch.load(path, map_location="cpu")["module"]
+ # Remove leading _forward_module from keys
+ sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
+ elif path.endswith("bin"):
+ sd = torch.load(path, map_location="cpu")
+ # Remove leading _forward_module from keys
+ sd = {k.replace("_forward_module.", ""): v for k, v in sd.items()}
+ elif path.endswith("safetensors"):
+ sd = load_safetensors(path)
+ else:
+ raise NotImplementedError
+
+ print(f"Loaded state dict from {path} with {len(sd)} keys")
+
+ # if remove_keys_from_weights is not None:
+ # for k in list(sd.keys()):
+ # for remove_key in remove_keys_from_weights:
+ # if remove_key in k:
+ # del sd[k]
+ if pattern_to_remove is not None or remove_keys_from_weights is not None:
+ sd = self.remove_mismatched_keys(
+ sd, pattern_to_remove, remove_keys_from_weights
+ )
+
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def remove_mismatched_keys(self, state_dict, pattern=None, additional_keys=None):
+ """Remove keys from the state dictionary based on a pattern and a list of additional specific keys."""
+ # Find keys that match the pattern
+ if pattern is not None:
+ mismatched_keys = [key for key in state_dict if re.search(pattern, key)]
+ else:
+ mismatched_keys = []
+
+ print(f"Removing {len(mismatched_keys)} keys based on pattern {pattern}")
+ print(mismatched_keys)
+
+ # Add specific keys to be removed
+ if additional_keys:
+ mismatched_keys.extend(
+ [key for key in additional_keys if key in state_dict]
+ )
+
+ # Remove all identified keys
+ for key in mismatched_keys:
+ if key in state_dict:
+ del state_dict[key]
+ return state_dict
+
+ def _init_first_stage(self, config):
+ model = instantiate_from_config(config).eval()
+ model.train = disabled_train
+ for param in model.parameters():
+ param.requires_grad = False
+ self.first_stage_model = model
+ if self.input_key == "latents":
+ # Remove encoder to save memory
+ self.first_stage_model.encoder = None
+ torch.cuda.empty_cache()
+
+ def get_input(self, batch):
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in bchw format
+ return batch[self.input_key]
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ is_video = False
+ if len(z.shape) == 5:
+ is_video = True
+ T = z.shape[2]
+ z = rearrange(z, "b c t h w -> (b t) c h w")
+
+ z = 1.0 / self.scale_factor * z
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
+
+ n_rounds = math.ceil(z.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
+ else:
+ kwargs = {}
+ out = self.first_stage_model.decode(
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
+ )
+ all_out.append(out)
+ out = torch.cat(all_out, dim=0)
+ if is_video:
+ out = rearrange(out, "(b t) c h w -> b c t h w", t=T)
+ torch.cuda.empty_cache()
+ return out
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ is_video = False
+ if len(x.shape) == 5:
+ is_video = True
+ T = x.shape[2]
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
+ n_rounds = math.ceil(x.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ out = self.first_stage_model.encode(
+ x[n * n_samples : (n + 1) * n_samples]
+ )
+ all_out.append(out)
+ z = torch.cat(all_out, dim=0)
+ z = self.scale_factor * z
+ if is_video:
+ z = rearrange(z, "(b t) c h w -> b c t h w", t=T)
+ return z
+
+ def forward(self, x, batch):
+ loss_dict = self.loss_fn(
+ self.model,
+ self.denoiser,
+ self.conditioner,
+ x,
+ batch,
+ self.first_stage_model,
+ )
+ # loss_mean = loss.mean()
+ for k in loss_dict:
+ loss_dict[k] = loss_dict[k].mean()
+ # loss_dict = {"loss": loss_mean}
+ return loss_dict["loss"], loss_dict
+
+ def shared_step(self, batch: Dict) -> Any:
+ x = self.get_input(batch)
+ if self.input_key != "latents":
+ x = self.encode_first_stage(x)
+ batch["global_step"] = self.global_step
+ loss, loss_dict = self(x, batch)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+ # debugging_message = "Training step"
+ # print(f"RANK - {self.trainer.global_rank}: {debugging_message}")
+
+ self.log_dict(
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ self.log(
+ "global_step",
+ self.global_step,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ # debugging_message = "Training step - log"
+ # print(f"RANK - {self.trainer.global_rank}: {debugging_message}")
+
+ if self.scheduler_config is not None:
+ lr = self.optimizers().param_groups[0]["lr"]
+ self.log(
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ # # to prevent other processes from moving forward until all processes are in sync
+ # self.trainer.strategy.barrier()
+
+ return loss
+
+ # def validation_step(self, batch, batch_idx):
+ # # loss, loss_dict = self.shared_step(batch)
+ # # self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+ # self.log(
+ # "global_step",
+ # self.global_step,
+ # prog_bar=True,
+ # logger=True,
+ # on_step=True,
+ # on_epoch=False,
+ # )
+ # return 0
+
+ # def on_train_epoch_start(self, *args, **kwargs):
+ # print(f"RANK - {self.trainer.global_rank}: on_train_epoch_start")
+
+ def on_train_start(self, *args, **kwargs):
+ # os.environ["CUDA_VISIBLE_DEVICES"] = str(self.trainer.global_rank)
+ # torch.cuda.set_device(self.trainer.global_rank)
+ # torch.cuda.empty_cache()
+ if self.sampler is None or self.loss_fn is None:
+ raise ValueError("Sampler and loss function need to be set for training.")
+
+ # def on_before_batch_transfer(self, batch, dataloader_idx):
+ # print(f"RANK - {self.trainer.global_rank}: on_before_batch_transfer - {dataloader_idx}")
+ # return batch
+
+ # def on_after_batch_transfer(self, batch, dataloader_idx):
+ # print(f"RANK - {self.trainer.global_rank}: on_after_batch_transfer - {dataloader_idx}")
+ # return batch
+
+ def on_train_batch_end(self, *args, **kwargs):
+ # print(f"RANK - {self.trainer.global_rank}: on_train_batch_end")
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ for embedder in self.conditioner.embedders:
+ if embedder.is_trainable:
+ params = params + list(embedder.parameters())
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+ "interval": "step",
+ "frequency": 1,
+ }
+ ]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond: Dict,
+ uc: Union[Dict, None] = None,
+ batch_size: int = 16,
+ shape: Union[None, Tuple, List] = None,
+ **kwargs,
+ ):
+ randn = torch.randn(batch_size, *shape).to(self.device)
+
+ denoiser = lambda input, sigma, c: self.denoiser(
+ self.model, input, sigma, c, **kwargs
+ )
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
+
+ return samples
+
+ @torch.no_grad()
+ def sample_no_guider(
+ self,
+ cond: Dict,
+ uc: Union[Dict, None] = None,
+ batch_size: int = 16,
+ shape: Union[None, Tuple, List] = None,
+ **kwargs,
+ ):
+ randn = torch.randn(batch_size, *shape).to(self.device)
+
+ denoiser = lambda input, sigma, c: self.denoiser(
+ self.model, input, sigma, c, **kwargs
+ )
+ samples = self.sampler_no_guidance(denoiser, randn, cond, uc=uc)
+
+ return samples
+
+ @torch.no_grad()
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
+ """
+ Defines heuristics to log different conditionings.
+ These can be lists of strings (text-to-image), tensors, ints, ...
+ """
+ image_h, image_w = batch[self.input_key].shape[-2:]
+ log = dict()
+
+ for embedder in self.conditioner.embedders:
+ if (
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
+ ) and not self.no_cond_log:
+ if embedder.input_key in self.no_log_keys:
+ continue
+ x = batch[embedder.input_key][:n]
+ if isinstance(x, torch.Tensor):
+ if x.dim() == 1:
+ # class-conditional, convert integer to string
+ x = [str(x[i].item()) for i in range(x.shape[0])]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
+ elif x.dim() == 2:
+ # size and crop cond and the like
+ x = [
+ "x".join([str(xx) for xx in x[i].tolist()])
+ for i in range(x.shape[0])
+ ]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ elif x.dim() == 4: # already an image
+ xc = x
+ elif x.dim() == 5:
+ xc = torch.cat([x[:, :, i] for i in range(x.shape[2])], dim=-1)
+ else:
+ print(x.shape, embedder.input_key)
+ raise NotImplementedError()
+ elif isinstance(x, (List, ListConfig)):
+ if isinstance(x[0], str):
+ # strings
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError()
+ log[embedder.input_key] = xc
+ return log
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch: Dict,
+ N: int = 8,
+ sample: bool = True,
+ ucg_keys: List[str] = None,
+ **kwargs,
+ ) -> Dict:
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
+ if ucg_keys:
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
+ )
+ else:
+ ucg_keys = conditioner_input_keys
+ log = dict()
+
+ x = self.get_input(batch)
+
+ c, uc = self.conditioner.get_unconditional_conditioning(
+ batch,
+ force_uc_zero_embeddings=ucg_keys
+ if len(self.conditioner.embedders) > 0
+ else [],
+ )
+
+ sampling_kwargs = {}
+
+ N = min(x.shape[0], N)
+ x = x.to(self.device)[:N]
+ if self.input_key != "latents":
+ log["inputs"] = x
+ z = self.encode_first_stage(x)
+ else:
+ z = x
+ log["reconstructions"] = self.decode_first_stage(z)
+ log.update(self.log_conditionings(batch, N))
+
+ for k in c:
+ if isinstance(c[k], torch.Tensor):
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
+
+ if sample:
+ with self.ema_scope("Plotting"):
+ samples = self.sample(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+
+ log["samples"] = samples
+
+ with self.ema_scope("Plotting"):
+ samples = self.sample_no_guider(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+
+ log["samples_no_guidance"] = samples
+ return log
+
+ @torch.no_grad()
+ def log_videos(
+ self,
+ batch: Dict,
+ N: int = 8,
+ sample: bool = True,
+ ucg_keys: List[str] = None,
+ **kwargs,
+ ) -> Dict:
+ # conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
+ # if ucg_keys:
+ # assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
+ # "Each defined ucg key for sampling must be in the provided conditioner input keys,"
+ # f"but we have {ucg_keys} vs. {conditioner_input_keys}"
+ # )
+ # else:
+ # ucg_keys = conditioner_input_keys
+ log = dict()
+ batch_uc = {}
+
+ x = self.get_input(batch)
+ num_frames = x.shape[2] # assuming bcthw format
+
+ for key in batch.keys():
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
+ batch_uc[key] = torch.clone(batch[key])
+
+ c, uc = self.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=ucg_keys
+ if ucg_keys is not None
+ else [
+ "cond_frames",
+ "cond_frames_without_noise",
+ ],
+ )
+
+ # for k in ["crossattn", "concat"]:
+ # uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
+ # uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
+ # c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
+ # c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
+
+ sampling_kwargs = {}
+
+ N = min(x.shape[0], N)
+ x = x.to(self.device)[:N]
+
+ if self.input_key != "latents":
+ log["inputs"] = x
+ z = self.encode_first_stage(x)
+ else:
+ z = x
+ log["reconstructions"] = self.decode_first_stage(z)
+ log.update(self.log_conditionings(batch, N))
+
+ if c.get("masks", None) is not None:
+ # Create a mask reconstruction
+ masks = 1 - c["masks"]
+ t = masks.shape[2]
+ masks = rearrange(masks, "b c t h w -> (b t) c h w")
+ target_size = (
+ log["reconstructions"].shape[-2],
+ log["reconstructions"].shape[-1],
+ )
+ masks = torch.nn.functional.interpolate(
+ masks, size=target_size, mode="nearest"
+ )
+ masks = rearrange(masks, "(b t) c h w -> b c t h w", t=t)
+ log["mask_reconstructions"] = log["reconstructions"] * masks
+
+ for k in c:
+ if isinstance(c[k], torch.Tensor):
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
+ elif isinstance(c[k], list):
+ for i in range(len(c[k])):
+ c[k][i], uc[k][i] = map(
+ lambda y: y[k][i][:N].to(self.device), (c, uc)
+ )
+
+ if sample:
+ n = 2 if self.is_guided else 1
+ # if num_frames == 1:
+ # sampling_kwargs["image_only_indicator"] = torch.ones(n, num_frames).to(self.device)
+ # else:
+ sampling_kwargs["image_only_indicator"] = torch.zeros(n, num_frames).to(
+ self.device
+ )
+ sampling_kwargs["num_video_frames"] = batch["num_video_frames"]
+
+ with self.ema_scope("Plotting"):
+ samples = self.sample(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+ if self.is_dubbing:
+ samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][
+ :, :, :, : samples.shape[-2] // 2
+ ]
+ log["samples"] = samples
+
+ # Without guidance
+ # if num_frames == 1:
+ # sampling_kwargs["image_only_indicator"] = torch.ones(1, num_frames).to(self.device)
+ # else:
+ sampling_kwargs["image_only_indicator"] = torch.zeros(1, num_frames).to(
+ self.device
+ )
+ sampling_kwargs["num_video_frames"] = batch["num_video_frames"]
+
+ with self.ema_scope("Plotting"):
+ samples = self.sample_no_guider(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+ if self.is_dubbing:
+ samples[:, :, :, : samples.shape[-2] // 2] = log["reconstructions"][
+ :, :, :, : samples.shape[-2] // 2
+ ]
+ log["samples_no_guidance"] = samples
+
+ torch.cuda.empty_cache()
+ return log
diff --git a/sgm/modules/__init__.py b/sgm/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0db1d7716a6e48f77b86a4b59c9289d6fb76b50b
--- /dev/null
+++ b/sgm/modules/__init__.py
@@ -0,0 +1,6 @@
+from .encoders.modules import GeneralConditioner
+
+UNCONDITIONAL_CONFIG = {
+ "target": "sgm.modules.GeneralConditioner",
+ "params": {"emb_models": []},
+}
diff --git a/sgm/modules/__pycache__/__init__.cpython-311.pyc b/sgm/modules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79af416b843aa53edaa199a1e2ef81585e51c26b
Binary files /dev/null and b/sgm/modules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/__pycache__/attention.cpython-311.pyc b/sgm/modules/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..086eb9d87b7bc5cf3b77cea581f26027c1a9a8ed
Binary files /dev/null and b/sgm/modules/__pycache__/attention.cpython-311.pyc differ
diff --git a/sgm/modules/__pycache__/ema.cpython-311.pyc b/sgm/modules/__pycache__/ema.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e20423c6026897afd7c0d41561a356ab85a093ae
Binary files /dev/null and b/sgm/modules/__pycache__/ema.cpython-311.pyc differ
diff --git a/sgm/modules/__pycache__/video_attention.cpython-311.pyc b/sgm/modules/__pycache__/video_attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bdf6e3c5fb08fd74a4be790d112a664647329e59
Binary files /dev/null and b/sgm/modules/__pycache__/video_attention.cpython-311.pyc differ
diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f656fb873d5ede32022a6aa87b9cbfe2e802a71
--- /dev/null
+++ b/sgm/modules/attention.py
@@ -0,0 +1,889 @@
+import logging
+import math
+from inspect import isfunction
+from typing import Any, Optional
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from packaging import version
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+logpy = logging.getLogger(__name__)
+
+if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ SDP_IS_AVAILABLE = True
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ BACKEND_MAP = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
+ }
+else:
+ from contextlib import nullcontext
+
+ SDP_IS_AVAILABLE = False
+ sdp_kernel = nullcontext
+ BACKEND_MAP = {}
+ logpy.warn(
+ f"No SDP backend available, likely because you are running in pytorch "
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
+ f"You might want to consider upgrading."
+ )
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ logpy.warn("no module 'xformers'. Processing without...")
+
+# from .diffusionmodules.util import mixed_checkpoint as checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SelfAttention(nn.Module):
+ ATTENTION_MODES = ("xformers", "torch", "math")
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_scale: Optional[float] = None,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ attn_mode: str = "xformers",
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ assert attn_mode in self.ATTENTION_MODES
+ self.attn_mode = attn_mode
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, L, C = x.shape
+
+ qkv = self.qkv(x)
+ if self.attn_mode == "torch":
+ qkv = rearrange(
+ qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
+ ).float()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "B H L D -> B L (H D)")
+ elif self.attn_mode == "xformers":
+ qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
+ x = xformers.ops.memory_efficient_attention(q, k, v)
+ x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
+ elif self.attn_mode == "math":
+ qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+ else:
+ raise NotImplemented
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ backend=None,
+ **kwargs,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.backend = backend
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ skip_attention=None,
+ **kwargs,
+ ):
+ h = self.heads
+
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+
+ # Ensure skip_attention is a B×1 boolean tensor
+ if skip_attention is None:
+ skip_attention = torch.zeros_like(x[:, :1], dtype=torch.bool)
+
+ assert isinstance(skip_attention, torch.Tensor)
+ assert skip_attention.shape[1] == 1 and skip_attention.dtype == torch.bool
+
+ # Split the batch into skip and non-skip parts
+ skip_indices = skip_attention.squeeze(1)
+ non_skip_indices = ~skip_indices
+
+ # Process skip attention samples
+ if skip_indices.any():
+ x_skip = x[skip_indices]
+ out_skip = self.to_v(x_skip)
+ out_skip = rearrange(out_skip, "b n (h d) -> b n (h d)", h=h)
+
+ # If all samples are skipped, return early
+ if not non_skip_indices.any():
+ if additional_tokens is not None:
+ out_skip = out_skip[:, n_tokens_to_mask:]
+ return self.to_out(out_skip)
+
+ # Process non-skip samples with attention
+ x_non_skip = x[non_skip_indices]
+ q = self.to_q(x_non_skip)
+ context = default(context, x_non_skip)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x_non_skip.shape[0] % n_times_crossframe_attn_in_self == 0
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
+
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+ # Combine skip and non-skip results
+ combined_out = torch.zeros(
+ (x.shape[0], out.shape[1], out.shape[2]), dtype=out.dtype, device=out.device
+ )
+ combined_out[non_skip_indices] = out
+ if skip_indices.any():
+ combined_out[skip_indices] = out_skip
+
+ if additional_tokens is not None:
+ combined_out = combined_out[:, n_tokens_to_mask:]
+ return self.to_out(combined_out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ use_reference=False,
+ extra_linear=False,
+ **kwargs,
+ ):
+ super().__init__()
+ logpy.debug(
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
+ f"context_dim is {context_dim} and using {heads} heads with a "
+ f"dimension of {dim_head}."
+ )
+ inner_dim = dim_head * heads
+ self.is_context = context_dim is not None
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+ self.use_reference = use_reference and self.is_context
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ if not self.use_reference:
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+ else:
+ if extra_linear:
+ self.to_k = nn.Linear(inner_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(inner_dim, inner_dim, bias=False)
+ self.extra_linear = extra_linear
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.attention_op: Optional[Any] = None
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ skip_attention=None,
+ ):
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+
+ # Ensure skip_attention is a B×1 boolean tensor
+ if skip_attention is None:
+ skip_attention = torch.zeros(x.shape[0], 1, dtype=torch.bool)
+ # print(x.shape)
+ # print(skip_attention)
+ # print(skip_attention.shape)
+ # print(any(skip_attention))
+ assert isinstance(skip_attention, torch.Tensor)
+ assert skip_attention.shape[1] == 1 and skip_attention.dtype == torch.bool
+
+ # Split the batch into skip and non-skip parts
+ skip_indices = skip_attention.squeeze(1)
+ non_skip_indices = ~skip_indices
+
+ # Process skip attention samples
+ if skip_indices.any():
+ x_skip = x[skip_indices]
+ out_skip = self.to_v(x_skip)
+ out_skip = (
+ out_skip.unsqueeze(0)
+ .reshape(-1, self.heads, out_skip.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(-1, out_skip.shape[1], self.heads * self.dim_head)
+ )
+ # If all samples are skipped, return early
+ if not non_skip_indices.any():
+ if additional_tokens is not None:
+ out_skip = out_skip[:, n_tokens_to_mask:]
+ return self.to_out(out_skip)
+
+ x_non_skip = x[non_skip_indices]
+ q = self.to_q(x_non_skip)
+ if not self.use_reference:
+ context = default(context, x_non_skip)
+ k = self.to_k(context)
+ v = self.to_v(context)
+ else:
+ # Reference has already correct shape
+ assert context is not None
+ if self.extra_linear:
+ k = self.to_k(context)
+ v = self.to_v(context)
+ k, v = context, context
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x_non_skip.shape[0] % n_times_crossframe_attn_in_self == 0
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+ if q.dtype != k.dtype:
+ k = k.to(q.dtype)
+ v = v.to(q.dtype)
+
+ # actually compute the attention, what we cannot get enough of
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
+ # NOTE: workaround for
+ # https://github.com/facebookresearch/xformers/issues/845
+ max_bs = 32768
+ N = q.shape[0]
+ n_batches = math.ceil(N / max_bs)
+ out = list()
+ for i_batch in range(n_batches):
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
+ out.append(
+ xformers.ops.memory_efficient_attention(
+ q[batch],
+ k[batch],
+ v[batch],
+ attn_bias=None,
+ op=self.attention_op,
+ )
+ )
+ out = torch.cat(out, 0)
+ else:
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ # TODO: Use this directly in the attention operation, as a bias
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ # Combine skip and non-skip results
+ combined_out = torch.zeros(
+ (x.shape[0], out.shape[1], out.shape[2]), dtype=out.dtype, device=out.device
+ )
+ combined_out[non_skip_indices] = out
+ if skip_indices.any():
+ combined_out[skip_indices] = out_skip
+ else:
+ combined_out = out
+
+ if additional_tokens is not None:
+ # remove additional token
+ combined_out = combined_out[:, n_tokens_to_mask:]
+ return self.to_out(combined_out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attn_mode="softmax",
+ sdp_backend=None,
+ reference_to=None,
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
+ logpy.warn(
+ f"Attention mode '{attn_mode}' is not available. Falling "
+ f"back to native attention. This is not a problem in "
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
+ f"version {torch.__version__}."
+ )
+ attn_mode = "softmax"
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
+ logpy.warn(
+ "We do not support vanilla attention anymore, as it is too "
+ "expensive. Sorry."
+ )
+ if not XFORMERS_IS_AVAILABLE:
+ assert False, (
+ "Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ )
+ else:
+ logpy.info("Falling back to xformers efficient attention.")
+ attn_mode = "softmax-xformers"
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
+ else:
+ assert sdp_backend is None
+ self.disable_self_attn = disable_self_attn
+ extra_linear = (reference_to is not None) and ("extra" in reference_to)
+ if extra_linear:
+ reference_to = reference_to.replace("_extra", "")
+ assert reference_to in [None, "self", "cross"]
+ self.reference_to = reference_to
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim
+ if (self.disable_self_attn or reference_to == "self")
+ else None,
+ backend=sdp_backend,
+ use_reference=reference_to == "self",
+ extra_linear=extra_linear,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ backend=sdp_backend,
+ use_reference=reference_to == "cross",
+ extra_linear=extra_linear,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(
+ self,
+ x,
+ context=None,
+ reference_context=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ skip_attention=None,
+ ):
+ kwargs = {"x": x}
+
+ if context is not None:
+ kwargs.update({"context": context})
+
+ if reference_context is not None:
+ kwargs.update({"reference_context": reference_context})
+
+ if additional_tokens is not None:
+ kwargs.update({"additional_tokens": additional_tokens})
+
+ if n_times_crossframe_attn_in_self:
+ kwargs.update(
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
+ )
+
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
+ if self.checkpoint:
+ # inputs = {"x": x, "context": context}
+ return checkpoint(
+ self._forward, x, context, reference_context, None, 0, skip_attention
+ )
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
+ else:
+ return self._forward(**kwargs)
+
+ def _forward(
+ self,
+ x,
+ context=None,
+ reference_context=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ skip_attention=None,
+ ):
+ self_context = reference_context if self.reference_to == "self" else context
+ # print(self.reference_to)
+ # print("context: ", context.shape if context is not None else None)
+ # print("reference_context: ", reference_context.shape if reference_context is not None else None)
+ # print("x: ", x.shape)
+
+ x = (
+ self.attn1(
+ self.norm1(x),
+ context=self_context
+ if (self.disable_self_attn or self.reference_to == "self")
+ else None,
+ additional_tokens=additional_tokens,
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
+ if not self.disable_self_attn
+ else 0,
+ skip_attention=skip_attention,
+ )
+ + x
+ )
+ cross_context = reference_context if self.reference_to == "cross" else context
+ x = (
+ self.attn2(
+ self.norm2(x),
+ context=cross_context,
+ additional_tokens=additional_tokens,
+ )
+ + x
+ )
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class BasicTransformerSingleLayerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ attn_mode="softmax",
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ # inputs = {"x": x, "context": context}
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
+ return checkpoint(self._forward, x, context)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context) + x
+ x = self.ff(self.norm2(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ attn_type="softmax",
+ use_checkpoint=True,
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
+ sdp_backend=None,
+ reference_to=None,
+ ):
+ super().__init__()
+ logpy.debug(
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
+ f"{in_channels} channels and {n_heads} heads."
+ )
+
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ if exists(context_dim) and isinstance(context_dim, list):
+ if depth != len(context_dim):
+ logpy.warn(
+ f"{self.__class__.__name__}: Found context dims "
+ f"{context_dim} of depth {len(context_dim)}, which does not "
+ f"match the specified 'depth' of {depth}. Setting context_dim "
+ f"to {depth * [context_dim[0]]} now."
+ )
+ # depth does not match context dims.
+ assert all(map(lambda x: x == context_dim[0], context_dim)), (
+ "need homogenous context_dim to match depth automatically"
+ )
+ context_dim = depth * [context_dim[0]]
+ elif context_dim is None:
+ context_dim = [None] * depth
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ attn_mode=attn_type,
+ checkpoint=use_checkpoint,
+ sdp_backend=sdp_backend,
+ reference_to=reference_to,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None, skip_attention=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ if i > 0 and len(context) == 1:
+ i = 0 # use same context for each block
+ x = block(x, context=context[i], skip_attention=skip_attention)
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class SimpleTransformer(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ context_dim: Optional[int] = None,
+ dropout: float = 0.0,
+ checkpoint: bool = True,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ BasicTransformerBlock(
+ dim,
+ heads,
+ dim_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ attn_mode="softmax-xformers",
+ checkpoint=checkpoint,
+ )
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ for layer in self.layers:
+ x = layer(x, context)
+ return x
diff --git a/sgm/modules/autoencoding/__init__.py b/sgm/modules/autoencoding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/__pycache__/__init__.cpython-311.pyc b/sgm/modules/autoencoding/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23335fbac3b06aea22689a4679776bc759a0dae8
Binary files /dev/null and b/sgm/modules/autoencoding/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-311.pyc b/sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93661bc133d2b1146b847dd2b4b6d3b8634939c9
Binary files /dev/null and b/sgm/modules/autoencoding/__pycache__/temporal_ae.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b316c7aa6ea1c5e31a58987aa3b37b2933eb7e2
--- /dev/null
+++ b/sgm/modules/autoencoding/losses/__init__.py
@@ -0,0 +1,7 @@
+__all__ = [
+ "GeneralLPIPSWithDiscriminator",
+ "LatentLPIPS",
+]
+
+from .discriminator_loss import GeneralLPIPSWithDiscriminator
+from .lpips import LatentLPIPS
diff --git a/sgm/modules/autoencoding/losses/discriminator_loss.py b/sgm/modules/autoencoding/losses/discriminator_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..09b6829267bf8e4d98c3f29abdc19e58dcbcbe64
--- /dev/null
+++ b/sgm/modules/autoencoding/losses/discriminator_loss.py
@@ -0,0 +1,306 @@
+from typing import Dict, Iterator, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torchvision
+from einops import rearrange
+from matplotlib import colormaps
+from matplotlib import pyplot as plt
+
+from ....util import default, instantiate_from_config
+from ..lpips.loss.lpips import LPIPS
+from ..lpips.model.model import weights_init
+from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+
+class GeneralLPIPSWithDiscriminator(nn.Module):
+ def __init__(
+ self,
+ disc_start: int,
+ logvar_init: float = 0.0,
+ disc_num_layers: int = 3,
+ disc_in_channels: int = 3,
+ disc_factor: float = 1.0,
+ disc_weight: float = 1.0,
+ perceptual_weight: float = 1.0,
+ disc_loss: str = "hinge",
+ scale_input_to_tgt_size: bool = False,
+ dims: int = 2,
+ learn_logvar: bool = False,
+ regularization_weights: Union[None, Dict[str, float]] = None,
+ additional_log_keys: Optional[List[str]] = None,
+ discriminator_config: Optional[Dict] = None,
+ ):
+ super().__init__()
+ self.dims = dims
+ if self.dims > 2:
+ print(
+ f"running with dims={dims}. This means that for perceptual loss "
+ f"calculation, the LPIPS loss will be applied to each frame "
+ f"independently."
+ )
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ assert disc_loss in ["hinge", "vanilla"]
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(
+ torch.full((), logvar_init), requires_grad=learn_logvar
+ )
+ self.learn_logvar = learn_logvar
+
+ discriminator_config = default(
+ discriminator_config,
+ {
+ "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
+ "params": {
+ "input_nc": disc_in_channels,
+ "n_layers": disc_num_layers,
+ "use_actnorm": False,
+ },
+ },
+ )
+
+ self.discriminator = instantiate_from_config(discriminator_config).apply(
+ weights_init
+ )
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.regularization_weights = default(regularization_weights, {})
+
+ self.forward_keys = [
+ "optimizer_idx",
+ "global_step",
+ "last_layer",
+ "split",
+ "regularization_log",
+ ]
+
+ self.additional_log_keys = set(default(additional_log_keys, []))
+ self.additional_log_keys.update(set(self.regularization_weights.keys()))
+
+ def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
+ return self.discriminator.parameters()
+
+ def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
+ if self.learn_logvar:
+ yield self.logvar
+ yield from ()
+
+ @torch.no_grad()
+ def log_images(
+ self, inputs: torch.Tensor, reconstructions: torch.Tensor
+ ) -> Dict[str, torch.Tensor]:
+ # calc logits of real/fake
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ if len(logits_real.shape) < 4:
+ # Non patch-discriminator
+ return dict()
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ # -> (b, 1, h, w)
+
+ # parameters for colormapping
+ high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
+ cmap = colormaps["PiYG"] # diverging colormap
+
+ def to_colormap(logits: torch.Tensor) -> torch.Tensor:
+ """(b, 1, ...) -> (b, 3, ...)"""
+ logits = (logits + high) / (2 * high)
+ logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
+ # -> (b, 1, ..., 3)
+ logits = torch.from_numpy(logits_np).to(logits.device)
+ return rearrange(logits, "b 1 ... c -> b c ...")
+
+ logits_real = torch.nn.functional.interpolate(
+ logits_real,
+ size=inputs.shape[-2:],
+ mode="nearest",
+ antialias=False,
+ )
+ logits_fake = torch.nn.functional.interpolate(
+ logits_fake,
+ size=reconstructions.shape[-2:],
+ mode="nearest",
+ antialias=False,
+ )
+
+ # alpha value of logits for overlay
+ alpha_real = torch.abs(logits_real) / high
+ alpha_fake = torch.abs(logits_fake) / high
+ # -> (b, 1, h, w) in range [0, 0.5]
+ # alpha value of lines don't really matter, since the values are the same
+ # for both images and logits anyway
+ grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
+ grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
+ grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
+ # -> (1, h, w)
+ # blend logits and images together
+
+ # prepare logits for plotting
+ logits_real = to_colormap(logits_real)
+ logits_fake = to_colormap(logits_fake)
+ # resize logits
+ # -> (b, 3, h, w)
+
+ # make some grids
+ # add all logits to one plot
+ logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
+ logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
+ # I just love how torchvision calls the number of columns `nrow`
+ grid_logits = torch.cat((logits_real, logits_fake), dim=1)
+ # -> (3, h, w)
+
+ grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
+ grid_images_fake = torchvision.utils.make_grid(
+ 0.5 * reconstructions + 0.5, nrow=4
+ )
+ grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
+ # -> (3, h, w) in range [0, 1]
+
+ grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
+
+ # Create labeled colorbar
+ dpi = 100
+ height = 128 / dpi
+ width = grid_logits.shape[2] / dpi
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
+ img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
+ plt.colorbar(
+ img,
+ cax=ax,
+ orientation="horizontal",
+ fraction=0.9,
+ aspect=width / height,
+ pad=0.0,
+ )
+ img.set_visible(False)
+ fig.tight_layout()
+ fig.canvas.draw()
+ # manually convert figure to numpy
+ cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
+ cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
+
+ # Add colorbar to plot
+ annotated_grid = torch.cat((grid_logits, cbar), dim=1)
+ blended_grid = torch.cat((grid_blend, cbar), dim=1)
+ return {
+ "vis_logits": 2 * annotated_grid[None, ...] - 1,
+ "vis_logits_blended": 2 * blended_grid[None, ...] - 1,
+ }
+
+ def calculate_adaptive_weight(
+ self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
+ ) -> torch.Tensor:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ reconstructions: torch.Tensor,
+ *, # added because I changed the order here
+ regularization_log: Dict[str, torch.Tensor],
+ optimizer_idx: int,
+ global_step: int,
+ last_layer: torch.Tensor,
+ split: str = "train",
+ weights: Union[None, float, torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, dict]:
+ if self.scale_input_to_tgt_size:
+ inputs = torch.nn.functional.interpolate(
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
+ )
+
+ if self.dims > 2:
+ inputs, reconstructions = map(
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
+ (inputs, reconstructions),
+ )
+
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(
+ inputs.contiguous(), reconstructions.contiguous()
+ )
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if global_step >= self.discriminator_iter_start or not self.training:
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ g_loss = -torch.mean(logits_fake)
+ if self.training:
+ d_weight = self.calculate_adaptive_weight(
+ nll_loss, g_loss, last_layer=last_layer
+ )
+ else:
+ d_weight = torch.tensor(1.0)
+ else:
+ d_weight = torch.tensor(0.0)
+ g_loss = torch.tensor(0.0, requires_grad=True)
+
+ loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
+ log = dict()
+ for k in regularization_log:
+ if k in self.regularization_weights:
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
+ if k in self.additional_log_keys:
+ log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
+
+ log.update(
+ {
+ f"{split}/loss/total": loss.clone().detach().mean(),
+ f"{split}/loss/nll": nll_loss.detach().mean(),
+ f"{split}/loss/rec": rec_loss.detach().mean(),
+ f"{split}/loss/g": g_loss.detach().mean(),
+ f"{split}/scalars/logvar": self.logvar.detach(),
+ f"{split}/scalars/d_weight": d_weight.detach(),
+ }
+ )
+
+ return loss, log
+ elif optimizer_idx == 1:
+ # second pass for discriminator update
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+
+ if global_step >= self.discriminator_iter_start or not self.training:
+ d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
+ else:
+ d_loss = torch.tensor(0.0, requires_grad=True)
+
+ log = {
+ f"{split}/loss/disc": d_loss.clone().detach().mean(),
+ f"{split}/logits/real": logits_real.detach().mean(),
+ f"{split}/logits/fake": logits_fake.detach().mean(),
+ }
+ return d_loss, log
+ else:
+ raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
+
+ def get_nll_loss(
+ self,
+ rec_loss: torch.Tensor,
+ weights: Optional[Union[float, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+
+ return nll_loss, weighted_nll_loss
diff --git a/sgm/modules/autoencoding/losses/lpips.py b/sgm/modules/autoencoding/losses/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..b329fcc2ee9477f0122aa7d066866cdfe71ce521
--- /dev/null
+++ b/sgm/modules/autoencoding/losses/lpips.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn as nn
+
+from ....util import default, instantiate_from_config
+from ..lpips.loss.lpips import LPIPS
+
+
+class LatentLPIPS(nn.Module):
+ def __init__(
+ self,
+ decoder_config,
+ perceptual_weight=1.0,
+ latent_weight=1.0,
+ scale_input_to_tgt_size=False,
+ scale_tgt_to_input_size=False,
+ perceptual_weight_on_inputs=0.0,
+ ):
+ super().__init__()
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
+ self.init_decoder(decoder_config)
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ self.latent_weight = latent_weight
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
+
+ def init_decoder(self, config):
+ self.decoder = instantiate_from_config(config)
+ if hasattr(self.decoder, "encoder"):
+ del self.decoder.encoder
+
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
+ log = dict()
+ loss = (latent_inputs - latent_predictions) ** 2
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
+ image_reconstructions = None
+ if self.perceptual_weight > 0.0:
+ image_reconstructions = self.decoder.decode(latent_predictions)
+ image_targets = self.decoder.decode(latent_inputs)
+ perceptual_loss = self.perceptual_loss(
+ image_targets.contiguous(), image_reconstructions.contiguous()
+ )
+ loss = (
+ self.latent_weight * loss.mean()
+ + self.perceptual_weight * perceptual_loss.mean()
+ )
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
+
+ if self.perceptual_weight_on_inputs > 0.0:
+ image_reconstructions = default(
+ image_reconstructions, self.decoder.decode(latent_predictions)
+ )
+ if self.scale_input_to_tgt_size:
+ image_inputs = torch.nn.functional.interpolate(
+ image_inputs,
+ image_reconstructions.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+ elif self.scale_tgt_to_input_size:
+ image_reconstructions = torch.nn.functional.interpolate(
+ image_reconstructions,
+ image_inputs.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+
+ perceptual_loss2 = self.perceptual_loss(
+ image_inputs.contiguous(), image_reconstructions.contiguous()
+ )
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
+ return loss, log
diff --git a/sgm/modules/autoencoding/lpips/__init__.py b/sgm/modules/autoencoding/lpips/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/lpips/__pycache__/__init__.cpython-311.pyc b/sgm/modules/autoencoding/lpips/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1febdb34a8b9090bd2e8228a2ff438d12221c380
Binary files /dev/null and b/sgm/modules/autoencoding/lpips/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/lpips/__pycache__/util.cpython-311.pyc b/sgm/modules/autoencoding/lpips/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5f3682e1abeb963984eb985c53afeafd0e0a395f
Binary files /dev/null and b/sgm/modules/autoencoding/lpips/__pycache__/util.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/lpips/loss/.gitignore b/sgm/modules/autoencoding/lpips/loss/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a92958a1cd4ffe005e1f5448ab3e6fd9c795a43a
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/loss/.gitignore
@@ -0,0 +1 @@
+vgg.pth
\ No newline at end of file
diff --git a/sgm/modules/autoencoding/lpips/loss/LICENSE b/sgm/modules/autoencoding/lpips/loss/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..924cfc85b8d63ef538f5676f830a2a8497932108
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/loss/LICENSE
@@ -0,0 +1,23 @@
+Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/sgm/modules/autoencoding/lpips/loss/__init__.py b/sgm/modules/autoencoding/lpips/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/lpips/loss/__pycache__/__init__.cpython-311.pyc b/sgm/modules/autoencoding/lpips/loss/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7bd583d36b6a9bc738d738ed4ed567a52042677
Binary files /dev/null and b/sgm/modules/autoencoding/lpips/loss/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/lpips/loss/__pycache__/lpips.cpython-311.pyc b/sgm/modules/autoencoding/lpips/loss/__pycache__/lpips.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a69e3234c2d4015fdc7c0fae65ef70b56a335a8a
Binary files /dev/null and b/sgm/modules/autoencoding/lpips/loss/__pycache__/lpips.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/lpips/loss/lpips.py b/sgm/modules/autoencoding/lpips/loss/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e34f3d083674f675a5ca024e9bd27fb77e2b6b5
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/loss/lpips.py
@@ -0,0 +1,147 @@
+"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
+
+from collections import namedtuple
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from ..util import get_ckpt_path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
+ self.load_state_dict(
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
+ )
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
+ )
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
+ outs1[kk]
+ )
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
+ for kk in range(len(self.chns))
+ ]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer(
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
+ )
+ self.register_buffer(
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
+ )
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """A single linear layer which does a 1x1 conv"""
+
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = (
+ [
+ nn.Dropout(),
+ ]
+ if (use_dropout)
+ else []
+ )
+ layers += [
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
+ ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple(
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
+ )
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2, 3], keepdim=keepdim)
diff --git a/sgm/modules/autoencoding/lpips/model/LICENSE b/sgm/modules/autoencoding/lpips/model/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..4b356e66b5aa689b339f1a80a9f1b5ba378003bb
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/model/LICENSE
@@ -0,0 +1,58 @@
+Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+--------------------------- LICENSE FOR pix2pix --------------------------------
+BSD License
+
+For pix2pix software
+Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+----------------------------- LICENSE FOR DCGAN --------------------------------
+BSD License
+
+For dcgan.torch software
+
+Copyright (c) 2015, Facebook, Inc. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/sgm/modules/autoencoding/lpips/model/__init__.py b/sgm/modules/autoencoding/lpips/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/lpips/model/model.py b/sgm/modules/autoencoding/lpips/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..66357d4e627f9a69a5abbbad15546c96fcd758fe
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/model/model.py
@@ -0,0 +1,88 @@
+import functools
+
+import torch.nn as nn
+
+from ..util import ActNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if (
+ type(norm_layer) == functools.partial
+ ): # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
+ nn.LeakyReLU(0.2, True),
+ ]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=2,
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=1,
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
+ ] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
diff --git a/sgm/modules/autoencoding/lpips/util.py b/sgm/modules/autoencoding/lpips/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..49c76e370bf16888ab61f42844b3c9f14ad9014c
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/util.py
@@ -0,0 +1,128 @@
+import hashlib
+import os
+
+import requests
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
+
+CKPT_MAP = {"vgg_lpips": "vgg.pth"}
+
+MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class ActNorm(nn.Module):
+ def __init__(
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
+ ):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height * width * torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
diff --git a/sgm/modules/autoencoding/lpips/vqperceptual.py b/sgm/modules/autoencoding/lpips/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..6195f0a6ed7ee6fd32c1bccea071e6075e95ee43
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/vqperceptual.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn.functional as F
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real))
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
+ )
+ return d_loss
diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff2b1815a5ba88892375e8ec9bedacea49024113
--- /dev/null
+++ b/sgm/modules/autoencoding/regularizers/__init__.py
@@ -0,0 +1,31 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ....modules.distributions.distributions import \
+ DiagonalGaussianDistribution
+from .base import AbstractRegularizer
+
+
+class DiagonalGaussianRegularizer(AbstractRegularizer):
+ def __init__(self, sample: bool = True):
+ super().__init__()
+ self.sample = sample
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ log = dict()
+ posterior = DiagonalGaussianDistribution(z)
+ if self.sample:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ kl_loss = posterior.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ log["kl_loss"] = kl_loss
+ return z, log
diff --git a/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-311.pyc b/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..afa540d624f96103b045ead1a140348a9723a568
Binary files /dev/null and b/sgm/modules/autoencoding/regularizers/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/regularizers/__pycache__/base.cpython-311.pyc b/sgm/modules/autoencoding/regularizers/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e3120d4a32cb4526d8f7a45db755d788e50e61dd
Binary files /dev/null and b/sgm/modules/autoencoding/regularizers/__pycache__/base.cpython-311.pyc differ
diff --git a/sgm/modules/autoencoding/regularizers/base.py b/sgm/modules/autoencoding/regularizers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fca681bb3c1f4818b57e956e31b98f76077ccb67
--- /dev/null
+++ b/sgm/modules/autoencoding/regularizers/base.py
@@ -0,0 +1,40 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class AbstractRegularizer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_trainable_parameters(self) -> Any:
+ raise NotImplementedError()
+
+
+class IdentityRegularizer(AbstractRegularizer):
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ return z, dict()
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+
+def measure_perplexity(
+ predicted_indices: torch.Tensor, num_centroids: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = (
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
+ )
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
diff --git a/sgm/modules/autoencoding/regularizers/quantize.py b/sgm/modules/autoencoding/regularizers/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..86a4dbdd10101b24f03bba134c4f8d2ab007f0db
--- /dev/null
+++ b/sgm/modules/autoencoding/regularizers/quantize.py
@@ -0,0 +1,487 @@
+import logging
+from abc import abstractmethod
+from typing import Dict, Iterator, Literal, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch import einsum
+
+from .base import AbstractRegularizer, measure_perplexity
+
+logpy = logging.getLogger(__name__)
+
+
+class AbstractQuantizer(AbstractRegularizer):
+ def __init__(self):
+ super().__init__()
+ # Define these in your init
+ # shape (N,)
+ self.used: Optional[torch.Tensor]
+ self.re_embed: int
+ self.unknown_index: Union[Literal["random"], int]
+
+ def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
+ assert self.used is not None, "You need to define used indices for remap"
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
+ device=new.device
+ )
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
+ assert self.used is not None, "You need to define used indices for remap"
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ @abstractmethod
+ def get_codebook_entry(
+ self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
+ ) -> torch.Tensor:
+ raise NotImplementedError()
+
+ def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
+ yield from self.parameters()
+
+
+class GumbelQuantizer(AbstractQuantizer):
+ """
+ credit to @karpathy:
+ https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+
+ def __init__(
+ self,
+ num_hiddens: int,
+ embedding_dim: int,
+ n_embed: int,
+ straight_through: bool = True,
+ kl_weight: float = 5e-4,
+ temp_init: float = 1.0,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ loss_key: str = "loss/vq",
+ ) -> None:
+ super().__init__()
+
+ self.loss_key = loss_key
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_embed
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ def forward(
+ self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
+ ) -> Tuple[torch.Tensor, Dict]:
+ # force hard = True when we are in eval mode, as we must quantize.
+ # actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+ out_dict = {}
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:, self.used, ...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:, self.used, ...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = (
+ self.kl_weight
+ * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+ )
+ out_dict[self.loss_key] = diff
+
+ ind = soft_one_hot.argmax(dim=1)
+ out_dict["indices"] = ind
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+
+ if return_logits:
+ out_dict["logits"] = logits
+
+ return z_q, out_dict
+
+ def get_codebook_entry(self, indices, shape):
+ # TODO: shape not yet optional
+ b, h, w, c = shape
+ assert b * h * w == indices.shape[0]
+ indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = (
+ F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ )
+ z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer(AbstractQuantizer):
+ """
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term,
+ beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ def __init__(
+ self,
+ n_e: int,
+ e_dim: int,
+ beta: float = 0.25,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ sane_index_shape: bool = False,
+ log_perplexity: bool = False,
+ embedding_weight_norm: bool = False,
+ loss_key: str = "loss/vq",
+ ):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.loss_key = loss_key
+
+ if not embedding_weight_norm:
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+ else:
+ self.embedding = torch.nn.utils.weight_norm(
+ nn.Embedding(self.n_e, self.e_dim), dim=1
+ )
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_e
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ self.sane_index_shape = sane_index_shape
+ self.log_perplexity = log_perplexity
+
+ def forward(
+ self,
+ z: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict]:
+ do_reshape = z.ndim == 4
+ if do_reshape:
+ # # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
+
+ else:
+ assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
+ z = z.contiguous()
+
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2
+ * torch.einsum(
+ "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
+ )
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ loss_dict = {}
+ if self.log_perplexity:
+ perplexity, cluster_usage = measure_perplexity(
+ min_encoding_indices.detach(), self.n_e
+ )
+ loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
+
+ # compute loss for embedding
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
+ (z_q - z.detach()) ** 2
+ )
+ loss_dict[self.loss_key] = loss
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ if do_reshape:
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z.shape[0], -1
+ ) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ if do_reshape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3]
+ )
+ else:
+ min_encoding_indices = rearrange(
+ min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
+ )
+
+ loss_dict["min_encoding_indices"] = min_encoding_indices
+
+ return z_q, loss_dict
+
+ def get_codebook_entry(
+ self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
+ ) -> torch.Tensor:
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ assert shape is not None, "Need to give shape for remap"
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad=False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(
+ new_cluster_size, alpha=1 - self.decay
+ )
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ # normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(AbstractQuantizer):
+ def __init__(
+ self,
+ n_embed: int,
+ embedding_dim: int,
+ beta: float,
+ decay: float = 0.99,
+ eps: float = 1e-5,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ loss_key: str = "loss/vq",
+ ):
+ super().__init__()
+ self.codebook_dim = embedding_dim
+ self.num_tokens = n_embed
+ self.beta = beta
+ self.loss_key = loss_key
+
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_embed
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z, 'b c h w -> b h w c'
+ z = rearrange(z, "b c h w -> b h w c")
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (
+ z_flattened.pow(2).sum(dim=1, keepdim=True)
+ + self.embedding.weight.pow(2).sum(dim=1)
+ - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
+ ) # 'n d -> d n'
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ # EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ # EMA embedding average
+ embed_sum = encodings.transpose(0, 1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ # normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ # z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, "b h w c -> b c h w")
+
+ out_dict = {
+ self.loss_key: loss,
+ "encodings": encodings,
+ "encoding_indices": encoding_indices,
+ "perplexity": perplexity,
+ }
+
+ return z_q, out_dict
+
+
+class VectorQuantizerWithInputProjection(VectorQuantizer):
+ def __init__(
+ self,
+ input_dim: int,
+ n_codes: int,
+ codebook_dim: int,
+ beta: float = 1.0,
+ output_dim: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__(n_codes, codebook_dim, beta, **kwargs)
+ self.proj_in = nn.Linear(input_dim, codebook_dim)
+ self.output_dim = output_dim
+ if output_dim is not None:
+ self.proj_out = nn.Linear(codebook_dim, output_dim)
+ else:
+ self.proj_out = nn.Identity()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ rearr = False
+ in_shape = z.shape
+
+ if z.ndim > 3:
+ rearr = self.output_dim is not None
+ z = rearrange(z, "b c ... -> b (...) c")
+ z = self.proj_in(z)
+ z_q, loss_dict = super().forward(z)
+
+ z_q = self.proj_out(z_q)
+ if rearr:
+ if len(in_shape) == 4:
+ z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
+ elif len(in_shape) == 5:
+ z_q = rearrange(
+ z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
+ )
+ else:
+ raise NotImplementedError(
+ f"rearranging not available for {len(in_shape)}-dimensional input."
+ )
+
+ return z_q, loss_dict
diff --git a/sgm/modules/autoencoding/temporal_ae.py b/sgm/modules/autoencoding/temporal_ae.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a17a91163469dbd8cfe2373d0d09554d5e25ed9
--- /dev/null
+++ b/sgm/modules/autoencoding/temporal_ae.py
@@ -0,0 +1,347 @@
+from typing import Callable, Iterable, Union
+
+import torch
+from einops import rearrange, repeat
+
+from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE,
+ AttnBlock, Decoder,
+ MemoryEfficientAttnBlock,
+ ResnetBlock)
+from sgm.modules.diffusionmodules.openaimodel import (ResBlock,
+ timestep_embedding)
+from sgm.modules.video_attention import VideoTransformerBlock
+from sgm.util import partialclass
+
+
+class VideoResBlock(ResnetBlock):
+ def __init__(
+ self,
+ out_channels,
+ *args,
+ dropout=0.0,
+ video_kernel_size=3,
+ alpha=0.0,
+ merge_strategy="learned",
+ **kwargs,
+ ):
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
+ if video_kernel_size is None:
+ video_kernel_size = [3, 1, 1]
+ self.time_stack = ResBlock(
+ channels=out_channels,
+ emb_channels=0,
+ dropout=dropout,
+ dims=3,
+ use_scale_shift_norm=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ kernel_size=video_kernel_size,
+ use_checkpoint=False,
+ skip_t_emb=True,
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def get_alpha(self, bs):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError()
+
+ def forward(self, x, temb, skip_video=False, timesteps=None):
+ if timesteps is None:
+ timesteps = self.timesteps
+
+ b, c, h, w = x.shape
+
+ x = super().forward(x, temb)
+
+ if not skip_video:
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+
+ x = self.time_stack(x, temb)
+
+ alpha = self.get_alpha(bs=b // timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix
+
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ return x
+
+
+class AE3DConv(torch.nn.Conv2d):
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ if isinstance(video_kernel_size, Iterable):
+ padding = [int(k // 2) for k in video_kernel_size]
+ else:
+ padding = int(video_kernel_size // 2)
+
+ self.time_mix_conv = torch.nn.Conv3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=video_kernel_size,
+ padding=padding,
+ )
+
+ def forward(self, input, timesteps, skip_video=False):
+ x = super().forward(input)
+ if skip_video:
+ return x
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+ x = self.time_mix_conv(x)
+ return rearrange(x, "b c t h w -> (b t) c h w")
+
+
+class VideoBlock(AttnBlock):
+ def __init__(
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
+ ):
+ super().__init__(in_channels)
+ # no context, single headed, as in base class
+ self.time_mix_block = VideoTransformerBlock(
+ dim=in_channels,
+ n_heads=1,
+ d_head=in_channels,
+ checkpoint=False,
+ ff_in=True,
+ attn_mode="softmax",
+ )
+
+ time_embed_dim = self.in_channels * 4
+ self.video_time_embed = torch.nn.Sequential(
+ torch.nn.Linear(self.in_channels, time_embed_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(time_embed_dim, self.in_channels),
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def forward(self, x, timesteps, skip_video=False):
+ if skip_video:
+ return super().forward(x)
+
+ x_in = x
+ x = self.attention(x)
+ h, w = x.shape[2:]
+ x = rearrange(x, "b c h w -> b (h w) c")
+
+ x_mix = x
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
+ emb = self.video_time_embed(t_emb) # b, n_channels
+ emb = emb[:, None, :]
+ x_mix = x_mix + emb
+
+ alpha = self.get_alpha()
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
+
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+
+ return x_in + x
+
+ def get_alpha(
+ self,
+ ):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
+
+
+class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
+ def __init__(
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
+ ):
+ super().__init__(in_channels)
+ # no context, single headed, as in base class
+ self.time_mix_block = VideoTransformerBlock(
+ dim=in_channels,
+ n_heads=1,
+ d_head=in_channels,
+ checkpoint=False,
+ ff_in=True,
+ attn_mode="softmax-xformers",
+ )
+
+ time_embed_dim = self.in_channels * 4
+ self.video_time_embed = torch.nn.Sequential(
+ torch.nn.Linear(self.in_channels, time_embed_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(time_embed_dim, self.in_channels),
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def forward(self, x, timesteps, skip_time_block=False):
+ if skip_time_block:
+ return super().forward(x)
+
+ x_in = x
+ x = self.attention(x)
+ h, w = x.shape[2:]
+ x = rearrange(x, "b c h w -> b (h w) c")
+
+ x_mix = x
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
+ emb = self.video_time_embed(t_emb) # b, n_channels
+ emb = emb[:, None, :]
+ x_mix = x_mix + emb
+
+ alpha = self.get_alpha()
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
+
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+
+ return x_in + x
+
+ def get_alpha(
+ self,
+ ):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
+
+
+def make_time_attn(
+ in_channels,
+ attn_type="vanilla",
+ attn_kwargs=None,
+ alpha: float = 0,
+ merge_strategy: str = "learned",
+):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ ], f"attn_type {attn_type} not supported for spatio-temporal attention"
+ print(
+ f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
+ )
+ if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
+ print(
+ f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
+ )
+ attn_type = "vanilla"
+
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return partialclass(
+ VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
+ )
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return partialclass(
+ MemoryEfficientVideoBlock,
+ in_channels,
+ alpha=alpha,
+ merge_strategy=merge_strategy,
+ )
+ else:
+ return NotImplementedError()
+
+
+class Conv2DWrapper(torch.nn.Conv2d):
+ def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
+ return super().forward(input)
+
+
+class VideoDecoder(Decoder):
+ available_time_modes = ["all", "conv-only", "attn-only"]
+
+ def __init__(
+ self,
+ *args,
+ video_kernel_size: Union[int, list] = 3,
+ alpha: float = 0.0,
+ merge_strategy: str = "learned",
+ time_mode: str = "conv-only",
+ **kwargs,
+ ):
+ self.video_kernel_size = video_kernel_size
+ self.alpha = alpha
+ self.merge_strategy = merge_strategy
+ self.time_mode = time_mode
+ assert (
+ self.time_mode in self.available_time_modes
+ ), f"time_mode parameter has to be in {self.available_time_modes}"
+ super().__init__(*args, **kwargs)
+
+ def get_last_layer(self, skip_time_mix=False, **kwargs):
+ if self.time_mode == "attn-only":
+ raise NotImplementedError("TODO")
+ else:
+ return (
+ self.conv_out.time_mix_conv.weight
+ if not skip_time_mix
+ else self.conv_out.weight
+ )
+
+ def _make_attn(self) -> Callable:
+ if self.time_mode not in ["conv-only", "only-last-conv"]:
+ return partialclass(
+ make_time_attn,
+ alpha=self.alpha,
+ merge_strategy=self.merge_strategy,
+ )
+ else:
+ return super()._make_attn()
+
+ def _make_conv(self) -> Callable:
+ if self.time_mode != "attn-only":
+ return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
+ else:
+ return Conv2DWrapper
+
+ def _make_resblock(self) -> Callable:
+ if self.time_mode not in ["attn-only", "only-last-conv"]:
+ return partialclass(
+ VideoResBlock,
+ video_kernel_size=self.video_kernel_size,
+ alpha=self.alpha,
+ merge_strategy=self.merge_strategy,
+ )
+ else:
+ return super()._make_resblock()
diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9c6faa24461bafc99c82814359f6037d745c5b9
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/augment_pipeline.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/augment_pipeline.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24b9901b6d303ad3876ff337d897fba8f2b99540
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/augment_pipeline.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fddc868350d10c35a77548077a5983223a2db256
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/denoiser.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..09d744622ab42c2c90876f1aafe4c0260251fcbb
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/denoiser_scaling.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3007add51e1d91f254dba14f9d22f90ebcd1d75
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/discretizer.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..74518801704b470ce78e3392f840c0c526076693
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/guiders.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/loss.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/loss.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ea8c8515c369df1cb0616f94af41f9d66453669
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/loss.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/loss_weighting.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/loss_weighting.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee0acaa61312b723b6ba10504850b85394a0b606
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/loss_weighting.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/model.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9116534e4b89614c7e66da5ac358c298f52e4e6
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/model.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c0400d3ee28264a2e3753c320cb9f1e2c4ac553
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/openaimodel.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d4eca6ea5c3d805263a705e1204b743bab725608
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sampling.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0b69b8fe3f5d75412ee065e3e60bd99afd1da570
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sampling_utils.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7210f4f2648236de92342ef3de181a157e250f2b
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/sigma_sampling.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/util.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/util.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a9fc613e5ad487235ba61eab428a61ad8c6ca742
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/util.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/video_model.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/video_model.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..068541759823b5e4324f1f64207eb63305d8231f
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/video_model.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-311.pyc b/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1235f57204966cfc97e7cea7b635d7752c049692
Binary files /dev/null and b/sgm/modules/diffusionmodules/__pycache__/wrappers.cpython-311.pyc differ
diff --git a/sgm/modules/diffusionmodules/augment_pipeline.py b/sgm/modules/diffusionmodules/augment_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ca1c8de9e7e8332e876a69cc77023990049433a
--- /dev/null
+++ b/sgm/modules/diffusionmodules/augment_pipeline.py
@@ -0,0 +1,595 @@
+"""Augmentation pipeline used in the paper
+"Elucidating the Design Space of Diffusion-Based Generative Models".
+Built around the same concepts that were originally proposed in the paper
+"Training Generative Adversarial Networks with Limited Data"."""
+
+import numpy as np
+import torch
+from einops import rearrange, repeat
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+
+_constant_cache = dict()
+
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device("cpu")
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+
+# ----------------------------------------------------------------------------
+# Coefficients of various wavelet decomposition low-pass filters.
+
+wavelets = {
+ "haar": [0.7071067811865476, 0.7071067811865476],
+ "db1": [0.7071067811865476, 0.7071067811865476],
+ "db2": [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ "db3": [
+ 0.035226291882100656,
+ -0.08544127388224149,
+ -0.13501102001039084,
+ 0.4598775021193313,
+ 0.8068915093133388,
+ 0.3326705529509569,
+ ],
+ "db4": [
+ -0.010597401784997278,
+ 0.032883011666982945,
+ 0.030841381835986965,
+ -0.18703481171888114,
+ -0.02798376941698385,
+ 0.6308807679295904,
+ 0.7148465705525415,
+ 0.23037781330885523,
+ ],
+ "db5": [
+ 0.003335725285001549,
+ -0.012580751999015526,
+ -0.006241490213011705,
+ 0.07757149384006515,
+ -0.03224486958502952,
+ -0.24229488706619015,
+ 0.13842814590110342,
+ 0.7243085284385744,
+ 0.6038292697974729,
+ 0.160102397974125,
+ ],
+ "db6": [
+ -0.00107730108499558,
+ 0.004777257511010651,
+ 0.0005538422009938016,
+ -0.031582039318031156,
+ 0.02752286553001629,
+ 0.09750160558707936,
+ -0.12976686756709563,
+ -0.22626469396516913,
+ 0.3152503517092432,
+ 0.7511339080215775,
+ 0.4946238903983854,
+ 0.11154074335008017,
+ ],
+ "db7": [
+ 0.0003537138000010399,
+ -0.0018016407039998328,
+ 0.00042957797300470274,
+ 0.012550998556013784,
+ -0.01657454163101562,
+ -0.03802993693503463,
+ 0.0806126091510659,
+ 0.07130921926705004,
+ -0.22403618499416572,
+ -0.14390600392910627,
+ 0.4697822874053586,
+ 0.7291320908465551,
+ 0.39653931948230575,
+ 0.07785205408506236,
+ ],
+ "db8": [
+ -0.00011747678400228192,
+ 0.0006754494059985568,
+ -0.0003917403729959771,
+ -0.00487035299301066,
+ 0.008746094047015655,
+ 0.013981027917015516,
+ -0.04408825393106472,
+ -0.01736930100202211,
+ 0.128747426620186,
+ 0.00047248457399797254,
+ -0.2840155429624281,
+ -0.015829105256023893,
+ 0.5853546836548691,
+ 0.6756307362980128,
+ 0.3128715909144659,
+ 0.05441584224308161,
+ ],
+ "sym2": [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
+ "sym3": [
+ 0.035226291882100656,
+ -0.08544127388224149,
+ -0.13501102001039084,
+ 0.4598775021193313,
+ 0.8068915093133388,
+ 0.3326705529509569,
+ ],
+ "sym4": [
+ -0.07576571478927333,
+ -0.02963552764599851,
+ 0.49761866763201545,
+ 0.8037387518059161,
+ 0.29785779560527736,
+ -0.09921954357684722,
+ -0.012603967262037833,
+ 0.0322231006040427,
+ ],
+ "sym5": [
+ 0.027333068345077982,
+ 0.029519490925774643,
+ -0.039134249302383094,
+ 0.1993975339773936,
+ 0.7234076904024206,
+ 0.6339789634582119,
+ 0.01660210576452232,
+ -0.17532808990845047,
+ -0.021101834024758855,
+ 0.019538882735286728,
+ ],
+ "sym6": [
+ 0.015404109327027373,
+ 0.0034907120842174702,
+ -0.11799011114819057,
+ -0.048311742585633,
+ 0.4910559419267466,
+ 0.787641141030194,
+ 0.3379294217276218,
+ -0.07263752278646252,
+ -0.021060292512300564,
+ 0.04472490177066578,
+ 0.0017677118642428036,
+ -0.007800708325034148,
+ ],
+ "sym7": [
+ 0.002681814568257878,
+ -0.0010473848886829163,
+ -0.01263630340325193,
+ 0.03051551316596357,
+ 0.0678926935013727,
+ -0.049552834937127255,
+ 0.017441255086855827,
+ 0.5361019170917628,
+ 0.767764317003164,
+ 0.2886296317515146,
+ -0.14004724044296152,
+ -0.10780823770381774,
+ 0.004010244871533663,
+ 0.010268176708511255,
+ ],
+ "sym8": [
+ -0.0033824159510061256,
+ -0.0005421323317911481,
+ 0.03169508781149298,
+ 0.007607487324917605,
+ -0.1432942383508097,
+ -0.061273359067658524,
+ 0.4813596512583722,
+ 0.7771857517005235,
+ 0.3644418948353314,
+ -0.05194583810770904,
+ -0.027219029917056003,
+ 0.049137179673607506,
+ 0.003808752013890615,
+ -0.01495225833704823,
+ -0.0003029205147213668,
+ 0.0018899503327594609,
+ ],
+}
+
+# ----------------------------------------------------------------------------
+# Helpers for constructing transformation matrices.
+
+
+def matrix(*rows, device=None):
+ assert all(len(row) == len(rows[0]) for row in rows)
+ elems = [x for row in rows for x in row]
+ ref = [x for x in elems if isinstance(x, torch.Tensor)]
+ if len(ref) == 0:
+ return constant(np.asarray(rows), device=device)
+ assert device is None or device == ref[0].device
+ elems = [
+ x if isinstance(x, torch.Tensor) else constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems
+ ]
+ return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
+
+
+def translate2d(tx, ty, **kwargs):
+ return matrix([1, 0, tx], [0, 1, ty], [0, 0, 1], **kwargs)
+
+
+def translate3d(tx, ty, tz, **kwargs):
+ return matrix([1, 0, 0, tx], [0, 1, 0, ty], [0, 0, 1, tz], [0, 0, 0, 1], **kwargs)
+
+
+def scale2d(sx, sy, **kwargs):
+ return matrix([sx, 0, 0], [0, sy, 0], [0, 0, 1], **kwargs)
+
+
+def scale3d(sx, sy, sz, **kwargs):
+ return matrix([sx, 0, 0, 0], [0, sy, 0, 0], [0, 0, sz, 0], [0, 0, 0, 1], **kwargs)
+
+
+def rotate2d(theta, **kwargs):
+ return matrix(
+ [torch.cos(theta), torch.sin(-theta), 0], [torch.sin(theta), torch.cos(theta), 0], [0, 0, 1], **kwargs
+ )
+
+
+def rotate3d(v, theta, **kwargs):
+ vx = v[..., 0]
+ vy = v[..., 1]
+ vz = v[..., 2]
+ s = torch.sin(theta)
+ c = torch.cos(theta)
+ cc = 1 - c
+ return matrix(
+ [vx * vx * cc + c, vx * vy * cc - vz * s, vx * vz * cc + vy * s, 0],
+ [vy * vx * cc + vz * s, vy * vy * cc + c, vy * vz * cc - vx * s, 0],
+ [vz * vx * cc - vy * s, vz * vy * cc + vx * s, vz * vz * cc + c, 0],
+ [0, 0, 0, 1],
+ **kwargs,
+ )
+
+
+def translate2d_inv(tx, ty, **kwargs):
+ return translate2d(-tx, -ty, **kwargs)
+
+
+def scale2d_inv(sx, sy, **kwargs):
+ return scale2d(1 / sx, 1 / sy, **kwargs)
+
+
+def rotate2d_inv(theta, **kwargs):
+ return rotate2d(-theta, **kwargs)
+
+
+# ----------------------------------------------------------------------------
+# Augmentation pipeline main class.
+# All augmentations are disabled by default; individual augmentations can
+# be enabled by setting their probability multipliers to 1.
+
+
+class AugmentPipe:
+ def __init__(
+ self,
+ p=1,
+ xflip=0,
+ yflip=0,
+ rotate_int=0,
+ translate_int=0,
+ translate_int_max=0.125,
+ scale=0,
+ rotate_frac=0,
+ aniso=0,
+ translate_frac=0,
+ scale_std=0.2,
+ rotate_frac_max=1,
+ aniso_std=0.2,
+ aniso_rotate_prob=0.5,
+ translate_frac_std=0.125,
+ brightness=0,
+ contrast=0,
+ lumaflip=0,
+ hue=0,
+ saturation=0,
+ brightness_std=0.2,
+ contrast_std=0.5,
+ hue_max=1,
+ saturation_std=1,
+ ):
+ super().__init__()
+ self.p = float(p) # Overall multiplier for augmentation probability.
+
+ # Pixel blitting.
+ self.xflip = float(xflip) # Probability multiplier for x-flip.
+ self.yflip = float(yflip) # Probability multiplier for y-flip.
+ self.rotate_int = float(rotate_int) # Probability multiplier for integer rotation.
+ self.translate_int = float(translate_int) # Probability multiplier for integer translation.
+ self.translate_int_max = float(
+ translate_int_max
+ ) # Range of integer translation, relative to image dimensions.
+
+ # Geometric transformations.
+ self.scale = float(scale) # Probability multiplier for isotropic scaling.
+ self.rotate_frac = float(rotate_frac) # Probability multiplier for fractional rotation.
+ self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
+ self.translate_frac = float(translate_frac) # Probability multiplier for fractional translation.
+ self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
+ self.rotate_frac_max = float(rotate_frac_max) # Range of fractional rotation, 1 = full circle.
+ self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
+ self.aniso_rotate_prob = float(
+ aniso_rotate_prob
+ ) # Probability of doing anisotropic scaling w.r.t. rotated coordinate frame.
+ self.translate_frac_std = float(
+ translate_frac_std
+ ) # Standard deviation of frational translation, relative to image dimensions.
+
+ # Color transformations.
+ self.brightness = float(brightness) # Probability multiplier for brightness.
+ self.contrast = float(contrast) # Probability multiplier for contrast.
+ self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
+ self.hue = float(hue) # Probability multiplier for hue rotation.
+ self.saturation = float(saturation) # Probability multiplier for saturation.
+ self.brightness_std = float(brightness_std) # Standard deviation of brightness.
+ self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
+ self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
+ self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
+
+ def __call__(self, images):
+ F = None
+ repeat_frames = False
+ if len(images.shape) == 5:
+ N, C, F, H, W = images.shape
+ images = rearrange(images, "n c f h w -> (n f) c h w")
+ repeat_frames = True
+ elif len(images.shape) == 4:
+ N, C, H, W = images.shape
+ device = images.device
+ labels = [torch.zeros([images.shape[0], 0], device=device)]
+
+ # ---------------
+ # Pixel blitting.
+ # ---------------
+
+ if self.xflip > 0:
+ w = torch.randint(2, [N, 1, 1, 1], device=device)
+ w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.xflip * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n c h w -> (n f) c h w", f=F)
+ images = torch.where(w == 1, images.flip(3), images)
+ labels += [w]
+
+ if self.yflip > 0:
+ w = torch.randint(2, [N, 1, 1, 1], device=device)
+ w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.yflip * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n c h w -> (n f) c h w", f=F)
+ images = torch.where(w == 1, images.flip(2), images)
+ labels += [w]
+
+ if self.rotate_int > 0:
+ w = torch.randint(4, [N, 1, 1, 1], device=device)
+ w = torch.where(torch.rand([N, 1, 1, 1], device=device) < self.rotate_int * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n c h w -> (n f) c h w", f=F)
+ images = torch.where((w == 1) | (w == 2), images.flip(3), images)
+ images = torch.where((w == 2) | (w == 3), images.flip(2), images)
+ images = torch.where((w == 1) | (w == 3), images.transpose(2, 3), images)
+ labels += [(w == 1) | (w == 2), (w == 2) | (w == 3)]
+
+ if self.translate_int > 0:
+ w = torch.rand([2, N, 1, 1, 1], device=device) * 2 - 1
+ w = torch.where(
+ torch.rand([1, N, 1, 1, 1], device=device) < self.translate_int * self.p, w, torch.zeros_like(w)
+ )
+ if repeat_frames:
+ w = repeat(w, "* n c h w -> * (n f) c h w", f=F)
+ tx = w[0].mul(W * self.translate_int_max).round().to(torch.int64)
+ ty = w[1].mul(H * self.translate_int_max).round().to(torch.int64)
+ b, c, y, x = torch.meshgrid(*(torch.arange(x, device=device) for x in images.shape), indexing="ij")
+ x = W - 1 - (W - 1 - (x - tx) % (W * 2 - 2)).abs()
+ y = H - 1 - (H - 1 - (y + ty) % (H * 2 - 2)).abs()
+ images = images.flatten()[(((b * C) + c) * H + y) * W + x]
+ labels += [tx.div(W * self.translate_int_max), ty.div(H * self.translate_int_max)]
+
+ # ------------------------------------------------
+ # Select parameters for geometric transformations.
+ # ------------------------------------------------
+
+ I_3 = torch.eye(3, device=device)
+ G_inv = I_3
+
+ if self.scale > 0:
+ w = torch.randn([N], device=device)
+ w = torch.where(torch.rand([N], device=device) < self.scale * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n -> (n f)", f=F)
+ s = w.mul(self.scale_std).exp2()
+ G_inv = G_inv @ scale2d_inv(s, s)
+ labels += [w]
+
+ if self.rotate_frac > 0:
+ w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.rotate_frac_max)
+ w = torch.where(torch.rand([N], device=device) < self.rotate_frac * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n -> (n f)", f=F)
+ G_inv = G_inv @ rotate2d_inv(-w)
+ labels += [w.cos() - 1, w.sin()]
+
+ if self.aniso > 0:
+ w = torch.randn([N], device=device)
+ r = (torch.rand([N], device=device) * 2 - 1) * np.pi
+ w = torch.where(torch.rand([N], device=device) < self.aniso * self.p, w, torch.zeros_like(w))
+ r = torch.where(torch.rand([N], device=device) < self.aniso_rotate_prob, r, torch.zeros_like(r))
+ if repeat_frames:
+ w = repeat(w, "n -> (n f)", f=F)
+ r = repeat(r, "n -> (n f)", f=F)
+ s = w.mul(self.aniso_std).exp2()
+ G_inv = G_inv @ rotate2d_inv(r) @ scale2d_inv(s, 1 / s) @ rotate2d_inv(-r)
+ labels += [w * r.cos(), w * r.sin()]
+
+ if self.translate_frac > 0:
+ w = torch.randn([2, N], device=device)
+ w = torch.where(torch.rand([1, N], device=device) < self.translate_frac * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "c n -> c (n f)", f=F)
+ G_inv = G_inv @ translate2d_inv(
+ w[0].mul(W * self.translate_frac_std), w[1].mul(H * self.translate_frac_std)
+ )
+ labels += [w[0], w[1]]
+
+ # ----------------------------------
+ # Execute geometric transformations.
+ # ----------------------------------
+
+ if G_inv is not I_3:
+ cx = (W - 1) / 2
+ cy = (H - 1) / 2
+ cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
+ cp = G_inv @ cp.t() # [batch, xyz, idx]
+ Hz = np.asarray(wavelets["sym6"], dtype=np.float32)
+ Hz_pad = len(Hz) // 4
+ margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
+ margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
+ margin = margin + constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
+ margin = margin.max(constant([0, 0] * 2, device=device))
+ margin = margin.min(constant([W - 1, H - 1] * 2, device=device))
+ mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
+
+ # Pad image and adjust origin.
+ images = torch.nn.functional.pad(input=images, pad=[mx0, mx1, my0, my1], mode="reflect")
+ G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
+
+ # Upsample.
+ conv_weight = constant(Hz[None, None, ::-1], dtype=images.dtype, device=images.device).tile(
+ [images.shape[1], 1, 1]
+ )
+ conv_pad = (len(Hz) + 1) // 2
+ images = torch.stack([images, torch.zeros_like(images)], dim=4).reshape(
+ N * default(F, 1), C, images.shape[2], -1
+ )[:, :, :, :-1]
+ images = torch.nn.functional.conv2d(
+ images, conv_weight.unsqueeze(2), groups=images.shape[1], padding=[0, conv_pad]
+ )
+ images = torch.stack([images, torch.zeros_like(images)], dim=3).reshape(
+ N * default(F, 1), C, -1, images.shape[3]
+ )[:, :, :-1, :]
+ images = torch.nn.functional.conv2d(
+ images, conv_weight.unsqueeze(3), groups=images.shape[1], padding=[conv_pad, 0]
+ )
+ G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
+ G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
+
+ # Execute transformation.
+ shape = [N * default(F, 1), C, (H + Hz_pad * 2) * 2, (W + Hz_pad * 2) * 2]
+ G_inv = (
+ scale2d(2 / images.shape[3], 2 / images.shape[2], device=device)
+ @ G_inv
+ @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
+ )
+ grid = torch.nn.functional.affine_grid(theta=G_inv[:, :2, :], size=shape, align_corners=False)
+ images = torch.nn.functional.grid_sample(
+ images, grid, mode="bilinear", padding_mode="zeros", align_corners=False
+ )
+
+ # Downsample and crop.
+ conv_weight = constant(Hz[None, None, :], dtype=images.dtype, device=images.device).tile(
+ [images.shape[1], 1, 1]
+ )
+ conv_pad = (len(Hz) - 1) // 2
+ images = torch.nn.functional.conv2d(
+ images, conv_weight.unsqueeze(2), groups=images.shape[1], stride=[1, 2], padding=[0, conv_pad]
+ )[:, :, :, Hz_pad:-Hz_pad]
+ images = torch.nn.functional.conv2d(
+ images, conv_weight.unsqueeze(3), groups=images.shape[1], stride=[2, 1], padding=[conv_pad, 0]
+ )[:, :, Hz_pad:-Hz_pad, :]
+
+ # --------------------------------------------
+ # Select parameters for color transformations.
+ # --------------------------------------------
+
+ I_4 = torch.eye(4, device=device)
+ M = I_4
+ luma_axis = constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device)
+
+ if self.brightness > 0:
+ w = torch.randn([N], device=device)
+ w = torch.where(torch.rand([N], device=device) < self.brightness * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n -> (n f)", f=F)
+ b = w * self.brightness_std
+ M = translate3d(b, b, b) @ M
+ labels += [w]
+
+ if self.contrast > 0:
+ w = torch.randn([N], device=device)
+ w = torch.where(torch.rand([N], device=device) < self.contrast * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n -> (n f)", f=F)
+ c = w.mul(self.contrast_std).exp2()
+ M = scale3d(c, c, c) @ M
+ labels += [w]
+
+ if self.lumaflip > 0:
+ w = torch.randint(2, [N, 1, 1], device=device)
+ w = torch.where(torch.rand([N, 1, 1], device=device) < self.lumaflip * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n 1 1-> (n f) 1 1", f=F)
+ M = (I_4 - 2 * luma_axis.ger(luma_axis) * w) @ M
+ labels += [w]
+
+ if self.hue > 0:
+ w = (torch.rand([N], device=device) * 2 - 1) * (np.pi * self.hue_max)
+ w = torch.where(torch.rand([N], device=device) < self.hue * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n -> (n f)", f=F)
+ M = rotate3d(luma_axis, w) @ M
+ labels += [w.cos() - 1, w.sin()]
+
+ if self.saturation > 0:
+ w = torch.randn([N, 1, 1], device=device)
+ w = torch.where(torch.rand([N, 1, 1], device=device) < self.saturation * self.p, w, torch.zeros_like(w))
+ if repeat_frames:
+ w = repeat(w, "n 1 1-> (n f) 1 1", f=F)
+ M = (luma_axis.ger(luma_axis) + (I_4 - luma_axis.ger(luma_axis)) * w.mul(self.saturation_std).exp2()) @ M
+ labels += [w]
+
+ # ------------------------------
+ # Execute color transformations.
+ # ------------------------------
+
+ if M is not I_4:
+ images = images.reshape([N * default(F, 1), C, H * W])
+ if C == 3:
+ images = M[:, :3, :3] @ images + M[:, :3, 3:]
+ elif C == 1:
+ M = M[:, :3, :].mean(dim=1, keepdims=True)
+ images = images * M[:, :, :3].sum(dim=2, keepdims=True) + M[:, :, 3:]
+ else:
+ raise ValueError("Image must be RGB (3 channels) or L (1 channel)")
+ images = images.reshape([N * default(F, 1), C, H, W])
+
+ labels = torch.cat([x.to(torch.float32).reshape(N, -1) for x in labels], dim=1)
+ if F is not None:
+ images = rearrange(images, "(n f) c h w -> n c f h w", f=F)
+ labels = rearrange(labels, "n (f l) -> n f l", f=F)[:, 0]
+ # assert labels[:, 0].eq(labels[:, 1]).all() # check that all frames have the same labels
+ # labels = labels[:, 0] # its the same for all frames, so we can just take the first one
+
+ return images, labels
+
+
+# ----------------------------------------------------------------------------
diff --git a/sgm/modules/diffusionmodules/denoiser.py b/sgm/modules/diffusionmodules/denoiser.py
new file mode 100644
index 0000000000000000000000000000000000000000..02718847dd52faba4641e6f6436c6af6007fb17f
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser.py
@@ -0,0 +1,391 @@
+from typing import Dict
+
+import torch
+import torch.nn as nn
+from einops import repeat, rearrange
+from ...util import append_dims, instantiate_from_config
+from .denoiser_scaling import DenoiserScaling
+
+
+class DenoiserDub(nn.Module):
+ def __init__(self, scaling_config: Dict, mask_input: bool = True):
+ super().__init__()
+
+ self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
+ self.mask_input = mask_input
+
+ def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
+ return sigma
+
+ def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
+ return c_noise
+
+ def forward(
+ self,
+ network: nn.Module,
+ input: torch.Tensor,
+ sigma: torch.Tensor,
+ cond: Dict,
+ num_overlap_frames: int = 1,
+ num_frames: int = 14,
+ n_skips: int = 1,
+ chunk_size: int = None,
+ **additional_model_inputs,
+ ) -> torch.Tensor:
+ sigma = self.possibly_quantize_sigma(sigma)
+ if input.ndim == 5:
+ T = input.shape[2]
+ input = rearrange(input, "b c t h w -> (b t) c h w")
+ if sigma.shape[0] != input.shape[0]:
+ sigma = repeat(sigma, "b ... -> b t ...", t=T)
+ sigma = rearrange(sigma, "b t ... -> (b t) ...")
+ sigma_shape = sigma.shape
+ sigma = append_dims(sigma, input.ndim)
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
+ gt = cond.get("gt", torch.Tensor([]).type_as(input))
+ if gt.dim() == 5:
+ gt = rearrange(gt, "b c t h w -> (b t) c h w")
+ masks = cond.get("masks", None)
+ if masks.dim() == 5:
+ masks = rearrange(masks, "b c t h w -> (b t) c h w")
+
+ if self.mask_input:
+ input = input * masks + gt * (1.0 - masks)
+
+ if chunk_size is not None:
+ assert chunk_size % num_frames == 0, (
+ "Chunk size should be multiple of num_frames"
+ )
+ out = chunk_network(
+ network,
+ input,
+ c_in,
+ c_noise,
+ cond,
+ additional_model_inputs,
+ chunk_size,
+ num_frames=num_frames,
+ )
+ else:
+ out = network(input * c_in, c_noise, cond, **additional_model_inputs)
+ out = out * c_out + input * c_skip
+ out = out * masks + gt * (1.0 - masks)
+ return out
+
+
+class DenoiserTemporalMultiDiffusion(nn.Module):
+ def __init__(self, scaling_config: Dict, is_dub: bool = False):
+ super().__init__()
+
+ self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
+ self.is_dub = is_dub
+
+ def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
+ return sigma
+
+ def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
+ return c_noise
+
+ def forward(
+ self,
+ network: nn.Module,
+ input: torch.Tensor,
+ sigma: torch.Tensor,
+ cond: Dict,
+ num_overlap_frames: int,
+ num_frames: int,
+ n_skips: int,
+ chunk_size: int = None,
+ **additional_model_inputs,
+ ) -> torch.Tensor:
+ """
+ Args:
+ network: Denoising network
+ input: Noisy input
+ sigma: Noise level
+ cond: Dictionary containing additional information
+ num_overlap_frames: Number of overlapping frames
+ additional_model_inputs: Additional inputs for the denoising network
+ Returns:
+ out: Denoised output
+ This function assumes the input is of shape (B, C, T, H, W) with the B dimension being the number of segments in video.
+ The num_overlap_frames is the number of overlapping frames between the segments to be able to handle the temporal overlap.
+ """
+ sigma = self.possibly_quantize_sigma(sigma)
+ T = num_frames
+ if input.ndim == 5:
+ T = input.shape[2]
+ input = rearrange(input, "b c t h w -> (b t) c h w")
+ if sigma.shape[0] != input.shape[0]:
+ sigma = repeat(sigma, "b ... -> b t ...", t=T)
+ sigma = rearrange(sigma, "b t ... -> (b t) ...")
+ n_skips = n_skips * input.shape[0] // T
+ sigma_shape = sigma.shape
+ sigma = append_dims(sigma, input.ndim)
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
+ if self.is_dub:
+ gt = cond.get("gt", torch.Tensor([]).type_as(input))
+ if gt.dim() == 5:
+ gt = rearrange(gt, "b c t h w -> (b t) c h w")
+ masks = cond.get("masks", None)
+ if masks.dim() == 5:
+ masks = rearrange(masks, "b c t h w -> (b t) c h w")
+ input = input * masks + gt * (1.0 - masks)
+
+ # Now we want to find the overlapping frames and average them
+ input = rearrange(input, "(b t) c h w -> b c t h w", t=T)
+ # Overlapping frames are at begining and end of each segment and given by num_overlap_frames
+ for i in range(input.shape[0] - n_skips):
+ average_frame = torch.stack(
+ [
+ input[i, :, -num_overlap_frames:],
+ input[i + 1, :, :num_overlap_frames],
+ ]
+ ).mean(0)
+ input[i, :, -num_overlap_frames:] = average_frame
+ input[i + n_skips, :, :num_overlap_frames] = average_frame
+
+ input = rearrange(input, "b c t h w -> (b t) c h w")
+
+ if chunk_size is not None:
+ assert chunk_size % num_frames == 0, (
+ "Chunk size should be multiple of num_frames"
+ )
+ out = chunk_network(
+ network,
+ input,
+ c_in,
+ c_noise,
+ cond,
+ additional_model_inputs,
+ chunk_size,
+ num_frames=num_frames,
+ )
+ else:
+ out = network(input * c_in, c_noise, cond, **additional_model_inputs)
+
+ out = out * c_out + input * c_skip
+
+ if self.is_dub:
+ out = out * masks + gt * (1.0 - masks)
+ return out
+
+
+def chunk_network(
+ network,
+ input,
+ c_in,
+ c_noise,
+ cond,
+ additional_model_inputs,
+ chunk_size,
+ num_frames=1,
+):
+ out = []
+
+ for i in range(0, input.shape[0], chunk_size):
+ start_idx = i
+ end_idx = i + chunk_size
+
+ input_chunk = input[start_idx:end_idx]
+ c_in_chunk = (
+ c_in[start_idx:end_idx]
+ if c_in.shape[0] == input.shape[0]
+ else c_in[start_idx // num_frames : end_idx // num_frames]
+ )
+ c_noise_chunk = (
+ c_noise[start_idx:end_idx]
+ if c_noise.shape[0] == input.shape[0]
+ else c_noise[start_idx // num_frames : end_idx // num_frames]
+ )
+
+ cond_chunk = {}
+ for k, v in cond.items():
+ if isinstance(v, torch.Tensor) and v.shape[0] == input.shape[0]:
+ cond_chunk[k] = v[start_idx:end_idx]
+ elif isinstance(v, torch.Tensor):
+ cond_chunk[k] = v[start_idx // num_frames : end_idx // num_frames]
+ else:
+ cond_chunk[k] = v
+
+ additional_model_inputs_chunk = {}
+ for k, v in additional_model_inputs.items():
+ if isinstance(v, torch.Tensor):
+ or_size = v.shape[0]
+ additional_model_inputs_chunk[k] = repeat(
+ v,
+ "b c -> (b t) c",
+ t=(input_chunk.shape[0] // num_frames // or_size) + 1,
+ )[: cond_chunk["concat"].shape[0]]
+ else:
+ additional_model_inputs_chunk[k] = v
+
+ out.append(
+ network(
+ input_chunk * c_in_chunk,
+ c_noise_chunk,
+ cond_chunk,
+ **additional_model_inputs_chunk,
+ )
+ )
+
+ return torch.cat(out, dim=0)
+
+
+class KarrasTemporalMultiDiffusion(DenoiserTemporalMultiDiffusion):
+ def __init__(self, scaling_config: Dict):
+ super().__init__(scaling_config)
+ self.bad_network = None
+
+ def set_bad_network(self, bad_network: nn.Module):
+ self.bad_network = bad_network
+
+ def split_inputs(
+ self, input: torch.Tensor, cond: Dict, additional_model_inputs
+ ) -> torch.Tensor:
+ half_input = input.shape[0] // 2
+ first_cond_half = {}
+ second_cond_half = {}
+ for k, v in cond.items():
+ if isinstance(v, torch.Tensor):
+ half_cond = v.shape[0] // 2
+ first_cond_half[k] = v[:half_cond]
+ second_cond_half[k] = v[half_cond:]
+ elif isinstance(v, list):
+ half_add = v[0].shape[0] // 2
+ first_cond_half[k] = [v[i][:half_add] for i in range(len(v))]
+ second_cond_half[k] = [v[i][half_add:] for i in range(len(v))]
+ else:
+ first_cond_half[k] = v
+ second_cond_half[k] = v
+
+ add_good = {}
+ add_bad = {}
+ for k, v in additional_model_inputs.items():
+ if isinstance(v, torch.Tensor):
+ half_add = v.shape[0] // 2
+ add_good[k] = v[:half_add]
+ add_bad[k] = v[half_add:]
+ elif isinstance(v, list):
+ half_add = v[0].shape[0] // 2
+ add_good[k] = [v[i][:half_add] for i in range(len(v))]
+ add_bad[k] = [v[i][half_add:] for i in range(len(v))]
+ else:
+ add_good[k] = v
+ add_bad[k] = v
+
+ return (
+ input[:half_input],
+ input[half_input:],
+ first_cond_half,
+ second_cond_half,
+ add_good,
+ add_bad,
+ )
+
+ def forward(
+ self,
+ network: nn.Module,
+ input: torch.Tensor,
+ sigma: torch.Tensor,
+ cond: Dict,
+ num_overlap_frames: int,
+ num_frames: int,
+ n_skips: int,
+ chunk_size: int = None,
+ **additional_model_inputs,
+ ) -> torch.Tensor:
+ """
+ Args:
+ network: Denoising network
+ input: Noisy input
+ sigma: Noise level
+ cond: Dictionary containing additional information
+ num_overlap_frames: Number of overlapping frames
+ additional_model_inputs: Additional inputs for the denoising network
+ Returns:
+ out: Denoised output
+ This function assumes the input is of shape (B, C, T, H, W) with the B dimension being the number of segments in video.
+ The num_overlap_frames is the number of overlapping frames between the segments to be able to handle the temporal overlap.
+ """
+ sigma = self.possibly_quantize_sigma(sigma)
+ T = num_frames
+ if input.ndim == 5:
+ T = input.shape[2]
+ input = rearrange(input, "b c t h w -> (b t) c h w")
+ if sigma.shape[0] != input.shape[0]:
+ sigma = repeat(sigma, "b ... -> b t ...", t=T)
+ sigma = rearrange(sigma, "b t ... -> (b t) ...")
+ n_skips = n_skips * input.shape[0] // T
+ sigma_shape = sigma.shape
+ sigma = append_dims(sigma, input.ndim)
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
+ if self.is_dub:
+ gt = cond.get("gt", torch.Tensor([]).type_as(input))
+ if gt.dim() == 5:
+ gt = rearrange(gt, "b c t h w -> (b t) c h w")
+ masks = cond.get("masks", None)
+ if masks.dim() == 5:
+ masks = rearrange(masks, "b c t h w -> (b t) c h w")
+ input = input * masks + gt * (1.0 - masks)
+
+ # Now we want to find the overlapping frames and average them
+ input = rearrange(input, "(b t) c h w -> b c t h w", t=T)
+ # Overlapping frames are at begining and end of each segment and given by num_overlap_frames
+ for i in range(input.shape[0] - n_skips):
+ average_frame = torch.stack(
+ [
+ input[i, :, -num_overlap_frames:],
+ input[i + 1, :, :num_overlap_frames],
+ ]
+ ).mean(0)
+ input[i, :, -num_overlap_frames:] = average_frame
+ input[i + n_skips, :, :num_overlap_frames] = average_frame
+
+ input = rearrange(input, "b c t h w -> (b t) c h w")
+
+ half = c_in.shape[0] // 2
+ in_bad, in_good, cond_bad, cond_good, add_inputs_good, add_inputs_bad = (
+ self.split_inputs(input, cond, additional_model_inputs)
+ )
+ if chunk_size is not None:
+ assert chunk_size % num_frames == 0, (
+ "Chunk size should be multiple of num_frames"
+ )
+ out = chunk_network(
+ network,
+ in_good,
+ c_in[half:],
+ c_noise[half:],
+ cond_good,
+ add_inputs_good,
+ chunk_size,
+ num_frames=num_frames,
+ )
+ bad_out = chunk_network(
+ self.bad_network,
+ in_bad,
+ c_in[:half],
+ c_noise[:half],
+ cond_bad,
+ add_inputs_bad,
+ chunk_size,
+ num_frames=num_frames,
+ )
+ else:
+ out = network(
+ in_good * c_in[half:], c_noise[half:], cond_good, **add_inputs_good
+ )
+ bad_out = self.bad_network(
+ in_bad * c_in[:half], c_noise[:half], cond_bad, **add_inputs_bad
+ )
+ out = torch.cat([bad_out, out], dim=0)
+
+ out = out * c_out + input * c_skip
+
+ if self.is_dub:
+ out = out * masks + gt * (1.0 - masks)
+ return out
diff --git a/sgm/modules/diffusionmodules/denoiser_scaling.py b/sgm/modules/diffusionmodules/denoiser_scaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ef050ae81e8b0dbb984f42be3b1071d020abab6
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser_scaling.py
@@ -0,0 +1,58 @@
+from abc import ABC, abstractmethod
+from typing import Tuple
+
+import torch
+
+
+class DenoiserScaling(ABC):
+ @abstractmethod
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ pass
+
+
+class IdentityScaling(DenoiserScaling):
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = torch.ones_like(sigma, device=sigma.device)
+ c_out = torch.ones_like(sigma, device=sigma.device)
+ c_in = torch.ones_like(sigma, device=sigma.device)
+ c_noise = torch.zeros_like(sigma, device=sigma.device)
+ return c_skip, c_out, c_in, c_noise
+
+
+class EDMScaling:
+ def __init__(self, sigma_data: float = 0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_noise = 0.25 * sigma.log()
+ return c_skip, c_out, c_in, c_noise
+
+
+class EpsScaling:
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = torch.ones_like(sigma, device=sigma.device)
+ c_out = -sigma
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VScaling:
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = 1.0 / (sigma**2 + 1.0)
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VScalingWithEDMcNoise(DenoiserScaling):
+ def __call__(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = 1.0 / (sigma**2 + 1.0)
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+ c_noise = 0.25 * sigma.log()
+ return c_skip, c_out, c_in, c_noise
diff --git a/sgm/modules/diffusionmodules/denoiser_weighting.py b/sgm/modules/diffusionmodules/denoiser_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser_weighting.py
@@ -0,0 +1,24 @@
+import torch
+
+
+class UnitWeighting:
+ def __call__(self, sigma):
+ return torch.ones_like(sigma, device=sigma.device)
+
+
+class EDMWeighting:
+ def __init__(self, sigma_data=0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma):
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+
+class VWeighting(EDMWeighting):
+ def __init__(self):
+ super().__init__(sigma_data=1.0)
+
+
+class EpsWeighting:
+ def __call__(self, sigma):
+ return sigma**-2.0
diff --git a/sgm/modules/diffusionmodules/diffuser_unet.py b/sgm/modules/diffusionmodules/diffuser_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdb965f1841876136fbfb44e8da30df201c30df3
--- /dev/null
+++ b/sgm/modules/diffusionmodules/diffuser_unet.py
@@ -0,0 +1,1430 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+import torch.nn.functional as F
+
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ GLIGENTextBoundingboxProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unet_2d_blocks import (
+ get_down_block,
+ get_mid_block,
+ get_up_block,
+)
+from diffusers.models.attention_processor import IPAdapterAttnProcessor, AttnProcessor2_0
+from einops import rearrange
+from contextlib import nullcontext
+
+from ...modules.diffusionmodules.attention_processors import IPAdapterAttnProcessor2_0
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads: int = 64,
+ audio_cond_method: str = None,
+ audio_emb_dim: int = 1280,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ self._check_config(
+ down_block_types=down_block_types,
+ up_block_types=up_block_types,
+ only_cross_attention=only_cross_attention,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
+ time_embedding_type,
+ block_out_channels=block_out_channels,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ time_embedding_dim=time_embedding_dim,
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ self.audio_cond_method = audio_cond_method
+ if audio_cond_method == "to_time_emb":
+ self.audio_emb_proj = nn.Sequential(
+ nn.Linear(audio_emb_dim, time_embed_dim),
+ get_activation(act_fn),
+ nn.Linear(time_embed_dim, time_embed_dim),
+ )
+
+ self._set_encoder_hid_proj(
+ encoder_hid_dim_type,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ )
+
+ # class embedding
+ self._set_class_embedding(
+ class_embed_type,
+ act_fn=act_fn,
+ num_class_embeds=num_class_embeds,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ timestep_input_dim=timestep_input_dim,
+ )
+
+ self._set_add_embedding(
+ addition_embed_type,
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
+ addition_time_embed_dim=addition_time_embed_dim,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ )
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ temb_channels=blocks_time_embed_dim,
+ in_channels=block_out_channels[-1],
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ output_scale_factor=mid_block_scale_factor,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ num_attention_heads=num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[-1],
+ dropout=dropout,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = (
+ list(reversed(transformer_layers_per_block))
+ if reverse_transformer_layers_per_block is None
+ else reverse_transformer_layers_per_block
+ )
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
+
+ def _check_config(
+ self,
+ down_block_types: Tuple[str],
+ up_block_types: Tuple[str],
+ only_cross_attention: Union[bool, Tuple[bool]],
+ block_out_channels: Tuple[int],
+ layers_per_block: Union[int, Tuple[int]],
+ cross_attention_dim: Union[int, Tuple[int]],
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
+ reverse_transformer_layers_per_block: bool,
+ attention_head_dim: int,
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
+ ):
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+ for layer_number_per_block in transformer_layers_per_block:
+ if isinstance(layer_number_per_block, list):
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+ def _set_time_proj(
+ self,
+ time_embedding_type: str,
+ block_out_channels: int,
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ time_embedding_dim: int,
+ ) -> Tuple[int, int]:
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ return time_embed_dim, timestep_input_dim
+
+ def _set_encoder_hid_proj(
+ self,
+ encoder_hid_dim_type: Optional[str],
+ cross_attention_dim: Union[int, Tuple[int]],
+ encoder_hid_dim: Optional[int],
+ ):
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ def _set_class_embedding(
+ self,
+ class_embed_type: Optional[str],
+ act_fn: str,
+ num_class_embeds: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ timestep_input_dim: int,
+ ):
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ def _set_add_embedding(
+ self,
+ addition_embed_type: str,
+ addition_embed_type_num_heads: int,
+ addition_time_embed_dim: Optional[int],
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ cross_attention_dim: Optional[int],
+ encoder_hid_dim: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ ):
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = GLIGENTextBoundingboxProjection(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def unload_lora(self):
+ """Unloads LoRA weights."""
+ deprecate(
+ "unload_lora",
+ "0.28.0",
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
+ )
+ for module in self.modules():
+ if hasattr(module, "set_lora_layer"):
+ module.set_lora_layer(None)
+
+ def get_time_embed(
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
+ ) -> Optional[torch.Tensor]:
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+ return t_emb
+
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ class_emb = None
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+ return class_emb
+
+ def get_aug_embed(
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> Optional[torch.Tensor]:
+ aug_emb = None
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb = self.add_embedding(image_embs, hint)
+ return aug_emb
+
+ def process_encoder_hidden_states(
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> torch.Tensor:
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds)
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
+ return encoder_hidden_states
+
+ def set_ip_adapter_scale(self, scale):
+ """
+ Sets the conditioning scale between text and image.
+
+ Example:
+
+ ```py
+ pipeline.set_ip_adapter_scale(0.5)
+ ```
+ """
+ # unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+ for attn_processor in self.attn_processors.values():
+ if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
+ if not isinstance(scale, list):
+ scale = [scale] * len(attn_processor.scale)
+ if len(attn_processor.scale) != len(scale):
+ raise ValueError(
+ f"`scale` should be a list of same length as the number if ip-adapters "
+ f"Expected {len(attn_processor.scale)} but got {len(scale)}."
+ )
+ attn_processor.scale = scale
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ audio_emb: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.audio_cond_method == "to_time_emb":
+ audio_emb = rearrange(audio_emb, "b t c -> (b t) c")
+ emb = emb + self.audio_emb_proj(audio_emb)
+ elif self.audio_cond_method == "cross_attention":
+ if audio_emb.ndim == 4:
+ audio_emb = rearrange(audio_emb, "b t d c -> b (t d) c")
+ # print("audio_emb", audio_emb.shape)
+ # print("encoder_hidden_states", encoder_hidden_states.shape)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, audio_emb], dim=1)
+ elif self.audio_cond_method == "ip_adapter":
+ if isinstance(added_cond_kwargs["image_embeds"], list):
+ added_cond_kwargs["image_embeds"].append(audio_emb)
+ else:
+ added_cond_kwargs["image_embeds"] = [added_cond_kwargs["image_embeds"], audio_emb]
+
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
+ if class_emb is not None:
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ aug_emb = self.get_aug_embed(
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+ if self.config.addition_embed_type == "image_hint":
+ aug_emb, hint = aug_emb
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
+ if cross_attention_kwargs is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
+
+ def convert_ip_adapter_attn_to_diffusers_and_load(self, state_dicts, **attn_kwargs):
+ # from ..models.attention_processor import (
+ # AttnProcessor,
+ # AttnProcessor2_0,
+ # IPAdapterAttnProcessor,
+ # IPAdapterAttnProcessor2_0,
+ # )
+
+ # set ip-adapter cross-attention processors & load state_dict
+ attn_procs = {}
+ key_id = 1
+ init_context = nullcontext
+ for name in self.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = self.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = self.config.block_out_channels[block_id]
+
+ if cross_attention_dim is None or "motion_modules" in name:
+ attn_processor_class = (
+ AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
+ )
+ attn_procs[name] = attn_processor_class()
+ else:
+ attn_processor_class = (
+ IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
+ )
+ num_image_text_embeds = []
+ for state_dict in state_dicts:
+ if "proj.weight" in state_dict["image_proj"]:
+ # IP-Adapter
+ num_image_text_embeds += [4]
+ elif "proj.3.weight" in state_dict["image_proj"]:
+ # IP-Adapter Full Face
+ num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
+ else:
+ # IP-Adapter Plus
+ num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
+
+ with init_context():
+ attn_procs[name] = attn_processor_class(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ scale=1.0,
+ num_tokens=num_image_text_embeds,
+ **attn_kwargs,
+ )
+
+ value_dict = {}
+ for i, state_dict in enumerate(state_dicts):
+ value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
+ value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
+
+ attn_procs[name].load_state_dict(value_dict)
+
+ key_id += 2
+
+ self.set_attn_processor(attn_procs)
diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..19b5617cf5d1736c7243e6987bd74150ab670a89
--- /dev/null
+++ b/sgm/modules/diffusionmodules/discretizer.py
@@ -0,0 +1,94 @@
+from abc import abstractmethod
+from functools import partial
+
+import numpy as np
+import torch
+
+from ...modules.diffusionmodules.util import make_beta_schedule
+from ...util import append_zero
+
+
+def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray:
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
+
+
+class Discretization:
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
+ sigmas = self.get_sigmas(n, device=device)
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
+ return sigmas if not flip else torch.flip(sigmas, (0,))
+
+ @abstractmethod
+ def get_sigmas(self, n, device):
+ pass
+
+
+class EDMDiscretization(Discretization):
+ def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
+ self.sigma_min = sigma_min
+ self.sigma_max = sigma_max
+ self.rho = rho
+
+ def get_sigmas(self, n, device="cpu"):
+ ramp = torch.linspace(0, 1, n, device=device)
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
+ print(sigmas)
+ return sigmas
+
+
+class AYSDiscretization(Discretization):
+ def __init__(self):
+ self.sigma_min = 0.002
+ self.sigma_max = 700.0
+ self.base_sigmas = np.array([700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002])
+
+ def loglinear_interp(self, t_steps, num_steps):
+ """
+ Performs log-linear interpolation of a given array of decreasing numbers.
+ """
+ xs = np.linspace(0, 1, len(t_steps))
+ ys = np.log(t_steps[::-1])
+
+ new_xs = np.linspace(0, 1, num_steps)
+ new_ys = np.interp(new_xs, xs, ys)
+
+ interped_ys = np.exp(new_ys)[::-1].copy()
+ return interped_ys
+
+ def get_sigmas(self, n, device="cpu"):
+ assert n >= 10, "Number of timesteps must be greater than 10 for AYS discretization."
+ if n > 10:
+ sigmas = self.loglinear_interp(self.base_sigmas, n)
+ else:
+ sigmas = self.base_sigmas
+ return torch.from_numpy(sigmas).to(device)
+
+
+class LegacyDDPMDiscretization(Discretization):
+ def __init__(
+ self,
+ linear_start=0.00085,
+ linear_end=0.0120,
+ num_timesteps=1000,
+ ):
+ super().__init__()
+ self.num_timesteps = num_timesteps
+ betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ def get_sigmas(self, n, device="cpu"):
+ if n < self.num_timesteps:
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
+ alphas_cumprod = self.alphas_cumprod[timesteps]
+ elif n == self.num_timesteps:
+ alphas_cumprod = self.alphas_cumprod
+ else:
+ raise ValueError
+
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ return torch.flip(sigmas, (0,))
diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9be20ccebb0d3e9387c32a1a314fd38eaaf30bf
--- /dev/null
+++ b/sgm/modules/diffusionmodules/guiders.py
@@ -0,0 +1,538 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+from ...util import append_dims, default
+
+logpy = logging.getLogger(__name__)
+
+
+class Guider(ABC):
+ @abstractmethod
+ def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
+ pass
+
+ def prepare_inputs(
+ self, x: torch.Tensor, s: float, c: Dict, uc: Dict
+ ) -> Tuple[torch.Tensor, float, Dict]:
+ pass
+
+
+class VanillaCFG(Guider):
+ def __init__(
+ self, scale: float, low_sigma: float = 0.0, high_sigma: float = float("inf")
+ ):
+ self.scale = scale
+ self.low_sigma = low_sigma
+ self.high_sigma = high_sigma
+
+ def set_scale(self, scale: float):
+ self.scale = scale
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_u, x_c = x.chunk(2)
+
+ x_pred = x_u + self.scale * (x_c - x_u)
+ return x_pred
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ if k in [
+ "vector",
+ "crossattn",
+ "concat",
+ "audio_emb",
+ "image_embeds",
+ "landmarks",
+ "masks",
+ "gt",
+ "valence",
+ "arousal",
+ ]:
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
+ elif k == "reference":
+ c_out["reference"] = []
+ for i in range(len(c[k])):
+ c_out["reference"].append(torch.cat((uc[k][i], c[k][i]), 0))
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
+
+
+class VanillaSTG(Guider):
+ def __init__(
+ self,
+ scale_spatial: float,
+ scale_temporal: float,
+ low_sigma: float = 0.0,
+ high_sigma: float = float("inf"),
+ layer_skip: int = 8,
+ ):
+ self.scale_spatial = scale_spatial
+ self.scale_temporal = scale_temporal
+ self.low_sigma = low_sigma
+ self.high_sigma = high_sigma
+ self.layer_skip = layer_skip
+
+ def set_scale(self, scale_spatial: float, scale_temporal: float):
+ self.scale_spatial = scale_spatial
+ self.scale_temporal = scale_temporal
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_c, x_spatial, x_temporal = x.chunk(3)
+ x_pred = (
+ x_c
+ + self.scale_spatial * (x_c - x_spatial)
+ + self.scale_temporal * (x_c - x_temporal)
+ )
+ return x_pred
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ if k in [
+ "vector",
+ "crossattn",
+ "concat",
+ "audio_emb",
+ "image_embeds",
+ "landmarks",
+ "masks",
+ "gt",
+ "valence",
+ "arousal",
+ ]:
+ c_out[k] = torch.cat((c[k], c[k], c[k]), 0)
+ elif k == "reference":
+ c_out["reference"] = []
+ for i in range(len(c[k])):
+ c_out["reference"].append(torch.cat((c[k][i], c[k][i], c[k][i]), 0))
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+
+ c_out["skip_spatial_attention_at"] = [None, self.layer_skip, None]
+ c_out["skip_temporal_attention_at"] = [None, None, self.layer_skip]
+
+ return torch.cat([x] * 3), torch.cat([s] * 3), c_out
+
+
+class MomentumBuffer:
+ def __init__(self, momentum: float):
+ self.momentum = momentum
+ self.running_average = 0
+
+ def update(self, update_value: torch.Tensor):
+ new_average = self.momentum * self.running_average
+ self.running_average = update_value + new_average
+
+
+def project(v0: torch.Tensor, v1: torch.Tensor):
+ dtype = v0.dtype
+ v0, v1 = v0.double(), v1.double()
+ v1 = F.normalize(v1, dim=[-1, -2, -3])
+ v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
+ v0_orthogonal = v0 - v0_parallel
+ return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
+
+
+def normalized_guidance(
+ pred_cond: torch.Tensor,
+ pred_uncond: torch.Tensor,
+ guidance_scale: float,
+ momentum_buffer: MomentumBuffer = None,
+ eta: float = 1.0,
+ norm_threshold: float = 0.0,
+):
+ diff = pred_cond - pred_uncond
+ if momentum_buffer is not None:
+ momentum_buffer.update(diff)
+ diff = momentum_buffer.running_average
+ if norm_threshold > 0:
+ ones = torch.ones_like(diff)
+ diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
+ diff = diff * scale_factor
+ diff_parallel, diff_orthogonal = project(diff, pred_cond)
+ normalized_update = diff_orthogonal + eta * diff_parallel
+ pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
+ return pred_guided
+
+
+class APGGuider(VanillaCFG):
+ def __init__(
+ self,
+ scale: float,
+ momentum: float = -0.75,
+ eta: float = 0.0,
+ norm_threshold: float = 2.5,
+ ):
+ super().__init__(scale)
+ self.momentum_buffer = MomentumBuffer(momentum)
+ self.eta = eta
+ self.norm_threshold = norm_threshold
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_u, x_c = x.chunk(2)
+ return normalized_guidance(
+ x_c, x_u, self.scale, self.momentum_buffer, self.eta, self.norm_threshold
+ )
+
+
+class VanillaCFGplusplus(VanillaCFG):
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_u, x_c = x.chunk(2)
+
+ x_pred = x_u + self.scale * (x_c - x_u)
+ return x_pred, x_u
+
+
+class KarrasGuider(VanillaCFG):
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ if k in [
+ "vector",
+ "crossattn",
+ "concat",
+ "audio_emb",
+ "image_embeds",
+ "landmarks",
+ "valence",
+ "arousal",
+ ]:
+ c_out[k] = torch.cat((c[k], c[k]), 0)
+ elif k == "reference":
+ c_out["reference"] = []
+ for i in range(len(c[k])):
+ c_out["reference"].append(torch.cat((c[k][i], c[k][i]), 0))
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
+
+
+class MultipleCondVanilla(Guider):
+ def __init__(self, scales, condition_names) -> None:
+ assert len(scales) == len(condition_names)
+ self.scales = scales
+ # self.condition_names = condition_names
+ self.n_conditions = len(scales)
+ self.map_cond_name = {
+ "audio_emb": "audio_emb",
+ "cond_frames_without_noise": "crossattn",
+ "cond_frames": "concat",
+ }
+ self.condition_names = [
+ self.map_cond_name.get(cond_name, cond_name)
+ for cond_name in condition_names
+ ]
+ print("Condition names: ", self.condition_names)
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ outs = x.chunk(self.n_conditions + 1)
+ x_full_cond = outs[0]
+ x_pred = (1 + sum(self.scales)) * x_full_cond
+ for i, scale in enumerate(self.scales):
+ x_pred -= scale * outs[i + 1]
+ return x_pred
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ # The first element is the full condition
+ for k in c:
+ if k in [
+ "vector",
+ "crossattn",
+ "concat",
+ "audio_emb",
+ "image_embeds",
+ "landmarks",
+ "masks",
+ "gt",
+ ]:
+ c_out[k] = c[k]
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+
+ # The rest are the conditions removed from the full condition
+ for cond_name in self.condition_names:
+ if not isinstance(cond_name, list):
+ cond_name = [cond_name]
+ for k in c:
+ if k in [
+ "vector",
+ "crossattn",
+ "concat",
+ "audio_emb",
+ "image_embeds",
+ "landmarks",
+ "masks",
+ "gt",
+ ]:
+ c_out[k] = torch.cat(
+ (c_out[k], uc[k] if k in cond_name else c[k]), 0
+ )
+
+ return (
+ torch.cat([x] * (self.n_conditions + 1)),
+ torch.cat([s] * (self.n_conditions + 1)),
+ c_out,
+ )
+
+
+class AudioRefMultiCondGuider(MultipleCondVanilla):
+ def __init__(
+ self,
+ audio_ratio: float = 5.0,
+ ref_ratio: float = 3.0,
+ use_normalized: bool = False,
+ momentum: float = -0.75,
+ eta: float = 0.0,
+ norm_threshold: float = 2.5,
+ ):
+ super().__init__(
+ scales=[audio_ratio, ref_ratio], condition_names=["audio_emb", "concat"]
+ )
+ self.audio_ratio = audio_ratio
+ self.ref_ratio = ref_ratio
+ self.use_normalized = use_normalized
+ print(f"Use normalized: {self.use_normalized}")
+ self.momentum_buffer = MomentumBuffer(momentum)
+ self.eta = eta
+ self.norm_threshold = norm_threshold
+ self.momentum_buffer_audio = MomentumBuffer(momentum)
+ self.momentum_buffer_ref = MomentumBuffer(momentum)
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ e_uc, e_ref, c_audio_ref = x.chunk(3)
+
+ if self.use_normalized:
+ # Normalized guidance version
+ # Compute diff for audio guidance
+ diff_audio = c_audio_ref - e_uc
+ if self.momentum_buffer_audio is not None:
+ self.momentum_buffer_audio.update(diff_audio)
+ diff_audio = self.momentum_buffer_audio.running_average
+ if self.norm_threshold > 0:
+ ones = torch.ones_like(diff_audio)
+ diff_norm = diff_audio.norm(p=2, dim=[-1, -2, -3], keepdim=True)
+ scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm)
+ diff_audio = diff_audio * scale_factor
+ diff_audio_parallel, diff_audio_orthogonal = project(
+ diff_audio, c_audio_ref
+ )
+ normalized_update_audio = (
+ diff_audio_orthogonal + self.eta * diff_audio_parallel
+ )
+ guidance_audio = (self.audio_ratio - 1) * normalized_update_audio
+
+ # Compute diff for ref guidance
+ diff_ref = e_ref - e_uc
+ if self.momentum_buffer_ref is not None:
+ self.momentum_buffer_ref.update(diff_ref)
+ diff_ref = self.momentum_buffer_ref.running_average
+ if self.norm_threshold > 0:
+ ones = torch.ones_like(diff_ref)
+ diff_norm = diff_ref.norm(p=2, dim=[-1, -2, -3], keepdim=True)
+ scale_factor = torch.minimum(ones, self.norm_threshold / diff_norm)
+ diff_ref = diff_ref * scale_factor
+ diff_ref_parallel, diff_ref_orthogonal = project(diff_ref, e_ref)
+ normalized_update_ref = diff_ref_orthogonal + self.eta * diff_ref_parallel
+ guidance_ref = (self.ref_ratio - 1) * normalized_update_ref
+
+ e_final = e_uc + guidance_audio + guidance_ref
+ else:
+ # Original version
+ e_final = (
+ self.audio_ratio * (c_audio_ref - e_ref)
+ + self.ref_ratio * (e_ref - e_uc)
+ + e_uc
+ )
+
+ return e_final
+
+ def set_scale(self, scale: torch.Tensor):
+ self.audio_ratio = float(scale[0])
+ self.ref_ratio = float(scale[1])
+ print(f"Audio ratio: {self.audio_ratio}, Ref ratio: {self.ref_ratio}")
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ # Prepare inputs for e_base (no audio, no ref concat)
+ c_base = {k: v for k, v in c.items()}
+ c_base["crossattn"] = uc["crossattn"]
+ c_base["concat"] = uc["concat"] # Remove ref concat
+
+ # Prepare inputs for e_ref (no audio, with ref concat)
+ c_audio_ref = {k: v for k, v in c.items()}
+ # c_ref["concat"] = uc["concat"] # Remove ref concat
+
+ # Prepare inputs for e_audio (all conditions)
+ c_ref = {k: v for k, v in c.items()}
+ c_ref["crossattn"] = uc["crossattn"]
+
+ # Combine all conditions
+ for k in c:
+ if k in [
+ "vector",
+ "crossattn",
+ "concat",
+ "audio_emb",
+ "image_embeds",
+ "landmarks",
+ "masks",
+ "gt",
+ ]:
+ c_out[k] = torch.cat((c_base[k], c_ref[k], c_audio_ref[k]), 0)
+ else:
+ c_out[k] = c[k]
+
+ return torch.cat([x] * 3), torch.cat([s] * 3), c_out
+
+
+class IdentityGuider(Guider):
+ def __init__(self, *args, **kwargs):
+ # self.num_frames = num_frames
+ pass
+
+ def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
+ return x
+
+ def prepare_inputs(
+ self, x: torch.Tensor, s: float, c: Dict, uc: Dict
+ ) -> Tuple[torch.Tensor, float, Dict]:
+ c_out = dict()
+
+ for k in c:
+ c_out[k] = c[k]
+
+ return x, s, c_out
+
+
+class LinearPredictionGuider(Guider):
+ def __init__(
+ self,
+ max_scale: float,
+ num_frames: int,
+ min_scale: float = 1.0,
+ additional_cond_keys: Optional[Union[List[str], str]] = None,
+ only_first=False,
+ ):
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.num_frames = num_frames
+ self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
+
+ self.only_first = only_first
+ if only_first:
+ self.scale = torch.ones_like(self.scale) * max_scale
+ self.scale[:, 0] = min_scale
+
+ additional_cond_keys = default(additional_cond_keys, [])
+ if isinstance(additional_cond_keys, str):
+ additional_cond_keys = [additional_cond_keys]
+ self.additional_cond_keys = additional_cond_keys
+
+ def set_scale(self, scale: torch.Tensor):
+ self.min_scale = scale
+ self.scale = torch.linspace(
+ self.min_scale, self.max_scale, self.num_frames
+ ).unsqueeze(0)
+
+ if self.only_first:
+ self.scale = torch.ones_like(self.scale) * self.max_scale
+ self.scale[:, 0] = self.min_scale
+
+ print(self.scale)
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_u, x_c = x.chunk(2)
+
+ x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
+ x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
+ scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
+ scale = append_dims(scale, x_u.ndim).to(x_u.device)
+ return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
+
+ def prepare_inputs(
+ self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ c_out = dict()
+
+ for k in c:
+ if (
+ k
+ in ["vector", "crossattn", "concat", "audio_emb", "masks", "gt"]
+ + self.additional_cond_keys
+ ):
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
+
+
+class LinearPredictionGuiderPlus(LinearPredictionGuider):
+ def __init__(
+ self,
+ max_scale: float,
+ num_frames: int,
+ min_scale: float = 1.0,
+ additional_cond_keys: Optional[Union[List[str], str]] = None,
+ ):
+ super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_u, x_c = x.chunk(2)
+
+ x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
+ x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
+ scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
+ scale = append_dims(scale, x_u.ndim).to(x_u.device)
+ return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ..."), x_u
+
+
+class TrianglePredictionGuider(LinearPredictionGuider):
+ def __init__(
+ self,
+ max_scale: float,
+ num_frames: int,
+ min_scale: float = 1.0,
+ period: float | List[float] = 1.0,
+ period_fusing: Literal["mean", "multiply", "max"] = "max",
+ additional_cond_keys: Optional[Union[List[str], str]] = None,
+ ):
+ super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
+ values = torch.linspace(0, 1, num_frames)
+ # Constructs a triangle wave
+ if isinstance(period, float):
+ period = [period]
+
+ scales = []
+ for p in period:
+ scales.append(self.triangle_wave(values, p))
+
+ if period_fusing == "mean":
+ scale = sum(scales) / len(period)
+ elif period_fusing == "multiply":
+ scale = torch.prod(torch.stack(scales), dim=0)
+ elif period_fusing == "max":
+ scale = torch.max(torch.stack(scales), dim=0).values
+ self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
+
+ def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
+ return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e5b99ca1d08fa2e3cc03aed266e32087d416f81
--- /dev/null
+++ b/sgm/modules/diffusionmodules/loss.py
@@ -0,0 +1,408 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+import math
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+import lpips
+from facenet_pytorch import InceptionResnetV1
+
+from ...modules.autoencoding.lpips.loss.lpips import LPIPS
+from ...modules.encoders.modules import GeneralConditioner, ConcatTimestepEmbedderND
+from ...util import append_dims, instantiate_from_config, default
+from ...modules.autoencoding.temporal_ae import VideoDecoder
+from ...data.data_utils import extract_face
+
+
+def logit_normal_sampler(m, s=1, beta_m=15, sample_num=1000000):
+ y_samples = torch.randn(sample_num) * s + m
+ x_samples = beta_m * (torch.exp(y_samples) / (1 + torch.exp(y_samples)))
+ return x_samples
+
+
+def mu_t(t, a=5, mu_max=1):
+ t = t.to("cpu")
+ return 2 * mu_max * t**a - mu_max
+
+
+def get_sigma_s(t, a, beta_m):
+ mu = mu_t(t, a=a)
+ sigma_s = logit_normal_sampler(m=mu, sample_num=t.shape[0], beta_m=beta_m)
+ return sigma_s
+
+
+class StandardDiffusionLoss(nn.Module):
+ def __init__(
+ self,
+ sigma_sampler_config: dict,
+ loss_weighting_config: dict,
+ loss_type: str = "l2",
+ offset_noise_level: float = 0.0,
+ batch2model_keys: Optional[Union[str, List[str]]] = None,
+ lambda_lower: float = 1.0,
+ lambda_upper: float = 1.0,
+ fix_image_leak: bool = False,
+ add_lpips: bool = False,
+ weight_pixel: float = 0.0,
+ n_frames_pixel: Optional[int] = 1,
+ what_pixel_losses: Optional[List[str]] = [],
+ disable_first_stage_autocast: bool = True,
+ ):
+ super().__init__()
+
+ assert loss_type in ["l2", "l1", "lpips"]
+
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
+ self.loss_weighting = instantiate_from_config(loss_weighting_config)
+
+ self.loss_type = loss_type
+ self.offset_noise_level = offset_noise_level
+ self.lambda_lower = lambda_lower
+ self.lambda_upper = lambda_upper
+ self.add_lpips = add_lpips
+ self.weight_pixel = weight_pixel
+ self.n_frames_pixel = n_frames_pixel
+ self.what_pixel_losses = what_pixel_losses
+
+ self.en_and_decode_n_samples_a_time = 1
+ self.disable_first_stage_autocast = disable_first_stage_autocast
+
+ if loss_type == "lpips":
+ self.lpips = LPIPS().eval()
+
+ if add_lpips or "lpips" in what_pixel_losses:
+ self.lpips = lpips.LPIPS(net="vgg").eval()
+
+ if "id" in what_pixel_losses or "id_mse" in what_pixel_losses:
+ self.id_model = InceptionResnetV1(pretrained="vggface2").eval().cuda()
+ for param in self.id_model.parameters():
+ param.requires_grad = False
+
+ if not batch2model_keys:
+ batch2model_keys = []
+
+ if isinstance(batch2model_keys, str):
+ batch2model_keys = [batch2model_keys]
+
+ self.batch2model_keys = set(batch2model_keys)
+
+ self.fix_image_leak = fix_image_leak
+ if fix_image_leak:
+ self.beta_m = 15
+ self.a = 5
+ self.noise_encoder = ConcatTimestepEmbedderND(256)
+
+ def get_noised_input(
+ self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
+ ) -> torch.Tensor:
+ noised_input = input + noise * sigmas_bc
+ return noised_input
+
+ def decode_first_stage(self, z, first_stage_model):
+ if len(z.shape) == 5:
+ z = rearrange(z, "b c t h w -> (b t) c h w")
+
+ z = 1.0 / 0.18215 * z
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
+
+ n_rounds = math.ceil(z.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ if isinstance(first_stage_model.decoder, VideoDecoder):
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
+ else:
+ kwargs = {}
+ out = first_stage_model.decode(
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
+ )
+ all_out.append(out)
+ out = torch.cat(all_out, dim=0)
+ # out = rearrange(out, "b c h w -> b h w c")
+ torch.cuda.empty_cache()
+ return out.clip(-1, 1)
+
+ def forward(
+ self,
+ network: nn.Module,
+ denoiser: nn.Module,
+ conditioner: GeneralConditioner,
+ input: torch.Tensor,
+ batch: Dict,
+ first_stage_model: nn.Module = None,
+ ) -> torch.Tensor:
+ cond = conditioner(batch)
+ return self._forward(network, denoiser, cond, input, batch, first_stage_model)
+
+ def _forward(
+ self,
+ network: nn.Module,
+ denoiser: nn.Module,
+ cond: Dict,
+ input: torch.Tensor,
+ batch: Dict,
+ first_stage_model: nn.Module = None,
+ ) -> Tuple[torch.Tensor, Dict]:
+ additional_model_inputs = {
+ key: batch[key] for key in self.batch2model_keys.intersection(batch)
+ }
+ sigmas = self.sigma_sampler(input.shape[0]).to(input)
+
+ noise = torch.randn_like(input)
+ if self.offset_noise_level > 0.0:
+ offset_shape = (
+ (input.shape[0], 1, input.shape[2])
+ if self.n_frames is not None
+ else (input.shape[0], input.shape[1])
+ )
+ noise = noise + self.offset_noise_level * append_dims(
+ torch.randn(offset_shape, device=input.device),
+ input.ndim,
+ )
+ sigmas_bc = append_dims(sigmas, input.ndim)
+ noised_input = self.get_noised_input(sigmas_bc, noise, input)
+
+ if self.fix_image_leak:
+ noise_aug_strength = get_sigma_s(sigmas / 700, self.a, self.beta_m)
+ noise_aug = append_dims(noise_aug_strength, 4).to(input.device)
+ noise = torch.randn_like(noise_aug)
+ cond["concat"] = self.get_noised_input(noise_aug, noise, cond["concat"])
+ noise_emb = self.noise_encoder(noise_aug_strength).to(input.device)
+ # cond["vector"] = noise_emb if "vector" not in cond else torch.cat([cond["vector"], noise_emb], dim=1)
+ cond["vector"] = noise_emb
+ # print(cond["concat"].shape, cond["vector"].shape, noise.shape, noise_aug.shape, noise_emb.shape)
+
+ model_output = denoiser(
+ network, noised_input, sigmas, cond, **additional_model_inputs
+ )
+ mask = cond.get("masks", None)
+ w = append_dims(self.loss_weighting(sigmas), input.ndim)
+ return self.get_loss(
+ model_output,
+ input,
+ w,
+ sigmas,
+ mask,
+ first_stage_model,
+ batch.get("original_frames", None),
+ batch.get("landmarks", None),
+ )
+
+ def get_loss(
+ self,
+ model_output,
+ target,
+ w,
+ sigmas,
+ mask=None,
+ first_stage_model=None,
+ original_frames=None,
+ landmarks=None,
+ ):
+ scaling_w = w[:, 0, 0, 0]
+
+ T = 1
+ if target.ndim == 5:
+ target = rearrange(target, "b c t h w -> (b t) c h w")
+ B = w.shape[0]
+ T = target.shape[0] // B
+ if w.shape[2] != T:
+ w = repeat(w, "b () () () () -> (b t) () () ()", t=T)
+ else:
+ w = rearrange(w, "b c t h w -> (b t) c h w")
+
+ or_w = w.clone()
+
+ if self.lambda_lower != 1.0:
+ weight_lower = torch.ones_like(model_output, device=w.device)
+ weight_lower[:, :, model_output.shape[2] // 2 :] *= self.lambda_lower
+ w = weight_lower * w
+
+ if self.lambda_upper != 1.0:
+ weight_upper = torch.ones_like(model_output, device=w.device)
+ weight_upper[:, :, : model_output.shape[2] // 2] *= self.lambda_upper
+ w = weight_upper * w
+ loss_dict = {}
+
+ if self.loss_type == "l2":
+ loss = torch.mean(
+ (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
+ )
+ elif self.loss_type == "l1":
+ loss = torch.mean(
+ (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
+ )
+ elif self.loss_type == "lpips":
+ loss = self.lpips(model_output, target).reshape(-1)
+ else:
+ raise NotImplementedError(f"Unknown loss type {self.loss_type}")
+
+ loss_dict[self.loss_type] = loss.clone()
+ loss_dict["loss"] = loss
+
+ if self.add_lpips:
+ loss_dict["lpips"] = w[:, 0, 0, 0] * self.lpips(
+ (model_output[:, :3] * 0.18215).clip(-1, 1),
+ (target[:, :3] * 0.18215).clip(-1, 1),
+ ).reshape(-1)
+ loss_dict["loss"] += loss_dict["lpips"].mean()
+
+ if self.weight_pixel > 0.0:
+ assert original_frames is not None
+ # Randomly select n_frames_pixel frames
+ selected_frames = torch.randperm(T)[: self.n_frames_pixel]
+ selected_model_output = rearrange(
+ model_output, "(b t) ... -> b t ...", t=T
+ )[:, selected_frames]
+ selected_model_output = rearrange(
+ selected_model_output, "b t ... -> (b t) ..."
+ )
+ selected_original_frames = original_frames[:, :, selected_frames]
+ selected_original_frames = rearrange(
+ selected_original_frames, "b c t ... -> (b t) c ..."
+ )
+ selected_w = rearrange(or_w, "(b t) ... -> b t ...", t=T)[
+ :, selected_frames
+ ]
+ selected_w = rearrange(selected_w, "b t ... -> (b t) ...")
+ if selected_w.shape[-1] != selected_original_frames.shape[-1]:
+ # Interpolate the weights to match the number of frames
+ selected_w = torch.nn.functional.interpolate(
+ selected_w, size=selected_original_frames.shape[-1], mode="nearest"
+ )
+ decoded_frames = self.decode_first_stage(
+ selected_model_output, first_stage_model
+ )
+ # print(decoded_frames.shape, selected_original_frames.shape, selected_w.shape)
+
+ for loss_name in self.what_pixel_losses:
+ if loss_name == "l2":
+ # print(selected_w.shape, decoded_frames.shape, selected_original_frames.shape)
+ loss_pixel = torch.mean(
+ (
+ selected_w
+ * (decoded_frames - selected_original_frames) ** 2
+ ).reshape(selected_original_frames.shape[0], -1),
+ 1,
+ )
+ loss_dict["pixel_l2"] = self.weight_pixel * loss_pixel.mean()
+ loss += self.weight_pixel * loss_pixel.mean()
+ elif loss_name == "lpips":
+ loss_pixel = (
+ self.lpips(decoded_frames, selected_original_frames).reshape(-1)
+ * scaling_w
+ )
+ loss_dict["pixel_lpips"] = loss_pixel.mean()
+ loss += self.weight_pixel * loss_pixel.mean()
+ elif loss_name == "l1":
+ loss_pixel = torch.mean(
+ (
+ selected_w
+ * (decoded_frames - selected_original_frames).abs()
+ ).reshape(selected_original_frames.shape[0], -1),
+ 1,
+ )
+ loss_dict["pixel_l1"] = self.weight_pixel * loss_pixel.mean()
+ loss += self.weight_pixel * loss_pixel.mean()
+ elif loss_name == "id":
+ landmarks = landmarks[:, selected_frames]
+ cat_id_input = (
+ (
+ torch.cat([decoded_frames, selected_original_frames], dim=0)
+ + 1
+ )
+ / 2
+ ) * 255
+ cat_id_landmarks = torch.cat([landmarks, landmarks], dim=0)
+ cat_id_landmarks = (
+ rearrange(cat_id_landmarks, "b t ... -> (b t) ...")
+ .cpu()
+ .numpy()
+ )
+ try:
+ cropped_decoded_frames = extract_face(
+ rearrange(cat_id_input, "b c h w -> b h w c"),
+ cat_id_landmarks,
+ margin=30,
+ postprocess=True,
+ )
+ # Save first frame to debug
+ n = cat_id_input.shape[0] // 2
+
+ id_embeddings = self.id_model(
+ rearrange(cropped_decoded_frames, "b h w c -> b c h w")
+ )
+ pred_embeddings, target_embeddings = (
+ id_embeddings[:n],
+ id_embeddings[n:],
+ )
+ # Cosine similarity loss (1 - cos_sim to make it a loss that should be minimized)
+ id_w = scaling_w
+ loss_pixel = (
+ id_w
+ * (
+ 1
+ - torch.nn.functional.cosine_similarity(
+ pred_embeddings, target_embeddings
+ )
+ )
+ ).mean()
+ loss_dict["pixel_id"] = self.weight_pixel * loss_pixel
+ loss += self.weight_pixel * loss_pixel
+ except RuntimeError as e:
+ if "adaptive_avg_pool2d()" in str(e):
+ print(
+ "Warning: Invalid face crop dimensions, skipping ID loss for this batch"
+ )
+ loss_dict["pixel_id"] = torch.tensor(
+ 0.0, device=cat_id_input.device
+ )
+ continue
+ else:
+ raise # Re-raise other RuntimeErrors
+ elif loss_name == "id_mse":
+ landmarks = landmarks[:, selected_frames]
+ cat_id_input = (
+ (
+ torch.cat([decoded_frames, selected_original_frames], dim=0)
+ + 1
+ )
+ / 2
+ ) * 255
+ cat_id_landmarks = torch.cat([landmarks, landmarks], dim=0)
+ cat_id_landmarks = (
+ rearrange(cat_id_landmarks, "b t ... -> (b t) ...")
+ .cpu()
+ .numpy()
+ )
+ cropped_decoded_frames = extract_face(
+ rearrange(cat_id_input, "b c h w -> b h w c"),
+ cat_id_landmarks,
+ margin=30,
+ postprocess=True,
+ )
+ # Save first frame to debug
+ n = cat_id_input.shape[0] // 2
+
+ id_embeddings = self.id_model(
+ rearrange(cropped_decoded_frames, "b h w c -> b c h w")
+ )
+
+ pred_embeddings, target_embeddings = (
+ id_embeddings[:n],
+ id_embeddings[n:],
+ )
+ # Cosine similarity loss (1 - cos_sim to make it a loss that should be minimized)
+ id_w = append_dims(
+ self.loss_weighting(sigmas), pred_embeddings.ndim
+ )
+ loss_pixel = (
+ id_w * ((pred_embeddings - target_embeddings) ** 2)
+ ).mean()
+ loss_dict["pixel_id_mse"] = self.weight_pixel * loss_pixel
+ loss += self.weight_pixel * loss_pixel
+
+ else:
+ raise NotImplementedError(f"Unknown pixel loss type {loss_name}")
+
+ return loss_dict
diff --git a/sgm/modules/diffusionmodules/loss_weighting.py b/sgm/modules/diffusionmodules/loss_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12c0a76635435babd1af33969e82fa284525af8
--- /dev/null
+++ b/sgm/modules/diffusionmodules/loss_weighting.py
@@ -0,0 +1,32 @@
+from abc import ABC, abstractmethod
+
+import torch
+
+
+class DiffusionLossWeighting(ABC):
+ @abstractmethod
+ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
+ pass
+
+
+class UnitWeighting(DiffusionLossWeighting):
+ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
+ return torch.ones_like(sigma, device=sigma.device)
+
+
+class EDMWeighting(DiffusionLossWeighting):
+ def __init__(self, sigma_data: float = 0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+
+class VWeighting(EDMWeighting):
+ def __init__(self):
+ super().__init__(sigma_data=1.0)
+
+
+class EpsWeighting(DiffusionLossWeighting):
+ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
+ return sigma**-2.0
diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..20409fa64134f3aadaa54d3054852d1431473a65
--- /dev/null
+++ b/sgm/modules/diffusionmodules/model.py
@@ -0,0 +1,746 @@
+# pytorch_diffusion + derived encoder decoder
+import logging
+import math
+from typing import Any, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from packaging import version
+
+logpy = logging.getLogger(__name__)
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ logpy.warning("no module 'xformers'. Processing without...")
+
+from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
+ h_ = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
+ # compute attention
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.attention_op: Optional[Any] = None
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C)
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
+ b, c, h, w = x.shape
+ x = rearrange(x, "b c h w -> b (h w) c")
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ "memory-efficient-cross-attn",
+ "linear",
+ "none",
+ ], f"attn_type {attn_type} unknown"
+ if version.parse(torch.__version__) < version.parse("2.0.0") and attn_type != "none":
+ assert XFORMERS_IS_AVAILABLE, (
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ )
+ attn_type = "vanilla-xformers"
+ logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ logpy.info(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ logpy.info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
+
+ make_attn_cls = self._make_attn()
+ make_resblock_cls = self._make_resblock()
+ make_conv_cls = self._make_conv()
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
+ self.mid.block_2 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = make_conv_cls(block_in, out_ch, kernel_size=3, stride=1, padding=1)
+
+ def _make_attn(self) -> Callable:
+ return make_attn
+
+ def _make_resblock(self) -> Callable:
+ return ResnetBlock
+
+ def _make_conv(self) -> Callable:
+ return torch.nn.Conv2d
+
+ def get_last_layer(self, **kwargs):
+ return self.conv_out.weight
+
+ def forward(self, z, **kwargs):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, **kwargs)
+ h = self.mid.attn_1(h, **kwargs)
+ h = self.mid.block_2(h, temb, **kwargs)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, **kwargs)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h, **kwargs)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+def zero_module(module):
+ # Zero out the parameters of a module and return it.
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class FaceLocator(torch.nn.Module):
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 1,
+ block_out_channels: tuple = (16, 32, 64, 128),
+ ):
+ super().__init__()
+ self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ InflatedConv3d(
+ block_out_channels[-1],
+ conditioning_embedding_channels,
+ kernel_size=3,
+ padding=1,
+ )
+ )
+
+ def forward(self, conditioning):
+ # print(conditioning.shape)
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..03221777ca9cfb8792407d453b0d7b911b6eddf0
--- /dev/null
+++ b/sgm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,880 @@
+import logging
+import math
+from abc import abstractmethod
+from typing import Iterable, List, Optional, Tuple, Union
+
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch.utils.checkpoint import checkpoint
+
+from ...modules.attention import SpatialTransformer
+from ...modules.diffusionmodules.util import (
+ avg_pool_nd,
+ conv_nd,
+ linear,
+ normalization,
+ timestep_embedding,
+ zero_module,
+)
+from ...modules.video_attention import SpatialVideoTransformer
+from ...util import exists
+
+logpy = logging.getLogger(__name__)
+
+
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ b, c, _ = x.shape
+ x = x.reshape(b, c, -1)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
+ x = x + self.positional_embedding[None, :, :]
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x: th.Tensor, emb: th.Tensor):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(
+ self,
+ x: th.Tensor,
+ emb: th.Tensor,
+ context: Optional[th.Tensor] = None,
+ reference_context: Optional[th.Tensor] = None,
+ audio_context: Optional[th.Tensor] = None,
+ image_only_indicator: Optional[th.Tensor] = None,
+ time_context: Optional[int] = None,
+ num_video_frames: Optional[int] = None,
+ skip_spatial_attention: bool = False,
+ skip_temporal_attention: bool = False,
+ ):
+ from ...modules.diffusionmodules.video_model import VideoResBlock
+
+ is_attention = False
+ for layer in self:
+ module = layer
+
+ if isinstance(module, TimestepBlock) and not isinstance(module, VideoResBlock):
+ x = layer(x, emb)
+ elif isinstance(module, VideoResBlock):
+ x = layer(x, emb, num_video_frames, image_only_indicator)
+ elif isinstance(module, SpatialVideoTransformer):
+ is_attention = True
+ x = layer(
+ x,
+ context,
+ reference_context,
+ time_context,
+ audio_context,
+ num_video_frames,
+ image_only_indicator,
+ skip_spatial_attention=skip_spatial_attention,
+ skip_temporal_attention=skip_temporal_attention,
+ )
+ elif isinstance(module, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+
+ return x, is_attention
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool,
+ dims: int = 2,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ third_up: bool = False,
+ kernel_size: int = 3,
+ scale_factor: int = 2,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ self.third_up = third_up
+ self.scale_factor = scale_factor
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, kernel_size, padding=padding)
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ assert x.shape[1] == self.channels
+
+ if self.dims == 3:
+ t_factor = 1 if not self.third_up else self.scale_factor
+ x = F.interpolate(
+ x,
+ (
+ t_factor * x.shape[2],
+ x.shape[3] * self.scale_factor,
+ x.shape[4] * self.scale_factor,
+ ),
+ mode="nearest",
+ )
+ else:
+ # prev_dtype = x.dtype
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
+ # x = x.to(prev_dtype)
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool,
+ dims: int = 2,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ third_down: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
+ if use_conv:
+ logpy.info(f"Building a Downsample layer with {dims} dims.")
+ logpy.info(
+ f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
+ f"kernel-size: 3, stride: {stride}, padding: {padding}"
+ )
+ if dims == 3:
+ logpy.info(f" --> Downsampling third axis (time): {third_down}")
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ assert x.shape[1] == self.channels
+
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ emb_channels: int,
+ dropout: float,
+ out_channels: Optional[int] = None,
+ use_conv: bool = False,
+ use_scale_shift_norm: bool = False,
+ dims: int = 2,
+ use_checkpoint: bool = False,
+ up: bool = False,
+ down: bool = False,
+ kernel_size: int = 3,
+ exchange_temb_dims: bool = False,
+ skip_t_emb: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.exchange_temb_dims = exchange_temb_dims
+
+ if isinstance(kernel_size, Iterable):
+ padding = [k // 2 for k in kernel_size]
+ else:
+ padding = kernel_size // 2
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.skip_t_emb = skip_t_emb
+ self.emb_out_channels = 2 * self.out_channels if use_scale_shift_norm else self.out_channels
+ if self.skip_t_emb:
+ logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}")
+ assert not self.use_scale_shift_norm
+ self.emb_layers = None
+ self.exchange_temb_dims = False
+ else:
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ self.emb_out_channels,
+ ),
+ )
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(
+ dims,
+ self.out_channels,
+ self.out_channels,
+ kernel_size,
+ padding=padding,
+ )
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if self.use_checkpoint:
+ return checkpoint(self._forward, x, emb)
+ else:
+ return self._forward(x, emb)
+
+ def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+
+ if self.skip_t_emb:
+ emb_out = th.zeros_like(h)
+ else:
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ if self.exchange_temb_dims:
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int = 1,
+ num_head_channels: int = -1,
+ use_checkpoint: bool = False,
+ use_new_attention_order: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x: th.Tensor, **kwargs) -> th.Tensor:
+ return checkpoint(self._forward, x)
+
+ def _forward(self, x: th.Tensor) -> th.Tensor:
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads: int):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv: th.Tensor) -> th.Tensor:
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads: int):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv: th.Tensor) -> th.Tensor:
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+
+class Timestep(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, t: th.Tensor) -> th.Tensor:
+ return timestep_embedding(t, self.dim)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ model_channels: int,
+ out_channels: int,
+ num_res_blocks: int,
+ attention_resolutions: int,
+ dropout: float = 0.0,
+ channel_mult: Union[List, Tuple] = (1, 2, 4, 8),
+ conv_resample: bool = True,
+ dims: int = 2,
+ num_classes: Optional[Union[int, str]] = None,
+ use_checkpoint: bool = False,
+ num_heads: int = -1,
+ num_head_channels: int = -1,
+ num_heads_upsample: int = -1,
+ use_scale_shift_norm: bool = False,
+ resblock_updown: bool = False,
+ transformer_depth: int = 1,
+ context_dim: Optional[int] = None,
+ disable_self_attentions: Optional[List[bool]] = None,
+ num_attention_blocks: Optional[List[int]] = None,
+ disable_middle_self_attn: bool = False,
+ disable_middle_transformer: bool = False,
+ use_linear_in_transformer: bool = False,
+ spatial_transformer_attn_type: str = "softmax",
+ adm_in_channels: Optional[int] = None,
+ unfreeze_blocks: Optional[List[str]] = None,
+ fine_tuning_method: str = None,
+ audio_cond_method: str = None,
+ **kwargs,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
+
+ self.adapter = None
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.audio_cond_method = audio_cond_method
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ transformer_depth_middle = transformer_depth[-1]
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+
+ if disable_self_attentions is not None:
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ logpy.info(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ )
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ logpy.info("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError
+
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ if context_dim is not None and exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ (
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ if not disable_middle_transformer
+ else th.nn.Identity()
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ if fine_tuning_method is not None:
+ # Freeze everything except the adapter
+ for param in self.parameters():
+ param.requires_grad = False
+ if self.adapter is not None:
+ for param in self.adapter.parameters():
+ param.requires_grad = True
+ if unfreeze_blocks:
+ if "input" in unfreeze_blocks:
+ for param in self.input_blocks[0].parameters():
+ param.requires_grad = True
+ # break # only unfreeze the first input block
+ if "label_emb" in unfreeze_blocks:
+ for param in self.label_emb.parameters():
+ param.requires_grad = True
+ if "time_embed" in unfreeze_blocks:
+ for param in self.time_embed.parameters():
+ param.requires_grad = True
+
+ def forward(
+ self,
+ x: th.Tensor,
+ timesteps: Optional[th.Tensor] = None,
+ encoder_hidden_states: Optional[th.Tensor] = None,
+ y: Optional[th.Tensor] = None,
+ audio_emb: Optional[th.Tensor] = None,
+ **kwargs,
+ ) -> th.Tensor:
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ context = encoder_hidden_states
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ if self.audio_cond_method == "cross_attention":
+ assert audio_emb is not None
+ # print(f"{context.shape=}")
+ if audio_emb.ndim == 4:
+ audio_emb = rearrange(audio_emb, "b t d c -> b (t d) c")
+ context = th.cat([context, audio_emb], dim=1)
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+
+ return self.out(h)
diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca557922d562d620c65eafef90b36389c55c228e
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sampling.py
@@ -0,0 +1,614 @@
+"""
+Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
+"""
+
+from collections import defaultdict
+from typing import Dict, Union
+
+import torch
+from omegaconf import ListConfig, OmegaConf
+from tqdm import tqdm
+from einops import rearrange
+
+from ...modules.diffusionmodules.sampling_utils import (
+ get_ancestral_step,
+ linear_multistep_coeff,
+ to_d,
+ to_neg_log_sigma,
+ to_sigma,
+ chunk_inputs,
+)
+from ...util import append_dims, default, instantiate_from_config
+
+DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
+
+
+class BaseDiffusionSampler:
+ def __init__(
+ self,
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
+ num_steps: Union[int, None] = None,
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
+ verbose: bool = True,
+ device: str = "cuda",
+ ):
+ self.num_steps = num_steps
+ self.discretization = instantiate_from_config(discretization_config)
+ self.guider = instantiate_from_config(
+ default(
+ guider_config,
+ DEFAULT_GUIDER,
+ )
+ )
+ self.verbose = verbose
+ self.device = device
+
+ def set_num_steps(self, num_steps: int):
+ self.num_steps = num_steps
+
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None, strength=1.0):
+ print("Num steps: ", self.num_steps if num_steps is None else num_steps)
+ sigmas = self.discretization(self.num_steps if num_steps is None else num_steps, device=self.device)
+ if strength != 1.0:
+ init_timestep = min(int(len(sigmas) * strength), len(sigmas))
+ t_start = max(len(sigmas) - init_timestep, 0)
+ # sigmas[:t_start] = torch.ones_like(sigmas[:t_start]) * sigmas[t_start]
+ sigmas = sigmas[t_start:]
+ uc = default(uc, cond)
+
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
+ num_sigmas = len(sigmas)
+
+ s_in = x.new_ones([x.shape[0]])
+
+ return x, s_in, sigmas, num_sigmas, cond, uc
+
+ def denoise(self, x, denoiser, sigma, cond, uc):
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
+ denoised = self.guider(denoised, sigma)
+ return denoised
+
+ def get_sigma_gen(self, num_sigmas):
+ sigma_generator = range(num_sigmas - 1)
+ if self.verbose:
+ print("#" * 30, " Sampling setting ", "#" * 30)
+ print(f"Sampler: {self.__class__.__name__}")
+ print(f"Discretization: {self.discretization.__class__.__name__}")
+ print(f"Guider: {self.guider.__class__.__name__}")
+ sigma_generator = tqdm(
+ sigma_generator,
+ total=num_sigmas,
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
+ )
+ return sigma_generator
+
+
+class FIFODiffusionSampler(BaseDiffusionSampler):
+ def __init__(self, lookahead=False, num_frames=14, num_partitions=4, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.num_frames = num_frames
+ self.lookahead = lookahead
+ self.num_partitions = num_partitions
+ self.num_steps = self.num_frames * self.num_partitions
+ self.fifo = []
+
+ def get_sigma_gen(self, num_sigmas, total_n_frames):
+ total = total_n_frames + num_sigmas - self.num_frames
+ sigma_generator = range(total_n_frames + num_sigmas - self.num_frames - 1)
+ if self.verbose:
+ print("#" * 30, " Sampling setting ", "#" * 30)
+ print(f"Sampler: {self.__class__.__name__}")
+ print(f"Discretization: {self.discretization.__class__.__name__}")
+ print(f"Guider: {self.guider.__class__.__name__}")
+ sigma_generator = tqdm(
+ sigma_generator,
+ total=total,
+ desc=f"Sampling with {self.__class__.__name__} for {total} steps",
+ )
+ return sigma_generator
+
+ def prepare_sampling_loop(self, x, cond, uc=None):
+ sigmas = self.discretization(self.num_steps, device=self.device)
+
+ uc = default(uc, cond)
+
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
+ num_sigmas = len(sigmas)
+
+ s_in = x.new_ones([x.shape[0]])
+
+ return x, s_in, sigmas, num_sigmas, cond, uc
+
+
+class SingleStepDiffusionSampler(BaseDiffusionSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
+ raise NotImplementedError
+
+ def euler_step(self, x, d, dt):
+ return x + dt * d
+
+
+class EDMSampler(SingleStepDiffusionSampler):
+ def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.s_churn = s_churn
+ self.s_tmin = s_tmin
+ self.s_tmax = s_tmax
+ self.s_noise = s_noise
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
+ sigma_hat = sigma * (gamma + 1.0)
+ if gamma > 0:
+ eps = torch.randn_like(x) * self.s_noise
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
+
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
+ if x.ndim == 5:
+ denoised = rearrange(denoised, "(b t) c h w -> b c t h w", b=x.shape[0])
+
+ d = to_d(x, sigma_hat, denoised)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+
+ euler_step = self.euler_step(x, d, dt)
+ x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, strength=1.0):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps, strength=strength)
+
+ for i in self.get_sigma_gen(num_sigmas):
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
+ )
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ gamma,
+ )
+
+ return x
+
+
+class EDMSampleCFGplusplus(SingleStepDiffusionSampler):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
+ sigma_hat = sigma
+
+ denoised, x_u = self.denoise(x, denoiser, sigma_hat, cond, uc)
+ if x.ndim == 5:
+ denoised = rearrange(denoised, "(b t) c h w -> b c t h w", b=x.shape[0])
+ x_u = rearrange(x_u, "(b t) c h w -> b c t h w", b=x.shape[0])
+
+ d = to_d(x, sigma_hat, x_u)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+ next_sigma = append_dims(next_sigma, x.ndim)
+
+ euler_step = self.euler_step(denoised, d, next_sigma)
+ x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, strength=1.0):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps, strength=strength)
+
+ for i in self.get_sigma_gen(num_sigmas):
+ s_in = x.new_ones([x.shape[0]])
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ None,
+ )
+
+ return x
+
+
+def shift_latents(latents):
+ # shift latents
+ latents[:, :, :-1] = latents[:, :, 1:].clone()
+
+ # add new noise to the last frame
+ latents[:, :, -1] = torch.randn_like(latents[:, :, -1])
+
+ return latents
+
+
+class FIFOEDMSampler(FIFODiffusionSampler):
+ """
+ The problem is that the original implementation doesn't take into consideration the condition.
+ So we need to check if this can work with the condition. Don't have time to check this now.
+ """
+
+ def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.s_churn = s_churn
+ self.s_tmin = s_tmin
+ self.s_tmax = s_tmax
+ self.s_noise = s_noise
+
+ def euler_step(self, x, d, dt):
+ return x + dt * d
+
+ def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
+ return euler_step
+
+ def concatenate_list_dict(self, dict1):
+ for k, v in dict1.items():
+ if isinstance(v, list):
+ dict1[k] = torch.cat(v, dim=0)
+ else:
+ dict1[k] = v
+ return dict1
+
+ def prepare_latents(self, x, c, uc, sigmas, num_sigmas):
+ latents_list = []
+ sigma_hat_list = []
+ sigma_next_list = []
+ c_list = defaultdict(list)
+ uc_list = defaultdict(list)
+
+ video = torch.load("/data/home/antoni/code/generative-models-dub/samples_z.pt")
+ video = rearrange(video, "t c h w -> () c t h w")
+
+ for k, v in c.items():
+ if not isinstance(v, torch.Tensor):
+ c_list[k] = v
+ uc_list[k] = uc[k]
+
+ if self.lookahead:
+ for i in range(self.num_frames // 2):
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
+ )
+ sigma = sigmas[i]
+ sigma_hat = sigma * (gamma + 1.0)
+ if gamma > 0:
+ eps = torch.randn_like(video[:, :, [0]]) * self.s_noise
+ latents = video[:, :, [0]] + eps * append_dims(sigma_hat**2 - sigma**2, video.ndim) ** 0.5
+ else:
+ latents = video[:, :, [0]]
+
+ for k, v in c.items():
+ if isinstance(v, torch.Tensor):
+ c_list[k].append(v[[0]])
+ for k, v in uc.items():
+ if isinstance(v, torch.Tensor):
+ uc_list[k].append(v[[0]])
+
+ latents_list.append(latents)
+ sigma_hat_list.append(sigma_hat)
+ sigma_next_list.append(sigmas[i + 1])
+
+ for i in range(num_sigmas - 1):
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0
+ )
+ sigma = sigmas[i]
+ sigma_hat = sigma * (gamma + 1.0)
+ frame_idx = max(0, i - (num_sigmas - self.num_frames))
+ print(frame_idx)
+ if gamma > 0:
+ eps = torch.randn_like(video[:, :, [frame_idx]]) * self.s_noise
+ latents = video[:, :, [frame_idx]] + eps * append_dims(sigma_hat**2 - sigma**2, video.ndim) ** 0.5
+ else:
+ latents = video[:, :, [frame_idx]]
+
+ for k, v in c.items():
+ if isinstance(v, torch.Tensor):
+ c_list[k].append(
+ v[[frame_idx]] if v.shape[0] == video.shape[2] else v[[frame_idx // self.num_frames]]
+ )
+ for k, v in uc.items():
+ if isinstance(v, torch.Tensor):
+ uc_list[k].append(
+ v[[frame_idx]] if v.shape[0] == video.shape[2] else v[[frame_idx // self.num_frames]]
+ )
+
+ latents_list.append(latents)
+ sigma_hat_list.append(sigma_hat)
+ sigma_next_list.append(sigmas[i + 1])
+
+ latents = torch.cat(latents_list, dim=2)
+ sigma_hat = torch.stack(sigma_hat_list, dim=0)
+ sigma_next = torch.stack(sigma_next_list, dim=0)
+
+ c_list = self.concatenate_list_dict(c_list)
+ uc_list = self.concatenate_list_dict(uc_list)
+
+ return latents, sigma_hat, sigma_next, c_list, uc_list
+
+ def sampler_step(self, sigma_hat, next_sigma, denoiser, x, cond, uc=None):
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
+ if x.ndim == 5:
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+
+ d = to_d(x, sigma_hat, denoised)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+
+ euler_step = self.euler_step(x, d, dt)
+ x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc)
+ return x
+
+ def merge_cond_dict(self, cond, total_n_frames):
+ for k, v in cond.items():
+ if not isinstance(v, torch.Tensor):
+ cond[k] = v
+ else:
+ if v.dim() == 5:
+ cond[k] = rearrange(v, "b c t h w -> (b t) c h w")
+ elif v.dim() == 3 and v.shape[0] != total_n_frames:
+ cond[k] = rearrange(v, "b t c -> (b t) () c")
+ else:
+ cond[k] = v
+ return cond
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, strength=1.0):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc)
+
+ x = rearrange(x, "b c h w -> () c b h w")
+ cond = self.merge_cond_dict(cond, x.shape[2])
+ uc = self.merge_cond_dict(uc, x.shape[2])
+ total_n_frames = x.shape[2]
+ latents, sigma_hat, sigma_next, cond, uc = self.prepare_latents(x, cond, uc, sigmas, num_sigmas)
+
+ fifo_video_frames = []
+
+ for i in self.get_sigma_gen(num_sigmas, total_n_frames):
+ for rank in reversed(range(2 * self.num_partitions if self.lookahead else self.num_partitions)):
+ start_idx = rank * (self.num_frames // 2) if self.lookahead else rank * self.num_frames
+ midpoint_idx = start_idx + self.num_frames // 2
+ end_idx = start_idx + self.num_frames
+
+ chunk_x, sigma_hat_chunk, sigma_next_chunk, cond_chunk, uc_chunk = chunk_inputs(
+ latents, cond, uc, sigma_hat, sigma_next, start_idx, end_idx, self.num_frames
+ )
+
+ s_in = chunk_x.new_ones([chunk_x.shape[0]])
+
+ out = self.sampler_step(
+ s_in * sigma_hat_chunk,
+ s_in * sigma_next_chunk,
+ denoiser,
+ chunk_x,
+ cond_chunk,
+ uc=uc_chunk,
+ )
+ if self.lookahead:
+ latents[:, :, midpoint_idx:end_idx] = rearrange(
+ out[-(self.num_frames // 2) :], "b c h w -> () c b h w"
+ )
+ else:
+ latents[:, :, start_idx:end_idx] = rearrange(out, "b c h w -> () c b h w")
+ del out
+
+ first_frame_idx = self.num_frames // 2 if self.lookahead else 0
+ latents = shift_latents(latents)
+ fifo_video_frames.append(latents[:, :, [first_frame_idx]])
+
+ return rearrange(torch.cat(fifo_video_frames, dim=2), "() c b h w -> b c h w")[-total_n_frames:]
+
+
+class AncestralSampler(SingleStepDiffusionSampler):
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.eta = eta
+ self.s_noise = s_noise
+ self.noise_sampler = lambda x: torch.randn_like(x)
+
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
+ d = to_d(x, sigma, denoised)
+ dt = append_dims(sigma_down - sigma, x.ndim)
+
+ return self.euler_step(x, d, dt)
+
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0,
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
+ x,
+ )
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ )
+
+ return x
+
+
+class LinearMultistepSampler(BaseDiffusionSampler):
+ def __init__(
+ self,
+ order=4,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.order = order
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ ds = []
+ sigmas_cpu = sigmas.detach().cpu().numpy()
+ for i in self.get_sigma_gen(num_sigmas):
+ sigma = s_in * sigmas[i]
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs)
+ denoised = self.guider(denoised, sigma)
+ d = to_d(x, sigma, denoised)
+ ds.append(d)
+ if len(ds) > self.order:
+ ds.pop(0)
+ cur_order = min(i + 1, self.order)
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+
+ return x
+
+
+class EulerEDMSampler(EDMSampler):
+ def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
+ return euler_step
+
+
+class EulerEDMSamplerPlusPlus(EDMSampleCFGplusplus):
+ def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
+ return euler_step
+
+
+class HeunEDMSampler(EDMSampler):
+ def possible_correction_step(self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc):
+ if torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ return euler_step
+ else:
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
+ d_new = to_d(euler_step, next_sigma, denoised)
+ d_prime = (d + d_new) / 2.0
+
+ # apply correction if noise level is not 0
+ x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step)
+ return x
+
+
+class EulerAncestralSampler(AncestralSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+
+ return x
+
+
+class DPMPP2SAncestralSampler(AncestralSampler):
+ def get_variables(self, sigma, sigma_down):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
+ h = t_next - t
+ s = t + 0.5 * h
+ return h, s, t, t_next
+
+ def get_mult(self, h, s, t, t_next):
+ mult1 = to_sigma(s) / to_sigma(t)
+ mult2 = (-0.5 * h).expm1()
+ mult3 = to_sigma(t_next) / to_sigma(t)
+ mult4 = (-h).expm1()
+
+ return mult1, mult2, mult3, mult4
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+
+ if torch.sum(sigma_down) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ x = x_euler
+ else:
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
+ mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)]
+
+ x2 = mult[0] * x - mult[1] * denoised
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
+
+ # apply correction if noise level is not 0
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
+
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+ return x
+
+
+class DPMPP2MSampler(BaseDiffusionSampler):
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
+ h = t_next - t
+
+ if previous_sigma is not None:
+ h_last = t - to_neg_log_sigma(previous_sigma)
+ r = h_last / h
+ return h, r, t, t_next
+ else:
+ return h, None, t, t_next
+
+ def get_mult(self, h, r, t, t_next, previous_sigma):
+ mult1 = to_sigma(t_next) / to_sigma(t)
+ mult2 = (-h).expm1()
+
+ if previous_sigma is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def sampler_step(
+ self,
+ old_denoised,
+ previous_sigma,
+ sigma,
+ next_sigma,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ ):
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
+ mult = [append_dims(mult, x.ndim) for mult in self.get_mult(h, r, t, t_next, previous_sigma)]
+
+ x_standard = mult[0] * x - mult[1] * denoised
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return x_standard, denoised
+ else:
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
+ x_advanced = mult[0] * x - mult[1] * denoised_d
+
+ # apply correction if noise level is not 0 and not first step
+ x = torch.where(append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard)
+
+ return x, denoised
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps)
+
+ old_denoised = None
+ for i in self.get_sigma_gen(num_sigmas):
+ x, old_denoised = self.sampler_step(
+ old_denoised,
+ None if i == 0 else s_in * sigmas[i - 1],
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc=uc,
+ )
+
+ return x
diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e60cd5de97e4641c17c5dbdbb6aa5063eda090f
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sampling_utils.py
@@ -0,0 +1,80 @@
+import torch
+from scipy import integrate
+from einops import repeat, rearrange
+from ...util import append_dims
+
+
+def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
+ if order - 1 > i:
+ raise ValueError(f"Order {order} too high for step {i}")
+
+ def fn(tau):
+ prod = 1.0
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
+
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
+ if not eta:
+ return sigma_to, 0.0
+ sigma_up = torch.minimum(
+ sigma_to,
+ eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
+ )
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+ return sigma_down, sigma_up
+
+
+def to_d(x, sigma, denoised):
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def to_neg_log_sigma(sigma):
+ return sigma.log().neg()
+
+
+def to_sigma(neg_log_sigma):
+ return neg_log_sigma.neg().exp()
+
+
+def chunk_inputs(
+ input,
+ cond,
+ additional_model_inputs,
+ sigma,
+ sigma_next,
+ start_idx,
+ end_idx,
+ num_frames=14,
+):
+ input_chunk = input[:, :, start_idx:end_idx].to(torch.float32).clone()
+
+ sigma_chunk = sigma[start_idx:end_idx].to(torch.float32)
+ sigma_next_chunk = sigma_next[start_idx:end_idx].to(torch.float32)
+
+ cond_chunk = {}
+ for k, v in cond.items():
+ if isinstance(v, torch.Tensor):
+ cond_chunk[k] = v[start_idx:end_idx]
+ else:
+ cond_chunk[k] = v
+
+ additional_model_inputs_chunk = {}
+ for k, v in additional_model_inputs.items():
+ if isinstance(v, torch.Tensor):
+ cond_chunk[k] = v[start_idx:end_idx]
+ else:
+ additional_model_inputs_chunk[k] = v
+
+ return (
+ input_chunk,
+ sigma_chunk,
+ sigma_next_chunk,
+ cond_chunk,
+ additional_model_inputs_chunk,
+ )
diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba98396c16157faf6ebd8f8508058c660d6d4750
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sigma_sampling.py
@@ -0,0 +1,29 @@
+import torch
+
+from ...util import default, instantiate_from_config
+
+
+class EDMSampling:
+ def __init__(self, p_mean=-1.2, p_std=1.2):
+ self.p_mean = p_mean
+ self.p_std = p_std
+
+ def __call__(self, n_samples, rand=None):
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
+ return log_sigma.exp()
+
+
+class DiscreteSampling:
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
+ self.num_idx = num_idx
+ self.sigmas = instantiate_from_config(discretization_config)(num_idx, do_append_zero=do_append_zero, flip=flip)
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas[idx]
+
+ def __call__(self, n_samples, rand=None):
+ idx = default(
+ rand,
+ torch.randint(0, self.num_idx, (n_samples,)),
+ )
+ return self.idx_to_sigma(idx)
diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..edd39ced552765177a2b28be0fbc3ac6c7fa7f87
--- /dev/null
+++ b/sgm/modules/diffusionmodules/util.py
@@ -0,0 +1,328 @@
+"""
+partially adopted from
+https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+and
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+and
+https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+
+thanks!
+"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+def make_beta_schedule(
+ schedule,
+ n_timestep,
+ linear_start=1e-4,
+ linear_end=2e-2,
+):
+ if schedule == "linear":
+ betas = torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64) ** 2
+ return betas.numpy()
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def mixed_checkpoint(func, inputs: dict, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
+ it also works with non-tensor inputs
+ :param func: the function to evaluate.
+ :param inputs: the argument dictionary to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
+ tensor_inputs = [inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)]
+ non_tensor_keys = [key for key in inputs if not isinstance(inputs[key], torch.Tensor)]
+ non_tensor_inputs = [inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)]
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
+ return MixedCheckpointFunction.apply(
+ func,
+ len(tensor_inputs),
+ len(non_tensor_inputs),
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ )
+ else:
+ return func(**inputs)
+
+
+class MixedCheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ run_function,
+ length_tensors,
+ length_non_tensors,
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ ):
+ ctx.end_tensors = length_tensors
+ ctx.end_non_tensors = length_tensors + length_non_tensors
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ assert len(tensor_keys) == length_tensors and len(non_tensor_keys) == length_non_tensors
+
+ ctx.input_tensors = {key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))}
+ ctx.input_non_tensors = {
+ key: val for (key, val) in zip(non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]))
+ }
+ ctx.run_function = run_function
+ ctx.input_params = list(args[ctx.end_non_tensors :])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(**ctx.input_tensors, **ctx.input_non_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
+ ctx.input_tensors = {key: ctx.input_tensors[key].detach().requires_grad_(True) for key in ctx.input_tensors}
+
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = {key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) for key in ctx.input_tensors}
+ # shallow_copies.update(additional_args)
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ list(ctx.input_tensors.values()) + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (
+ (None, None, None, None, None)
+ + input_grads[: ctx.end_tensors]
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
+ + input_grads[ctx.end_tensors :]
+ )
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
+ device=timesteps.device
+ )
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class AlphaBlender(nn.Module):
+ strategies = ["learned", "fixed", "learned_with_images"]
+
+ def __init__(
+ self,
+ alpha: float,
+ merge_strategy: str = "learned_with_images",
+ rearrange_pattern: str = "b t -> (b t) 1 1",
+ ):
+ super().__init__()
+ self.merge_strategy = merge_strategy
+ self.rearrange_pattern = rearrange_pattern
+
+ assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"
+
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
+ if self.merge_strategy == "fixed":
+ alpha = self.mix_factor
+ elif self.merge_strategy == "learned":
+ alpha = torch.sigmoid(self.mix_factor)
+ elif self.merge_strategy == "learned_with_images":
+ assert image_only_indicator is not None, "need image_only_indicator ..."
+ alpha = torch.where(
+ image_only_indicator.bool(),
+ torch.ones(1, 1, device=image_only_indicator.device),
+ rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
+ )
+ alpha = rearrange(alpha, self.rearrange_pattern)
+ else:
+ raise NotImplementedError
+ return alpha
+
+ def forward(
+ self,
+ x_spatial: torch.Tensor,
+ x_temporal: torch.Tensor,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ alpha = self.get_alpha(image_only_indicator)
+ x = alpha * x_spatial + (1.0 - alpha) * x_temporal
+ return x
diff --git a/sgm/modules/diffusionmodules/video_model.py b/sgm/modules/diffusionmodules/video_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42555f872b6ad662983e61f3548bf607192f536
--- /dev/null
+++ b/sgm/modules/diffusionmodules/video_model.py
@@ -0,0 +1,754 @@
+from functools import partial
+from typing import List, Optional, Union
+
+from einops import rearrange, repeat
+import copy
+
+from ...modules.diffusionmodules.openaimodel import *
+from ...modules.video_attention import SpatialVideoTransformer
+from ...modules.diffusionmodules.model import FaceLocator
+from ...util import default
+from .util import AlphaBlender
+
+
+class VideoResBlock(ResBlock):
+ def __init__(
+ self,
+ channels: int,
+ emb_channels: int,
+ dropout: float,
+ video_kernel_size: Union[int, List[int]] = 3,
+ merge_strategy: str = "fixed",
+ merge_factor: float = 0.5,
+ out_channels: Optional[int] = None,
+ use_conv: bool = False,
+ use_scale_shift_norm: bool = False,
+ dims: int = 2,
+ use_checkpoint: bool = False,
+ up: bool = False,
+ down: bool = False,
+ skip_time: bool = False,
+ ):
+ super().__init__(
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=out_channels,
+ use_conv=use_conv,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ up=up,
+ down=down,
+ )
+
+ self.time_stack = ResBlock(
+ default(out_channels, channels),
+ emb_channels,
+ dropout=dropout,
+ dims=3,
+ out_channels=default(out_channels, channels),
+ use_scale_shift_norm=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ kernel_size=video_kernel_size,
+ use_checkpoint=use_checkpoint,
+ exchange_temb_dims=True,
+ )
+ self.time_mixer = AlphaBlender(
+ alpha=merge_factor,
+ merge_strategy=merge_strategy,
+ rearrange_pattern="b t -> b 1 t 1 1",
+ )
+ self.skip_time = skip_time
+
+ def forward(
+ self,
+ x: th.Tensor,
+ emb: th.Tensor,
+ num_video_frames: int,
+ image_only_indicator: Optional[th.Tensor] = None,
+ ) -> th.Tensor:
+ x = super().forward(x, emb)
+
+ if self.skip_time:
+ return x
+
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
+
+ x = self.time_stack(
+ x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
+ )
+ x = self.time_mixer(
+ x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
+ )
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ return x
+
+
+class VideoUNet(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ model_channels: int,
+ out_channels: int,
+ num_res_blocks: int,
+ attention_resolutions: int,
+ dropout: float = 0.0,
+ channel_mult: List[int] = (1, 2, 4, 8),
+ conv_resample: bool = True,
+ dims: int = 2,
+ num_classes: Optional[int] = None,
+ use_checkpoint: bool = False,
+ num_heads: int = -1,
+ num_head_channels: int = -1,
+ num_heads_upsample: int = -1,
+ use_scale_shift_norm: bool = False,
+ resblock_updown: bool = False,
+ transformer_depth: Union[List[int], int] = 1,
+ transformer_depth_middle: Optional[int] = None,
+ context_dim: Optional[int] = None,
+ time_downup: bool = False,
+ time_context_dim: Optional[int] = None,
+ extra_ff_mix_layer: bool = False,
+ use_spatial_context: bool = False,
+ merge_strategy: str = "fixed",
+ merge_factor: float = 0.5,
+ spatial_transformer_attn_type: str = "softmax",
+ video_kernel_size: Union[int, List[int]] = 3,
+ use_linear_in_transformer: bool = False,
+ adm_in_channels: Optional[int] = None,
+ disable_temporal_crossattention: bool = False,
+ max_ddpm_temb_period: int = 10000,
+ fine_tuning_method: str = None,
+ unfreeze_blocks: Optional[List[str]] = None,
+ adapter_kwargs: Optional[dict] = {},
+ audio_cond_method: str = None,
+ audio_dim: Optional[int] = 0,
+ additional_audio_frames: Optional[int] = 0,
+ skip_time: bool = False,
+ use_ada_aug: bool = False,
+ encode_landmarks: bool = False,
+ reference_to: str = None,
+ ):
+ super().__init__()
+ assert context_dim is not None
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1
+
+ if num_head_channels == -1:
+ assert num_heads != -1
+
+ self.additional_audio_frames = additional_audio_frames
+ audio_multiplier = additional_audio_frames * 2 + 1
+ audio_dim = audio_dim * audio_multiplier
+
+ self.audio_is_context = "both" in audio_cond_method
+
+ if "both" == audio_cond_method:
+ audio_cond_method = "to_time_emb_image"
+ elif "both_keyframes" == audio_cond_method:
+ audio_cond_method = "to_time_emb"
+
+ if "to_time_emb" in audio_cond_method:
+ adm_in_channels += audio_dim
+
+ print(adm_in_channels, audio_dim, audio_cond_method)
+
+ self.adapter = None
+ self.audio_cond_method = audio_cond_method
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ transformer_depth_middle = default(
+ transformer_depth_middle, transformer_depth[-1]
+ )
+
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.use_ada_aug = use_ada_aug
+ if use_ada_aug:
+ self.map_aug = linear(9, time_embed_dim)
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+
+ elif self.num_classes == "sequential":
+ if adm_in_channels > 0:
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ # Disabling the label embedding
+ self.num_classes = None
+ else:
+ raise ValueError()
+
+ self.encode_landmarks = encode_landmarks
+ if encode_landmarks:
+ self.face_locator = FaceLocator(
+ 320, conditioning_channels=3, block_out_channels=(16, 32, 96, 256)
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+
+ def get_attention_layer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=1,
+ context_dim=None,
+ use_checkpoint=False,
+ disabled_sa=False,
+ audio_context_dim=None,
+ ):
+ return SpatialVideoTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=depth,
+ context_dim=context_dim,
+ audio_context_dim=audio_context_dim,
+ time_context_dim=time_context_dim,
+ dropout=dropout,
+ ff_in=extra_ff_mix_layer,
+ use_spatial_context=use_spatial_context,
+ merge_strategy=merge_strategy,
+ merge_factor=merge_factor,
+ checkpoint=use_checkpoint,
+ use_linear=use_linear_in_transformer,
+ attn_mode=spatial_transformer_attn_type,
+ disable_self_attn=disabled_sa,
+ disable_temporal_crossattention=disable_temporal_crossattention,
+ max_time_embed_period=max_ddpm_temb_period,
+ skip_time=skip_time,
+ reference_to=reference_to,
+ )
+
+ def get_resblock(
+ merge_factor,
+ merge_strategy,
+ video_kernel_size,
+ ch,
+ time_embed_dim,
+ dropout,
+ out_ch,
+ dims,
+ use_checkpoint,
+ use_scale_shift_norm,
+ down=False,
+ up=False,
+ ):
+ return VideoResBlock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ channels=ch,
+ emb_channels=time_embed_dim,
+ dropout=dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=down,
+ up=up,
+ skip_time=skip_time,
+ )
+
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ out_ch=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ layers.append(
+ get_attention_layer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ audio_context_dim=audio_dim
+ if "cross_attention" in audio_cond_method
+ else None,
+ use_checkpoint=use_checkpoint,
+ disabled_sa=False,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ ds *= 2
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ out_ch=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch,
+ conv_resample,
+ dims=dims,
+ out_channels=out_ch,
+ third_down=time_downup,
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ self.middle_block = TimestepEmbedSequential(
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ time_embed_dim=time_embed_dim,
+ out_ch=None,
+ dropout=dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ get_attention_layer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ audio_context_dim=audio_dim
+ if "new_cross_attention" in audio_cond_method
+ else None,
+ use_checkpoint=use_checkpoint,
+ ),
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ out_ch=None,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+
+ layers = [
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch + ich,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ out_ch=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ layers.append(
+ get_attention_layer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ audio_context_dim=audio_dim
+ if "new_cross_attention" == audio_cond_method
+ else None,
+ use_checkpoint=use_checkpoint,
+ disabled_sa=False,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ ds //= 2
+ layers.append(
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ out_ch=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(
+ ch,
+ conv_resample,
+ dims=dims,
+ out_channels=out_ch,
+ third_up=time_downup,
+ )
+ )
+
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ if fine_tuning_method is not None:
+ # Freeze everything except the adapter
+ for param in self.parameters():
+ param.requires_grad = False
+ if self.adapter is not None:
+ for param in self.adapter.parameters():
+ param.requires_grad = True
+ if len(unfreeze_blocks):
+ if "input" in unfreeze_blocks:
+ for param in self.input_blocks[0].parameters():
+ param.requires_grad = True
+ # break # only unfreeze the first input block
+ if "label_emb" in unfreeze_blocks:
+ for param in self.label_emb.parameters():
+ param.requires_grad = True
+
+ def get_skip_attention_at(
+ self,
+ skip_attention_at: List[int],
+ curr_layer: int,
+ batch_size: int,
+ num_video_frames: int,
+ ):
+ if skip_attention_at is None:
+ return None
+
+ skip_attention = th.zeros(len(skip_attention_at), 1, dtype=th.bool)
+
+ for i, layer in enumerate(skip_attention_at):
+ skip_attention[i] = layer == curr_layer
+ skip_attention = repeat(
+ skip_attention, "b ... -> (b t) ...", t=num_video_frames
+ )
+ assert skip_attention.shape[0] == batch_size, (
+ f"{skip_attention.shape[0]} != {batch_size}"
+ )
+ return skip_attention
+
+ def forward(
+ self,
+ x: th.Tensor,
+ timesteps: th.Tensor,
+ context: Optional[th.Tensor] = None,
+ reference_context: Optional[th.Tensor] = None,
+ y: Optional[th.Tensor] = None,
+ audio_emb: Optional[th.Tensor] = None,
+ landmarks: Optional[th.Tensor] = None,
+ aug_labels: Optional[th.Tensor] = None,
+ time_context: Optional[th.Tensor] = None,
+ num_video_frames: Optional[int] = 1,
+ image_only_indicator: Optional[th.Tensor] = None,
+ skip_spatial_attention_at: Optional[List[int]] = None,
+ skip_temporal_attention_at: Optional[List[int]] = None,
+ ):
+ if self.audio_is_context:
+ assert audio_emb is None
+ audio_emb = context.clone()
+
+ curr_context_idx = 0
+ num_video_frames = (
+ num_video_frames
+ if isinstance(num_video_frames, int)
+ else num_video_frames[0]
+ )
+ if reference_context is not None:
+ copy_context = copy.deepcopy(reference_context)
+ mid = copy_context.pop(-1)
+ copy_context.insert((len(copy_context) // 2) - 1, mid)
+ reference_context = copy_context
+ curr_context_idx = 0
+ if num_video_frames > 1:
+ reference_context = [
+ repeat(ref_context, "b h w -> (b t) h w", t=num_video_frames)
+ for ref_context in reference_context
+ ]
+
+ or_batch_size = x.shape[0] // num_video_frames
+ if (
+ image_only_indicator is not None
+ and image_only_indicator.shape[0] != or_batch_size
+ ):
+ # TODO: fix this
+ image_only_indicator = repeat(
+ image_only_indicator, "b ... -> (b t) ...", t=2
+ )
+
+ if context is not None and x.shape[0] != context.shape[0]:
+ context = repeat(context, "b ... -> b t ...", t=num_video_frames)
+ context = rearrange(context, "b t ... -> (b t) ...", t=num_video_frames)
+
+ if "cross_attention" in self.audio_cond_method:
+ assert audio_emb is not None
+ if audio_emb.ndim == 4:
+ audio_emb = rearrange(audio_emb, "b t d c -> b (t d) c")
+
+ # context = th.cat([context, audio_emb], dim=1)
+
+ if self.audio_cond_method == "cross_time":
+ assert audio_emb is not None
+ time_context = audio_emb
+
+ if y is not None and y.shape[0] != x.shape[0]:
+ y = repeat(y, "b ... -> b t ...", t=num_video_frames)
+ y = rearrange(y, "b t ... -> (b t) ...", t=num_video_frames)
+
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y is not None or "to_time_emb" in self.audio_cond_method
+
+ if self.audio_cond_method == "to_time_emb":
+ assert audio_emb is not None
+ audio_emb = rearrange(audio_emb, "b t c -> (b t) c")
+ if y is not None:
+ y = th.cat([y, audio_emb], dim=1)
+ else:
+ y = audio_emb
+ elif self.audio_cond_method == "to_time_emb_image":
+ assert audio_emb is not None
+
+ audio_emb = rearrange(audio_emb, "b t c -> b (t c)")
+ if y is not None:
+ y = th.cat([y, audio_emb], dim=1)
+ else:
+ y = audio_emb
+ assert y.shape[0] == x.shape[0], (
+ f"{y.shape} != {x.shape} and audio_emb.shape: {audio_emb.shape}"
+ )
+
+ emb = emb + self.label_emb(y)
+
+ if self.use_ada_aug:
+ assert aug_labels is not None, (
+ "must provide aug_labels if use_ada_aug is True"
+ )
+ emb = emb + self.map_aug(aug_labels)
+
+ h = x
+
+ if self.encode_landmarks:
+ landmarks_emb = self.face_locator(landmarks)
+ landmarks_emb = rearrange(landmarks_emb, "b c t h w -> (b t) c h w")
+ # print("landmarks_emb:", landmarks_emb.shape)
+ for i, module in enumerate(self.input_blocks):
+ # print(image_only_indicator.shape, num_video_frames, h.shape)
+ if i == 1 and self.encode_landmarks:
+ h = h + landmarks_emb
+ # print("h.shape:", h.shape, i)
+ skip_spatial_attention = self.get_skip_attention_at(
+ skip_spatial_attention_at,
+ curr_context_idx,
+ x.shape[0],
+ num_video_frames,
+ )
+ skip_temporal_attention = self.get_skip_attention_at(
+ skip_temporal_attention_at,
+ curr_context_idx,
+ x.shape[0],
+ num_video_frames,
+ )
+ h, is_attention = module(
+ h,
+ emb,
+ context=context,
+ reference_context=reference_context[curr_context_idx]
+ if reference_context is not None
+ else None,
+ audio_context=audio_emb
+ if "cross_attention" in self.audio_cond_method
+ else None,
+ image_only_indicator=image_only_indicator,
+ time_context=time_context,
+ num_video_frames=num_video_frames,
+ skip_spatial_attention=skip_spatial_attention,
+ skip_temporal_attention=skip_temporal_attention,
+ )
+ if is_attention:
+ curr_context_idx = (
+ None if curr_context_idx is None else curr_context_idx + 1
+ )
+ hs.append(h)
+ skip_spatial_attention = self.get_skip_attention_at(
+ skip_spatial_attention_at, curr_context_idx, x.shape[0], num_video_frames
+ )
+ skip_temporal_attention = self.get_skip_attention_at(
+ skip_temporal_attention_at, curr_context_idx, x.shape[0], num_video_frames
+ )
+ h, is_attention = self.middle_block(
+ h,
+ emb,
+ context=context,
+ reference_context=reference_context[curr_context_idx]
+ if reference_context is not None
+ else None,
+ audio_context=audio_emb
+ if "cross_attention" in self.audio_cond_method
+ else None,
+ image_only_indicator=image_only_indicator,
+ time_context=time_context,
+ num_video_frames=num_video_frames,
+ skip_spatial_attention=skip_spatial_attention,
+ skip_temporal_attention=skip_temporal_attention,
+ )
+ curr_context_idx = None if curr_context_idx is None else curr_context_idx + 1
+ for i, module in enumerate(self.output_blocks):
+ skip_x = hs.pop()
+ if self.adapter is not None:
+ skip_x = self.adapter[i](
+ skip_x, n_frames=num_video_frames, condition=audio_emb
+ )
+ h = th.cat([h, skip_x], dim=1)
+ skip_spatial_attention = self.get_skip_attention_at(
+ skip_spatial_attention_at,
+ curr_context_idx,
+ x.shape[0],
+ num_video_frames,
+ )
+ skip_temporal_attention = self.get_skip_attention_at(
+ skip_temporal_attention_at,
+ curr_context_idx,
+ x.shape[0],
+ num_video_frames,
+ )
+ h, is_attention = module(
+ h,
+ emb,
+ context=context,
+ reference_context=reference_context[curr_context_idx]
+ if reference_context is not None
+ else None,
+ audio_context=audio_emb
+ if "cross_attention" in self.audio_cond_method
+ else None,
+ image_only_indicator=image_only_indicator,
+ time_context=time_context,
+ num_video_frames=num_video_frames,
+ skip_spatial_attention=skip_spatial_attention,
+ skip_temporal_attention=skip_temporal_attention,
+ )
+ if is_attention:
+ curr_context_idx = (
+ None if curr_context_idx is None else curr_context_idx + 1
+ )
+ # h = h.type(x.dtype)
+ return self.out(h)
diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..89c67ab243bd2ed7a9ae03475adfd1756a9035f6
--- /dev/null
+++ b/sgm/modules/diffusionmodules/wrappers.py
@@ -0,0 +1,337 @@
+import torch
+import torch.nn as nn
+from packaging import version
+from einops import repeat, rearrange
+from diffusers.utils import _get_model_file
+from diffusers.models.modeling_utils import load_state_dict
+from ...modules.diffusionmodules.augment_pipeline import AugmentPipe
+from ...modules.encoders.modules import ConcatTimestepEmbedderND
+from ...util import append_dims
+
+
+OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
+
+
+class IdentityWrapper(nn.Module):
+ def __init__(self, diffusion_model, compile_model: bool = False):
+ super().__init__()
+ compile = (
+ torch.compile
+ if (version.parse(torch.__version__) >= version.parse("2.0.0"))
+ and compile_model
+ else lambda x: x
+ )
+ self.diffusion_model = compile(diffusion_model)
+
+ def forward(self, *args, **kwargs):
+ return self.diffusion_model(*args, **kwargs)
+
+
+class OpenAIWrapper(IdentityWrapper):
+ def __init__(
+ self,
+ diffusion_model,
+ compile_model: bool = False,
+ ada_aug_percent=0.0,
+ fix_image_leak=False,
+ add_embeddings=False,
+ im_size=[64, 64],
+ n_channels=4,
+ ):
+ super().__init__(diffusion_model, compile_model)
+ self.fix_image_leak = fix_image_leak
+ if fix_image_leak:
+ self.beta_m = 15
+ self.a = 5
+ self.noise_encoder = ConcatTimestepEmbedderND(256)
+
+ self.augment_pipe = None
+ if ada_aug_percent > 0.0:
+ augment_kwargs = dict(
+ xflip=1e8, yflip=1, scale=1, rotate_frac=1, aniso=1, translate_frac=1
+ )
+ self.augment_pipe = AugmentPipe(ada_aug_percent, **augment_kwargs)
+
+ self.add_embeddings = add_embeddings
+ if add_embeddings:
+ self.learned_mask = nn.Parameter(
+ torch.zeros(n_channels, im_size[0], im_size[1])
+ )
+
+ def get_noised_input(
+ self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
+ ) -> torch.Tensor:
+ noised_input = input + noise * sigmas_bc
+ return noised_input
+
+ def forward(
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
+ ) -> torch.Tensor:
+ cond_cat = c.get("concat", torch.Tensor([]).type_as(x))
+
+ if len(cond_cat.shape) and cond_cat.shape[0]:
+ T = x.shape[0] // cond_cat.shape[0]
+ if self.fix_image_leak:
+ noise_aug_strength = get_sigma_s(
+ rearrange(t, "(b t) ... -> b t ...", b=T)[: cond_cat.shape[0], 0] / 700,
+ self.a,
+ self.beta_m,
+ )
+ noise_aug = append_dims(noise_aug_strength, 4).to(x.device)
+ noise = torch.randn_like(noise_aug)
+ cond_cat = self.get_noised_input(noise_aug, noise, cond_cat)
+ noise_emb = self.noise_encoder(noise_aug_strength).to(x.device)
+ c["vector"] = (
+ noise_emb
+ if "vector" not in c
+ else torch.cat([c["vector"], noise_emb], dim=1)
+ )
+
+ if (
+ len(cond_cat.shape)
+ and cond_cat.shape[0]
+ and x.shape[0] != cond_cat.shape[0]
+ ):
+ cond_cat = repeat(cond_cat, "b c h w -> b c t h w", t=T)
+ cond_cat = rearrange(cond_cat, "b c t h w -> (b t) c h w")
+ x = torch.cat((x, cond_cat), dim=1)
+
+ if self.add_embeddings:
+ learned_mask = repeat(
+ self.learned_mask.to(x.device), "c h w -> b c h w", b=cond_cat.shape[0]
+ )
+ x = torch.cat((x, learned_mask), dim=1)
+
+ if self.augment_pipe is not None:
+ x, labels = self.augment_pipe(x)
+ else:
+ labels = torch.zeros(x.shape[0], 9, device=x.device)
+
+ return self.diffusion_model(
+ x,
+ timesteps=t,
+ context=c.get("crossattn", None),
+ reference_context=c.get("reference", None),
+ y=c.get("vector", None),
+ audio_emb=c.get("audio_emb", None),
+ landmarks=c.get("landmarks", None),
+ aug_labels=labels,
+ **kwargs,
+ )
+
+
+class DubbingWrapper(IdentityWrapper):
+ def __init__(self, diffusion_model, compile_model: bool = False, mask_input=False):
+ super().__init__(diffusion_model, compile_model)
+ self.mask_input = mask_input
+
+ def forward(
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
+ ) -> torch.Tensor:
+ cond_cat = c.get("concat", torch.Tensor([]).type_as(x))
+ if len(cond_cat.shape):
+ T = x.shape[0] // cond_cat.shape[0]
+ if cond_cat.shape[1] == 4:
+ cond_cat = repeat(cond_cat, "b c h w -> b (t c) h w", t=T)
+ cond_cat = rearrange(cond_cat, "b (t c) h w -> (b t) c h w", t=T)
+
+ x = torch.cat((x, cond_cat), dim=1)
+ out = self.diffusion_model(
+ x,
+ timesteps=t,
+ context=c.get("crossattn", None),
+ y=c.get("vector", None),
+ audio_emb=c.get("audio_emb", None),
+ skip_spatial_attention_at=c.get("skip_spatial_attention_at", None),
+ skip_temporal_attention_at=c.get("skip_temporal_attention_at", None),
+ **kwargs,
+ )
+
+ return out
+
+
+class StabilityWrapper(IdentityWrapper):
+ def __init__(
+ self,
+ diffusion_model,
+ compile_model: bool = False,
+ use_ipadapter: bool = False,
+ ipadapter_model: str = "ip-adapter_sd15.bin",
+ adapter_scale: float = 1.0,
+ n_adapters: int = 1,
+ skip_text_emb: bool = False,
+ # pass_image_emb_to_hidden_states: bool = False,
+ ):
+ super().__init__(diffusion_model, compile_model)
+ self.use_ipadapter = use_ipadapter
+ # self.pass_image_emb_to_hidden_states = pass_image_emb_to_hidden_states
+
+ if use_ipadapter:
+ model_file = _get_model_file(
+ "h94/IP-Adapter",
+ weights_name=ipadapter_model, # ip-adapter_sd15.bin
+ # cache_dir="/vol/paramonos2/projects/antoni/.cache",
+ subfolder="models",
+ )
+ state_dict = load_state_dict(model_file)
+ state_dict = [load_state_dict(model_file)] * n_adapters
+ print(f"Loading IP-Adapter weights from {model_file}")
+
+ diffusion_model.set_ip_adapter_scale(adapter_scale)
+
+ def forward(
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
+ ) -> torch.Tensor:
+ added_cond_kwargs = None
+ if self.use_ipadapter:
+ added_cond_kwargs = {"image_embeds": c.get("image_embeds", None)}
+ landmarks = c.get("landmarks", None)
+ if landmarks is not None:
+ added_cond_kwargs["image_embeds"] = [
+ added_cond_kwargs["image_embeds"],
+ landmarks,
+ ]
+
+ cond_cat = c.get("concat", torch.Tensor([]).type_as(x))
+ if len(cond_cat.shape) and cond_cat.shape[0]:
+ cond_cat = repeat(
+ cond_cat, "b c h w -> b c t h w", t=x.shape[0] // cond_cat.shape[0]
+ )
+ cond_cat = rearrange(cond_cat, "b c t h w -> (b t) c h w")
+ x = torch.cat((x, cond_cat), dim=1)
+
+ return self.diffusion_model(
+ x,
+ t,
+ encoder_hidden_states=c.get("crossattn", None),
+ added_cond_kwargs=added_cond_kwargs,
+ audio_emb=c.get("audio_emb", None),
+ **kwargs,
+ )[0]
+
+
+def logit_normal_sampler(m, s=1, beta_m=15, sample_num=1000000):
+ y_samples = torch.randn(sample_num) * s + m
+ x_samples = beta_m * (torch.exp(y_samples) / (1 + torch.exp(y_samples)))
+ return x_samples
+
+
+def mu_t(t, a=5, mu_max=1):
+ t = t.to("cpu")
+ return 2 * mu_max * t**a - mu_max
+
+
+def get_sigma_s(t, a, beta_m):
+ mu = mu_t(t, a=a)
+ sigma_s = logit_normal_sampler(m=mu, sample_num=t.shape[0], beta_m=beta_m)
+ return sigma_s
+
+
+class InterpolationWrapper(IdentityWrapper):
+ def __init__(
+ self,
+ diffusion_model,
+ compile_model: bool = False,
+ im_size=[512, 512],
+ n_channels=4,
+ starting_mask_method="zeros",
+ add_mask=True,
+ fix_image_leak=False,
+ ):
+ super().__init__(diffusion_model, compile_model)
+ im_size = [
+ x // 8 for x in im_size
+ ] # 8 is the default downscaling factor in the vae model
+ if starting_mask_method == "zeros":
+ self.learned_mask = nn.Parameter(
+ torch.zeros(n_channels, im_size[0], im_size[1])
+ )
+ elif starting_mask_method == "ones":
+ self.learned_mask = nn.Parameter(
+ torch.ones(n_channels, im_size[0], im_size[1])
+ )
+ elif starting_mask_method == "random":
+ self.learned_mask = nn.Parameter(
+ torch.randn(n_channels, im_size[0], im_size[1])
+ )
+ elif starting_mask_method == "none":
+ self.learned_mask = None
+ elif starting_mask_method == "fixed_ones":
+ self.learned_mask = torch.ones(n_channels, im_size[0], im_size[1])
+ elif starting_mask_method == "fixed_zeros":
+ self.learned_mask = torch.zeros(n_channels, im_size[0], im_size[1])
+ else:
+ raise NotImplementedError(
+ f"Unknown stating_mask_method: {starting_mask_method}"
+ )
+
+ self.add_mask = add_mask
+ self.fix_image_leak = fix_image_leak
+ if fix_image_leak:
+ self.beta_m = 15
+ self.a = 5
+ self.noise_encoder = ConcatTimestepEmbedderND(256)
+
+ def get_noised_input(
+ self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
+ ) -> torch.Tensor:
+ noised_input = input + noise * sigmas_bc
+ return noised_input
+
+ def forward(
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
+ ) -> torch.Tensor:
+ cond_cat = c.get("concat", torch.Tensor([]).type_as(x))
+ T = x.shape[0] // cond_cat.shape[0]
+
+ if self.fix_image_leak:
+ noise_aug_strength = get_sigma_s(
+ rearrange(t, "(b t) ... -> b t ...", b=T)[: cond_cat.shape[0], 0] / 700,
+ self.a,
+ self.beta_m,
+ )
+ noise_aug = append_dims(noise_aug_strength, 4).to(x.device)
+ noise = torch.randn_like(noise_aug)
+ cond_cat = self.get_noised_input(noise_aug, noise, cond_cat)
+ noise_emb = self.noise_encoder(noise_aug_strength).to(x.device)
+ # cond["vector"] = noise_emb if "vector" not in cond else torch.cat([cond["vector"], noise_emb], dim=1)
+ c["vector"] = noise_emb
+
+ cond_cat = rearrange(cond_cat, "b (t c) h w -> b c t h w", t=2)
+
+ start, end = cond_cat.chunk(2, dim=2)
+ if self.learned_mask is None:
+ learned_mask = torch.stack(
+ [start.squeeze(2)] * (T // 2 - 1) + [end.squeeze(2)] * (T // 2 - 1),
+ dim=2,
+ )
+ else:
+ learned_mask = repeat(
+ self.learned_mask.to(x.device), "c h w -> b c h w", b=cond_cat.shape[0]
+ )
+ ones_mask = torch.ones_like(learned_mask)[:, 0].unsqueeze(1)
+ zeros_mask = torch.zeros_like(learned_mask)[:, 0].unsqueeze(1)
+ if self.learned_mask is None:
+ cond_seq = torch.cat([start] + [learned_mask] + [end], dim=2)
+ else:
+ cond_seq = torch.stack(
+ [start.squeeze(2)] + [learned_mask] * (T - 2) + [end.squeeze(2)], dim=2
+ )
+ cond_seq = rearrange(cond_seq, "b c t h w -> (b t) c h w")
+
+ x = torch.cat((x, cond_seq), dim=1)
+ if self.add_mask:
+ mask_seq = torch.stack(
+ [ones_mask] + [zeros_mask] * (T - 2) + [ones_mask], dim=2
+ )
+ mask_seq = rearrange(mask_seq, "b c t h w -> (b t) c h w")
+ x = torch.cat((x, mask_seq), dim=1)
+
+ return self.diffusion_model(
+ x,
+ timesteps=t,
+ context=c.get("crossattn", None),
+ y=c.get("vector", None),
+ audio_emb=c.get("audio_emb", None),
+ **kwargs,
+ )
diff --git a/sgm/modules/distributions/__init__.py b/sgm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/distributions/__pycache__/__init__.cpython-311.pyc b/sgm/modules/distributions/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..278ca03b012f99a398c76dbdba617e1e3bf9793e
Binary files /dev/null and b/sgm/modules/distributions/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/distributions/__pycache__/distributions.cpython-311.pyc b/sgm/modules/distributions/__pycache__/distributions.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2209752482c2cb6b781d34620c3a1c145cb377d7
Binary files /dev/null and b/sgm/modules/distributions/__pycache__/distributions.cpython-311.pyc differ
diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..016be35523187ea366db9ade391fe8ee276db60b
--- /dev/null
+++ b/sgm/modules/distributions/distributions.py
@@ -0,0 +1,102 @@
+import numpy as np
+import torch
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(
+ device=self.parameters.device
+ )
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
+ device=self.parameters.device
+ )
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..97b5ae2b230f89b4dba57e44c4f851478ad86f68
--- /dev/null
+++ b/sgm/modules/ema.py
@@ -0,0 +1,86 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ torch.tensor(0, dtype=torch.int)
+ if use_num_upates
+ else torch.tensor(-1, dtype=torch.int),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(
+ one_minus_decay * (shadow_params[sname] - m_param[key])
+ )
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/sgm/modules/encoders/__init__.py b/sgm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/encoders/__pycache__/__init__.cpython-311.pyc b/sgm/modules/encoders/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a99d3f8a6820b2ce7bc0f112ff269e4a44f711a
Binary files /dev/null and b/sgm/modules/encoders/__pycache__/__init__.cpython-311.pyc differ
diff --git a/sgm/modules/encoders/__pycache__/modules.cpython-311.pyc b/sgm/modules/encoders/__pycache__/modules.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a53ae689a1a7cc597950597d6c5dce0c56c7ca5
Binary files /dev/null and b/sgm/modules/encoders/__pycache__/modules.cpython-311.pyc differ
diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdc238ddf87457e7b9d7039f5cae7d079b78abe4
--- /dev/null
+++ b/sgm/modules/encoders/modules.py
@@ -0,0 +1,1474 @@
+import math
+from contextlib import nullcontext
+from functools import partial
+from typing import Dict, List, Optional, Tuple, Union
+import os
+
+import kornia
+import numpy as np
+import open_clip
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from omegaconf import ListConfig
+from torch.utils.checkpoint import checkpoint
+from transformers import (
+ ByT5Tokenizer,
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5Tokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
+from ...modules.diffusionmodules.model import Encoder
+from ...modules.diffusionmodules.openaimodel import Timestep
+from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ...modules.distributions.distributions import DiagonalGaussianDistribution
+from ...util import (
+ append_dims,
+ autocast,
+ count_params,
+ default,
+ disabled_train,
+ expand_dims_like,
+ instantiate_from_config,
+)
+
+from facenet_pytorch import MTCNN, InceptionResnetV1
+from insightface.app import FaceAnalysis
+
+
+class AbstractEmbModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._is_trainable = None
+ self._ucg_rate = None
+ self._input_key = None
+
+ @property
+ def is_trainable(self) -> bool:
+ return self._is_trainable
+
+ @property
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
+ return self._ucg_rate
+
+ @property
+ def input_key(self) -> str:
+ return self._input_key
+
+ @is_trainable.setter
+ def is_trainable(self, value: bool):
+ self._is_trainable = value
+
+ @ucg_rate.setter
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
+ self._ucg_rate = value
+
+ @input_key.setter
+ def input_key(self, value: str):
+ self._input_key = value
+
+ @is_trainable.deleter
+ def is_trainable(self):
+ del self._is_trainable
+
+ @ucg_rate.deleter
+ def ucg_rate(self):
+ del self._ucg_rate
+
+ @input_key.deleter
+ def input_key(self):
+ del self._input_key
+
+
+class GeneralConditioner(nn.Module):
+ OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
+ KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
+
+ def __init__(self, emb_models: Union[List, ListConfig]):
+ super().__init__()
+ embedders = []
+ for n, embconfig in enumerate(emb_models):
+ embedder = instantiate_from_config(embconfig)
+ assert isinstance(embedder, AbstractEmbModel), (
+ f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
+ )
+ embedder.is_trainable = embconfig.get("is_trainable", False)
+ embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
+ if not embedder.is_trainable:
+ embedder.train = disabled_train
+ for param in embedder.parameters():
+ param.requires_grad = False
+ embedder.eval()
+ print(
+ f"Initialized embedder #{n}: {embedder.__class__.__name__} "
+ f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
+ )
+
+ if "input_key" in embconfig:
+ embedder.input_key = embconfig["input_key"]
+ elif "input_keys" in embconfig:
+ embedder.input_keys = embconfig["input_keys"]
+ else:
+ raise KeyError(
+ f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
+ )
+
+ embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
+ if embedder.legacy_ucg_val is not None:
+ embedder.ucg_prng = np.random.RandomState()
+
+ embedders.append(embedder)
+ self.embedders = nn.ModuleList(embedders)
+
+ def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
+ assert embedder.legacy_ucg_val is not None
+ p = embedder.ucg_rate
+ val = embedder.legacy_ucg_val
+ for i in range(len(batch[embedder.input_key])):
+ if embedder.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[embedder.input_key][i] = val
+ return batch
+
+ def forward(
+ self, batch: Dict, force_zero_embeddings: Optional[List] = None
+ ) -> Dict:
+ output = dict()
+ if force_zero_embeddings is None:
+ force_zero_embeddings = []
+ for embedder in self.embedders:
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
+ with embedding_context():
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
+ if embedder.legacy_ucg_val is not None:
+ batch = self.possibly_get_ucg_val(embedder, batch)
+ emb_out = embedder(batch[embedder.input_key])
+ elif hasattr(embedder, "input_keys"):
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
+ assert isinstance(emb_out, (torch.Tensor, list, tuple)), (
+ f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
+ )
+ if not isinstance(emb_out, (list, tuple)):
+ emb_out = [emb_out]
+ # TODO: In future cond_type is probably better than OUTPUT_DIM2KEYS
+ has_cond_type = hasattr(embedder, "cond_type")
+ if has_cond_type and embedder.cond_type == "reference":
+ reference_list = []
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
+ probal_null = expand_dims_like(
+ torch.bernoulli(
+ (1.0 - embedder.ucg_rate)
+ * torch.ones(emb_out[0].shape[0], device=emb_out[0].device)
+ ),
+ emb_out[0],
+ )
+ for emb in emb_out:
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
+ emb = probal_null * emb
+ if (
+ hasattr(embedder, "input_key")
+ and embedder.input_key in force_zero_embeddings
+ ):
+ emb = torch.zeros_like(emb)
+ reference_list.append(emb)
+ output["reference"] = reference_list
+ else:
+ for emb in emb_out:
+ if has_cond_type:
+ out_key = embedder.cond_type
+ else:
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
+ emb = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - embedder.ucg_rate)
+ * torch.ones(emb.shape[0], device=emb.device)
+ ),
+ emb,
+ )
+ * emb
+ )
+ if (
+ hasattr(embedder, "input_key")
+ and embedder.input_key in force_zero_embeddings
+ ):
+ emb = torch.zeros_like(emb)
+
+ if out_key in output:
+ output[out_key] = torch.cat(
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
+ )
+ else:
+ output[out_key] = emb
+ return output
+
+ def get_unconditional_conditioning(
+ self,
+ batch_c: Dict,
+ batch_uc: Optional[Dict] = None,
+ force_uc_zero_embeddings: Optional[List[str]] = None,
+ force_cond_zero_embeddings: Optional[List[str]] = None,
+ ):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ ucg_rates = list()
+ for embedder in self.embedders:
+ ucg_rates.append(embedder.ucg_rate)
+ embedder.ucg_rate = 0.0
+ c = self(batch_c, force_cond_zero_embeddings)
+ uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
+
+ for embedder, rate in zip(self.embedders, ucg_rates):
+ embedder.ucg_rate = rate
+ return c, uc
+
+
+class InceptionV3(nn.Module):
+ """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
+ port with an additional squeeze at the end"""
+
+ def __init__(self, normalize_input=False, **kwargs):
+ super().__init__()
+ from pytorch_fid import inception
+
+ kwargs["resize_input"] = True
+ self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
+
+ def forward(self, inp):
+ outp = self.model(inp)
+
+ if len(outp) == 1:
+ return outp[0].squeeze()
+
+ return outp
+
+
+class IdentityEncoder(AbstractEmbModel):
+ def __init__(self, cond_type=None):
+ super().__init__()
+ if cond_type is not None:
+ setattr(self, "cond_type", cond_type)
+
+ def encode(self, x):
+ return x
+
+ def forward(self, x):
+ return x
+
+
+class EmotionLabelEmbedder(AbstractEmbModel):
+ def __init__(self, num_emotions=8, embedding_dim=256):
+ super().__init__()
+ self.embedding = nn.Embedding(num_emotions, embedding_dim)
+ self.embedding_dim = embedding_dim
+
+ def forward(self, x):
+ # x should be a tensor of emotion label indices
+ # Shape: (batch_size, num_frames
+ return rearrange(self.embedding(x), "b t d -> (b t) d")
+
+
+class WhisperAudioEmbedder(AbstractEmbModel):
+ def __init__(
+ self, merge_method="mean", linear_dim=None, cond_type=None, audio_dim=None
+ ):
+ super().__init__()
+ if cond_type is not None:
+ setattr(self, "cond_type", cond_type)
+ else:
+ self.cond_type = "audio_emb"
+ self.merge_method = merge_method
+ self.linear = None
+ if audio_dim is not None:
+ self.audio_dim = audio_dim * 2 if merge_method == "concat" else audio_dim
+ else:
+ self.audio_dim = 768 * 2 if merge_method == "concat" else 768
+ if linear_dim is not None:
+ self.linear = nn.Linear(self.audio_dim, linear_dim)
+
+ def forward(self, x):
+ # x shape: (batch_size, n_frames, 2, 1280)
+ # print(f"Audio input shape: {x.shape}")
+ if self.merge_method == "mean":
+ x = x.mean(dim=2)
+ elif self.merge_method == "concat":
+ x = rearrange(x, "b n c d -> b n (c d)")
+ elif self.merge_method == "add":
+ x = x.sum(dim=2)
+ elif self.merge_method == "none" or self.merge_method is None:
+ pass
+ else:
+ raise NotImplementedError(f"Unknown merge method: {self.merge_method}")
+
+ if self.linear is not None:
+ x = self.linear(x)
+
+ return x
+
+
+class FaceEmbeddings(AbstractEmbModel):
+ def __init__(
+ self,
+ linear_dim=None,
+ id_embeddings_dim=512,
+ n_cond_frames=1,
+ n_copies=1,
+ face_type="insightface",
+ ):
+ super().__init__()
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(id_embeddings_dim, linear_dim),
+ torch.nn.GELU(),
+ torch.nn.Linear(linear_dim, linear_dim),
+ )
+ self.norm = torch.nn.LayerNorm(linear_dim)
+
+ self.face_type = face_type
+ if face_type == "insightface":
+ self.app = FaceAnalysis(
+ name="buffalo_l",
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
+ )
+ self.app.prepare(ctx_id=0, det_size=(320, 320))
+ else:
+ self.mtcnn = MTCNN(
+ image_size=160,
+ margin=0,
+ min_face_size=20,
+ post_process=True,
+ device="cuda",
+ ).eval() # Keep everything as default
+ self.resnet = InceptionResnetV1(pretrained="vggface2").eval()
+ for param in self.resnet.parameters():
+ param.requires_grad = False
+ for param in self.mtcnn.parameters():
+ param.requires_grad = False
+
+ self.n_cond_frames = n_cond_frames
+ self.n_copies = n_copies
+
+ @torch.no_grad()
+ def get_insightface_embeddings(self, x):
+ x = (((x + 1.0) / 2.0) * 255.0).clip(0, 255)
+ face_embeddings = torch.empty((len(x), 512), device=x.device)
+ for i in range(len(x)):
+ image = x[i].cpu().numpy()
+ face = self.app.get(image)[0]
+ face_embeddings[i] = torch.as_tensor(face.normed_embedding, device=x.device)
+ return face_embeddings
+
+ @torch.no_grad()
+ def get_facenet_embeddings(self, x):
+ x = (((x + 1.0) / 2.0) * 255.0).clip(0, 255)
+ img_crops = self.mtcnn(x, device=x.device)
+ img_crops = rearrange(torch.stack(img_crops), "b h w c -> b c h w")
+ return self.resnet(img_crops)
+
+ def get_embeddings(self, x):
+ if self.face_type == "insightface":
+ x = self.get_insightface_embeddings(x)
+ else:
+ x = self.get_facenet_embeddings(x)
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+ def forward(self, vid):
+ if vid.ndim == 5:
+ vid = rearrange(vid, "b c t h w -> (b t) h w c")
+ else:
+ vid = rearrange(vid, "b c h w -> b h w c")
+ vid = self.get_embeddings(vid)
+ vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
+ vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
+
+ return vid
+
+
+class ClassEmbedder(AbstractEmbModel):
+ def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
+ super().__init__()
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.add_sequence_dim = add_sequence_dim
+
+ def forward(self, c):
+ c = self.embedding(c)
+ if self.add_sequence_dim:
+ c = c[:, None, :]
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = (
+ self.n_classes - 1
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc.long()}
+ return uc
+
+
+class ClassEmbedderForMultiCond(ClassEmbedder):
+ def forward(self, batch, key=None, disable_dropout=False):
+ out = batch
+ key = default(key, self.key)
+ islist = isinstance(batch[key], list)
+ if islist:
+ batch[key] = batch[key][0]
+ c_out = super().forward(batch, key, disable_dropout)
+ out[key] = [c_out] if islist else c_out
+ return out
+
+
+class FrozenT5Embedder(AbstractEmbModel):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenByT5Embedder(AbstractEmbModel):
+ """
+ Uses the ByT5 transformer encoder for text. Is character-aware.
+ """
+
+ def __init__(
+ self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = ByT5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEmbModel):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+
+ LAYERS = ["last", "pooled", "hidden", "zero"]
+
+ def __init__(
+ self,
+ version="openai/clip-vit-large-patch14",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ layer_idx=None,
+ always_return_pooled=False,
+ null_text_path=None,
+ ): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+
+ self.null_text = None
+
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ if null_text_path is not None:
+ self.null_text = torch.load(null_text_path)
+ else:
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ self.return_pooled = always_return_pooled
+
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, text):
+ if self.null_text is not None:
+ if self.transformer is not None:
+ self.null_text = self.null_text.to(self.transformer.device)
+ self.transformer = None
+ torch.cuda.empty_cache()
+ return repeat(self.null_text, "b n c -> (b t) n c", t=len(text))
+
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
+ )
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "zero":
+ z = torch.zeros_like(outputs.last_hidden_state)
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ if self.return_pooled:
+ return z, outputs.pooler_output
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPImageEmbedder(AbstractEmbModel):
+ """
+ Uses the CLIP vision transformer encoder for images
+ """
+
+ def __init__(
+ self,
+ version="h94/IP-Adapter",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ antialias=True,
+ ucg_rate=0.0,
+ unsqueeze_dim=False,
+ repeat_to_max_len=False,
+ num_image_crops=0,
+ output_tokens=False,
+ init_device=None,
+ subfolder="models/image_encoder",
+ get_hidden_states=False,
+ ):
+ super().__init__()
+ model = CLIPVisionModelWithProjection.from_pretrained(
+ version, subfolder=subfolder
+ )
+ # del model.transformer
+ self.model = model
+ self.max_crops = num_image_crops
+ self.pad_to_max_len = self.max_crops > 0
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
+ self.device = device
+ self.max_length = max_length
+ self.get_hidden_states = get_hidden_states
+
+ if freeze:
+ self.freeze()
+
+ self.antialias = antialias
+
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+ self.ucg_rate = ucg_rate
+ self.unsqueeze_dim = unsqueeze_dim
+ self.stored_batch = None
+ # self.model.visual.output_tokens = output_tokens
+ self.output_tokens = output_tokens
+
+ def preprocess(self, x):
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=True,
+ )
+ x = (x + 1.0) / 2.0
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ tokens = None
+ if self.output_tokens:
+ z, tokens = z[0], z[1]
+ z = z.to(image.dtype)
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
+ z = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
+ )[:, None]
+ * z
+ )
+ if tokens is not None:
+ tokens = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(tokens.shape[0], device=tokens.device)
+ ),
+ tokens,
+ )
+ * tokens
+ )
+
+ if self.unsqueeze_dim:
+ z = z[:, None, :]
+ if self.output_tokens:
+ assert not self.repeat_to_max_len
+ assert not self.pad_to_max_len
+ return tokens, z
+ if self.repeat_to_max_len:
+ if z.dim() == 2:
+ z_ = z[:, None, :]
+ else:
+ z_ = z
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
+ elif self.pad_to_max_len:
+ assert z.dim() == 3
+ z_pad = torch.cat(
+ (
+ z,
+ torch.zeros(
+ z.shape[0],
+ self.max_length - z.shape[1],
+ z.shape[2],
+ device=z.device,
+ ),
+ ),
+ 1,
+ )
+ return z_pad, z_pad[:, 0, ...]
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ # if self.max_crops > 0:
+ # img = self.preprocess_by_cropping(img)
+ if img.dim() == 5:
+ assert self.max_crops == img.shape[1]
+ img = rearrange(img, "b n c h w -> (b n) c h w")
+ img = self.preprocess(img)
+ if not self.output_tokens:
+ # assert not self.model.visual.output_tokens
+ if self.get_hidden_states:
+ x = self.model(img, output_hidden_states=True).hidden_states[-2]
+ else:
+ x = self.model(img).image_embeds
+ tokens = None
+ else:
+ # assert self.model.visual.output_tokens
+ x, tokens = self.model(img).image_embeds
+ if self.max_crops > 0:
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
+ # drop out between 0 and all along the sequence axis
+ x = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
+ )
+ * x
+ )
+ if tokens is not None:
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
+ print(
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
+ f"Check what you are doing, and then remove this message."
+ )
+ if self.output_tokens:
+ return x, tokens
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+
+ LAYERS = ["pooled", "last", "penultimate"]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ always_return_pooled=False,
+ legacy=True,
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device("cpu"),
+ pretrained=version,
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ self.return_pooled = always_return_pooled
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+ self.legacy = legacy
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ if not self.return_pooled and self.legacy:
+ return z
+ if self.return_pooled:
+ assert not self.legacy
+ return z[self.layer], z["pooled"]
+
+ return z[self.layer]
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ if self.legacy:
+ x = x[self.layer]
+ x = self.model.ln_final(x)
+ return x
+ else:
+ # x is a dict and will stay a dict
+ o = x["last"]
+ o = self.model.ln_final(o)
+ pooled = self.pool(o, text)
+ x["pooled"] = pooled
+ return x
+
+ def pool(self, x, text):
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = (
+ x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
+ @ self.model.text_projection
+ )
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ outputs = {}
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - 1:
+ outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
+ return outputs
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEmbModel):
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate",
+ ]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch, device=torch.device("cpu"), pretrained=version
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ antialias=True,
+ ucg_rate=0.0,
+ unsqueeze_dim=False,
+ repeat_to_max_len=False,
+ num_image_crops=0,
+ output_tokens=False,
+ init_device=None,
+ ):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device(default(init_device, "cpu")),
+ pretrained=version,
+ )
+ del model.transformer
+ self.model = model
+ self.max_crops = num_image_crops
+ self.pad_to_max_len = self.max_crops > 0
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ self.antialias = antialias
+
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+ self.ucg_rate = ucg_rate
+ self.unsqueeze_dim = unsqueeze_dim
+ self.stored_batch = None
+ self.model.visual.output_tokens = output_tokens
+ self.output_tokens = output_tokens
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ # prev_dtype = x.dtype
+ # # move to torch.float32
+ # x = x.to(torch.float32)
+ # print(f"Preprocessing image with dtype {x.dtype}")
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ # x = torch.nn.functional.interpolate(
+ # x, size=(224, 224), mode="bicubic", align_corners=True, antialias=self.antialias
+ # )
+ # x = x.to(prev_dtype)
+ # print(f"Postprocessing image with dtype {x.dtype}")
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ tokens = None
+ if self.output_tokens:
+ z, tokens = z[0], z[1]
+ z = z.to(image.dtype)
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
+ z = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
+ )[:, None]
+ * z
+ )
+ if tokens is not None:
+ tokens = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(tokens.shape[0], device=tokens.device)
+ ),
+ tokens,
+ )
+ * tokens
+ )
+ if self.unsqueeze_dim:
+ z = z[:, None, :]
+ if self.output_tokens:
+ assert not self.repeat_to_max_len
+ assert not self.pad_to_max_len
+ return tokens, z
+ if self.repeat_to_max_len:
+ if z.dim() == 2:
+ z_ = z[:, None, :]
+ else:
+ z_ = z
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
+ elif self.pad_to_max_len:
+ assert z.dim() == 3
+ z_pad = torch.cat(
+ (
+ z,
+ torch.zeros(
+ z.shape[0],
+ self.max_length - z.shape[1],
+ z.shape[2],
+ device=z.device,
+ ),
+ ),
+ 1,
+ )
+ return z_pad, z_pad[:, 0, ...]
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ # if self.max_crops > 0:
+ # img = self.preprocess_by_cropping(img)
+ if img.dim() == 5:
+ assert self.max_crops == img.shape[1]
+ img = rearrange(img, "b n c h w -> (b n) c h w")
+ img = self.preprocess(img)
+ if not self.output_tokens:
+ assert not self.model.visual.output_tokens
+ x = self.model.visual(img)
+ tokens = None
+ else:
+ assert self.model.visual.output_tokens
+ x, tokens = self.model.visual(img)
+ if self.max_crops > 0:
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
+ # drop out between 0 and all along the sequence axis
+ x = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
+ )
+ * x
+ )
+ if tokens is not None:
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
+ print(
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
+ f"Check what you are doing, and then remove this message."
+ )
+ if self.output_tokens:
+ return x, tokens
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEmbModel):
+ def __init__(
+ self,
+ clip_version="openai/clip-vit-large-patch14",
+ t5_version="google/t5-v1_1-xl",
+ device="cuda",
+ clip_max_length=77,
+ t5_max_length=77,
+ ):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(
+ clip_version, device, max_length=clip_max_length
+ )
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.0e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.0e-6:.2f} M params."
+ )
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(
+ self,
+ n_stages=1,
+ method="bilinear",
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False,
+ wrap_video=False,
+ kernel_size=1,
+ remap_output=False,
+ ):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in [
+ "nearest",
+ "linear",
+ "bilinear",
+ "trilinear",
+ "bicubic",
+ "area",
+ ]
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None or remap_output
+ if self.remap_output:
+ print(
+ f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
+ )
+ self.channel_mapper = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ bias=bias,
+ padding=kernel_size // 2,
+ )
+ self.wrap_video = wrap_video
+
+ def forward(self, x):
+ if self.wrap_video and x.ndim == 5:
+ B, C, T, H, W = x.shape
+ x = rearrange(x, "b c t h w -> b t c h w")
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+ if self.wrap_video:
+ x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
+ x = rearrange(x, "b t c h w -> b c t h w")
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+
+class LowScaleEncoder(nn.Module):
+ def __init__(
+ self,
+ model_config,
+ linear_start,
+ linear_end,
+ timesteps=1000,
+ max_noise_level=250,
+ output_size=64,
+ scale_factor=1.0,
+ ):
+ super().__init__()
+ self.max_noise_level = max_noise_level
+ self.model = instantiate_from_config(model_config)
+ self.augmentation_schedule = self.register_schedule(
+ timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
+ )
+ self.out_size = output_size
+ self.scale_factor = scale_factor
+
+ def register_schedule(
+ self,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ betas = make_beta_schedule(
+ beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ (timesteps,) = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, (
+ "alphas have to be defined for each timestep"
+ )
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer("betas", to_torch(betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+ )
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def forward(self, x):
+ z = self.model.encode(x)
+ if isinstance(z, DiagonalGaussianDistribution):
+ z = z.sample()
+ z = z * self.scale_factor
+ noise_level = torch.randint(
+ 0, self.max_noise_level, (x.shape[0],), device=x.device
+ ).long()
+ z = self.q_sample(z, noise_level)
+ if self.out_size is not None:
+ z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
+ return z, noise_level
+
+ def decode(self, z):
+ z = z / self.scale_factor
+ return self.model.decode(z)
+
+
+class ConcatTimestepEmbedderND(AbstractEmbModel):
+ """embeds each dimension independently and concatenates them"""
+
+ def __init__(self, outdim, is_temporal=False):
+ super().__init__()
+ self.timestep = Timestep(outdim)
+ self.outdim = outdim
+ self.is_temporal = is_temporal
+
+ def forward(self, x):
+ if self.is_temporal:
+ x = rearrange(x, "b t ... -> (b t) () ...")
+ if x.ndim == 1:
+ x = x[:, None]
+ assert len(x.shape) == 2
+ b, dims = x.shape[0], x.shape[1]
+ x = rearrange(x, "b d -> (b d)")
+ emb = self.timestep(x)
+ emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
+ return emb
+
+
+class GaussianEncoder(Encoder, AbstractEmbModel):
+ def __init__(
+ self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.posterior = DiagonalGaussianRegularizer()
+ self.weight = weight
+ self.flatten_output = flatten_output
+
+ def forward(self, x) -> Tuple[Dict, torch.Tensor]:
+ z = super().forward(x)
+ z, log = self.posterior(z)
+ log["loss"] = log["kl_loss"]
+ log["weight"] = self.weight
+ if self.flatten_output:
+ z = rearrange(z, "b c h w -> b (h w ) c")
+ return log, z
+
+
+class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
+ def __init__(
+ self,
+ n_cond_frames: int,
+ n_copies: int,
+ encoder_config: dict,
+ sigma_sampler_config: Optional[dict] = None,
+ sigma_cond_config: Optional[dict] = None,
+ is_ae: bool = False,
+ scale_factor: float = 1.0,
+ disable_encoder_autocast: bool = False,
+ en_and_decode_n_samples_a_time: Optional[int] = None,
+ load_encoder: bool = True,
+ ):
+ super().__init__()
+
+ self.n_cond_frames = n_cond_frames
+ self.n_copies = n_copies
+ if load_encoder:
+ self.encoder = instantiate_from_config(encoder_config)
+ else:
+ self.encoder = None
+ self.sigma_sampler = (
+ instantiate_from_config(sigma_sampler_config)
+ if sigma_sampler_config is not None
+ else None
+ )
+ self.sigma_cond = (
+ instantiate_from_config(sigma_cond_config)
+ if sigma_cond_config is not None
+ else None
+ )
+ self.is_ae = is_ae
+ self.scale_factor = scale_factor
+ self.disable_encoder_autocast = disable_encoder_autocast
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
+
+ def forward(
+ self, vid: torch.Tensor
+ ) -> Union[
+ torch.Tensor,
+ Tuple[torch.Tensor, torch.Tensor],
+ Tuple[torch.Tensor, dict],
+ Tuple[Tuple[torch.Tensor, torch.Tensor], dict],
+ ]:
+ if vid.ndim == 5:
+ vid = rearrange(vid, "b c t h w -> (b t) c h w")
+
+ if vid.shape[1] == 4:
+ if self.encoder is not None:
+ self.encoder = None
+ torch.cuda.empty_cache()
+
+ vid = repeat(vid, "b c h w -> (b t) c h w", t=self.n_copies)
+
+ return (
+ rearrange(
+ vid.squeeze(1), "(b t) c h w -> b (t c) h w", t=self.n_cond_frames
+ )
+ / 0.18215
+ )
+
+ if self.sigma_sampler is not None:
+ b = vid.shape[0] // self.n_cond_frames
+ sigmas = self.sigma_sampler(b).to(vid.device)
+ if self.sigma_cond is not None:
+ sigma_cond = self.sigma_cond(sigmas)
+ sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
+ sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
+ noise = torch.randn_like(vid)
+ vid = vid + noise * append_dims(sigmas, vid.ndim)
+
+ with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
+ n_samples = (
+ self.en_and_decode_n_samples_a_time
+ if self.en_and_decode_n_samples_a_time is not None
+ else vid.shape[0]
+ )
+ n_rounds = math.ceil(vid.shape[0] / n_samples)
+ all_out = []
+ for n in range(n_rounds):
+ if self.is_ae:
+ out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples])
+ else:
+ out = self.encoder(vid[n * n_samples : (n + 1) * n_samples])
+ all_out.append(out)
+
+ vid = torch.cat(all_out, dim=0)
+ vid *= self.scale_factor
+
+ vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
+ vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
+
+ return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
+
+ return return_val
+
+
+class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
+ def __init__(
+ self,
+ open_clip_embedding_config: Dict,
+ n_cond_frames: int,
+ n_copies: int,
+ ):
+ super().__init__()
+
+ self.n_cond_frames = n_cond_frames
+ self.n_copies = n_copies
+ self.open_clip = instantiate_from_config(open_clip_embedding_config)
+
+ def forward(self, vid):
+ if vid.ndim == 5:
+ vid = rearrange(vid, "b c t h w -> (b t) c h w")
+ vid = self.open_clip(vid)
+ vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
+ vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
+
+ return vid
+
+
+class FrozenCLIPImagePredictionEmbedder(AbstractEmbModel):
+ def __init__(
+ self,
+ clip_embedding_config: Dict,
+ n_cond_frames: int,
+ n_copies: int,
+ give_cond_type: str = None,
+ ):
+ super().__init__()
+
+ self.n_cond_frames = n_cond_frames
+ self.n_copies = n_copies
+ if give_cond_type is not None:
+ self.cond_type = give_cond_type
+ self.clip = instantiate_from_config(clip_embedding_config)
+
+ def forward(self, vid):
+ if vid.ndim == 5:
+ vid = rearrange(vid, "b c t h w -> (b t) c h w")
+ vid = self.clip(vid)
+
+ if vid.dim() == 2:
+ vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
+ vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
+ elif vid.dim() == 3:
+ vid = rearrange(vid, "(b t) d c -> b t d c", t=self.n_cond_frames)
+ vid = repeat(vid, "b t d c -> (b s) t d c", s=self.n_copies).squeeze(1)
+ else:
+ raise ValueError(f"Unsupported input shape {vid.shape}")
+
+ return vid
diff --git a/sgm/modules/utils/draw_utils.py b/sgm/modules/utils/draw_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d13d70f36c30594d7b933d5357f518a82f6202d
--- /dev/null
+++ b/sgm/modules/utils/draw_utils.py
@@ -0,0 +1,219 @@
+import cv2
+import mediapipe as mp
+import numpy as np
+from mediapipe.framework.formats import landmark_pb2
+
+
+class FaceMeshVisualizer:
+ def __init__(
+ self,
+ forehead_edge=False,
+ upface_only=False,
+ draw_eye=True,
+ draw_head=False,
+ draw_iris=True,
+ draw_eyebrow=True,
+ draw_mouse=True,
+ draw_nose=True,
+ draw_pupil=True,
+ ):
+ self.mp_drawing = mp.solutions.drawing_utils
+ mp_face_mesh = mp.solutions.face_mesh
+ self.mp_face_mesh = mp_face_mesh
+ self.forehead_edge = forehead_edge
+
+ DrawingSpec = mp.solutions.drawing_styles.DrawingSpec
+ f_thick = 2
+ f_rad = 1
+ right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad)
+ right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad)
+ right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad)
+ left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad)
+ head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad)
+ nose_draw = DrawingSpec(color=(200, 200, 200), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_obl = DrawingSpec(color=(10, 180, 20), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_obr = DrawingSpec(color=(20, 10, 180), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_ibl = DrawingSpec(color=(100, 100, 30), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_ibr = DrawingSpec(color=(100, 150, 50), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_otl = DrawingSpec(color=(20, 80, 100), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_otr = DrawingSpec(color=(80, 100, 20), thickness=f_thick, circle_radius=f_rad)
+
+ mouth_draw_itl = DrawingSpec(color=(120, 100, 200), thickness=f_thick, circle_radius=f_rad)
+ mouth_draw_itr = DrawingSpec(color=(150, 120, 100), thickness=f_thick, circle_radius=f_rad)
+
+ FACEMESH_LIPS_OUTER_BOTTOM_LEFT = [(61, 146), (146, 91), (91, 181), (181, 84), (84, 17)]
+ FACEMESH_LIPS_OUTER_BOTTOM_RIGHT = [(17, 314), (314, 405), (405, 321), (321, 375), (375, 291)]
+
+ FACEMESH_LIPS_INNER_BOTTOM_LEFT = [(78, 95), (95, 88), (88, 178), (178, 87), (87, 14)]
+ FACEMESH_LIPS_INNER_BOTTOM_RIGHT = [(14, 317), (317, 402), (402, 318), (318, 324), (324, 308)]
+
+ FACEMESH_LIPS_OUTER_TOP_LEFT = [(61, 185), (185, 40), (40, 39), (39, 37), (37, 0)]
+ FACEMESH_LIPS_OUTER_TOP_RIGHT = [(0, 267), (267, 269), (269, 270), (270, 409), (409, 291)]
+
+ FACEMESH_LIPS_INNER_TOP_LEFT = [(78, 191), (191, 80), (80, 81), (81, 82), (82, 13)]
+ FACEMESH_LIPS_INNER_TOP_RIGHT = [(13, 312), (312, 311), (311, 310), (310, 415), (415, 308)]
+
+ FACEMESH_CUSTOM_FACE_OVAL = [
+ (176, 149),
+ (150, 136),
+ (356, 454),
+ (58, 132),
+ (152, 148),
+ (361, 288),
+ (251, 389),
+ (132, 93),
+ (389, 356),
+ (400, 377),
+ (136, 172),
+ (377, 152),
+ (323, 361),
+ (172, 58),
+ (454, 323),
+ (365, 379),
+ (379, 378),
+ (148, 176),
+ (93, 234),
+ (397, 365),
+ (149, 150),
+ (288, 397),
+ (234, 127),
+ (378, 400),
+ (127, 162),
+ (162, 21),
+ ]
+
+ # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about.
+ face_connection_spec = {}
+
+ # from IPython import embed
+ # embed()
+ if self.forehead_edge:
+ for edge in mp_face_mesh.FACEMESH_FACE_OVAL:
+ face_connection_spec[edge] = head_draw
+ else:
+ if draw_head:
+ FACEMESH_CUSTOM_FACE_OVAL_sorted = sorted(FACEMESH_CUSTOM_FACE_OVAL)
+ if upface_only:
+ for edge in [
+ FACEMESH_CUSTOM_FACE_OVAL_sorted[edge_idx] for edge_idx in [1, 2, 9, 12, 13, 16, 22, 25]
+ ]:
+ face_connection_spec[edge] = head_draw
+ else:
+ for edge in FACEMESH_CUSTOM_FACE_OVAL_sorted:
+ face_connection_spec[edge] = head_draw
+
+ if draw_eye:
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYE:
+ face_connection_spec[edge] = left_eye_draw
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYE:
+ face_connection_spec[edge] = right_eye_draw
+
+ if draw_eyebrow:
+ for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW:
+ face_connection_spec[edge] = left_eyebrow_draw
+
+ for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
+ face_connection_spec[edge] = right_eyebrow_draw
+
+ if draw_iris:
+ for edge in mp_face_mesh.FACEMESH_LEFT_IRIS:
+ face_connection_spec[edge] = left_iris_draw
+ for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
+ face_connection_spec[edge] = right_iris_draw
+
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW:
+ # face_connection_spec[edge] = right_eyebrow_draw
+ # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS:
+ # face_connection_spec[edge] = right_iris_draw
+
+ # for edge in mp_face_mesh.FACEMESH_LIPS:
+ # face_connection_spec[edge] = mouth_draw
+
+ if draw_mouse:
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_LEFT:
+ face_connection_spec[edge] = mouth_draw_obl
+ for edge in FACEMESH_LIPS_OUTER_BOTTOM_RIGHT:
+ face_connection_spec[edge] = mouth_draw_obr
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_LEFT:
+ face_connection_spec[edge] = mouth_draw_ibl
+ for edge in FACEMESH_LIPS_INNER_BOTTOM_RIGHT:
+ face_connection_spec[edge] = mouth_draw_ibr
+ for edge in FACEMESH_LIPS_OUTER_TOP_LEFT:
+ face_connection_spec[edge] = mouth_draw_otl
+ for edge in FACEMESH_LIPS_OUTER_TOP_RIGHT:
+ face_connection_spec[edge] = mouth_draw_otr
+ for edge in FACEMESH_LIPS_INNER_TOP_LEFT:
+ face_connection_spec[edge] = mouth_draw_itl
+ for edge in FACEMESH_LIPS_INNER_TOP_RIGHT:
+ face_connection_spec[edge] = mouth_draw_itr
+
+ self.face_connection_spec = face_connection_spec
+
+ self.pupil_landmark_spec = {468: right_iris_draw, 473: left_iris_draw}
+ self.nose_landmark_spec = {4: nose_draw}
+
+ self.draw_pupil = draw_pupil
+ self.draw_nose = draw_nose
+
+ def draw_points(self, image, landmark_list, drawing_spec, halfwidth: int = 2):
+ """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all
+ landmarks. Until our PR is merged into mediapipe, we need this separate method."""
+ if len(image.shape) != 3:
+ raise ValueError("Input image must be H,W,C.")
+ image_rows, image_cols, image_channels = image.shape
+ if image_channels != 3: # BGR channels
+ raise ValueError("Input image must contain three channel bgr data.")
+ for idx, landmark in enumerate(landmark_list.landmark):
+ if idx not in drawing_spec:
+ continue
+
+ if (landmark.HasField("visibility") and landmark.visibility < 0.9) or (
+ landmark.HasField("presence") and landmark.presence < 0.5
+ ):
+ continue
+ if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0:
+ continue
+
+ image_x = int(image_cols * landmark.x)
+ image_y = int(image_rows * landmark.y)
+
+ draw_color = drawing_spec[idx].color
+ image[image_y - halfwidth : image_y + halfwidth, image_x - halfwidth : image_x + halfwidth, :] = draw_color
+
+ def draw_landmarks(self, image_size, keypoints, normed=False, ini_size=[512, 512]):
+ # ini_size = [512, 512]
+ image = np.zeros([ini_size[1], ini_size[0], 3], dtype=np.uint8)
+ if keypoints is not None:
+ new_landmarks = landmark_pb2.NormalizedLandmarkList()
+ for i in range(keypoints.shape[0]):
+ landmark = new_landmarks.landmark.add()
+ if normed:
+ landmark.x = keypoints[i, 0]
+ landmark.y = keypoints[i, 1]
+ else:
+ landmark.x = keypoints[i, 0] / image_size[0]
+ landmark.y = keypoints[i, 1] / image_size[1]
+ landmark.z = 1.0
+
+ self.mp_drawing.draw_landmarks(
+ image=image,
+ landmark_list=new_landmarks,
+ connections=self.face_connection_spec.keys(),
+ landmark_drawing_spec=None,
+ connection_drawing_spec=self.face_connection_spec,
+ )
+
+ if self.draw_pupil:
+ self.draw_points(image, new_landmarks, self.pupil_landmark_spec, 3)
+
+ if self.draw_nose:
+ self.draw_points(image, new_landmarks, self.nose_landmark_spec, 3)
+
+ image = cv2.resize(image, (image_size[0], image_size[1]))
+
+ return image
diff --git a/sgm/modules/video_attention.py b/sgm/modules/video_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a28fca38172d7ae9147a500fa130250e4875517
--- /dev/null
+++ b/sgm/modules/video_attention.py
@@ -0,0 +1,364 @@
+import torch
+from einops import repeat
+from ..modules.attention import *
+from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
+
+
+class TimeMixSequential(nn.Sequential):
+ def forward(self, x, context=None, timesteps=None):
+ for layer in self:
+ x = layer(x, context, timesteps)
+
+ return x
+
+
+class VideoTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention,
+ "softmax-xformers": MemoryEfficientCrossAttention,
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ timesteps=None,
+ ff_in=False,
+ inner_dim=None,
+ attn_mode="softmax",
+ disable_self_attn=False,
+ disable_temporal_crossattention=False,
+ switch_temporal_ca_to_sa=False,
+ ):
+ super().__init__()
+
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+
+ self.ff_in = ff_in or inner_dim is not None
+ if inner_dim is None:
+ inner_dim = dim
+
+ assert int(n_heads * d_head) == inner_dim
+
+ self.is_res = inner_dim == dim
+
+ if self.ff_in:
+ self.norm_in = nn.LayerNorm(dim)
+ self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff)
+
+ self.timesteps = timesteps
+ self.disable_self_attn = disable_self_attn
+ if self.disable_self_attn:
+ self.attn1 = attn_cls(
+ query_dim=inner_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ context_dim=context_dim,
+ dropout=dropout,
+ ) # is a cross-attention
+ else:
+ self.attn1 = attn_cls(
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
+
+ if disable_temporal_crossattention:
+ if switch_temporal_ca_to_sa:
+ raise ValueError
+ else:
+ self.attn2 = None
+ else:
+ self.norm2 = nn.LayerNorm(inner_dim)
+ if switch_temporal_ca_to_sa:
+ self.attn2 = attn_cls(
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ else:
+ self.attn2 = attn_cls(
+ query_dim=inner_dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+
+ self.norm1 = nn.LayerNorm(inner_dim)
+ self.norm3 = nn.LayerNorm(inner_dim)
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
+
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ print(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(
+ self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None, skip_attention=False
+ ) -> torch.Tensor:
+ if self.checkpoint:
+ return checkpoint(self._forward, x, context, timesteps, skip_attention)
+ else:
+ return self._forward(x, context, timesteps=timesteps, skip_attention=skip_attention)
+
+ def _forward(self, x, context=None, timesteps=None, skip_attention=None):
+ assert self.timesteps or timesteps
+ assert not (self.timesteps and timesteps) or self.timesteps == timesteps
+ timesteps = self.timesteps or timesteps
+ B, S, C = x.shape
+ x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
+
+ if skip_attention is not None:
+ skip_attention = repeat(skip_attention[: B // timesteps], "b ... -> (b s) ...", s=S)
+ if self.ff_in:
+ x_skip = x
+ x = self.ff_in(self.norm_in(x))
+ if self.is_res:
+ x += x_skip
+
+ if self.disable_self_attn:
+ x = self.attn1(self.norm1(x), context=context) + x
+ else:
+ x = self.attn1(self.norm1(x), skip_attention=skip_attention) + x
+
+ if self.attn2 is not None:
+ if self.switch_temporal_ca_to_sa:
+ x = self.attn2(self.norm2(x), skip_attention=skip_attention) + x
+ else:
+ x = self.attn2(self.norm2(x), context=context) + x
+ x_skip = x
+ x = self.ff(self.norm3(x))
+ if self.is_res:
+ x += x_skip
+
+ x = rearrange(x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps)
+ return x
+
+ def get_last_layer(self):
+ return self.ff.net[-1].weight
+
+
+class SpatialVideoTransformer(SpatialTransformer):
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ use_linear=False,
+ context_dim=None,
+ use_spatial_context=False,
+ timesteps=None,
+ merge_strategy: str = "fixed",
+ merge_factor: float = 0.5,
+ merge_audio_factor: float = 5.0, # Almost 0 audio at first
+ time_context_dim=None,
+ audio_context_dim=None,
+ ff_in=False,
+ checkpoint=False,
+ time_depth=1,
+ attn_mode="softmax",
+ disable_self_attn=False,
+ disable_temporal_crossattention=False,
+ max_time_embed_period: int = 10000,
+ skip_time=False,
+ reference_to=None,
+ ):
+ super().__init__(
+ in_channels,
+ n_heads,
+ d_head,
+ depth=depth,
+ dropout=dropout,
+ attn_type=attn_mode,
+ use_checkpoint=checkpoint,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ disable_self_attn=disable_self_attn,
+ reference_to=reference_to,
+ )
+ self.time_depth = time_depth
+ self.depth = depth
+ self.max_time_embed_period = max_time_embed_period
+ self.skip_time = skip_time
+
+ time_mix_d_head = d_head
+ n_time_mix_heads = n_heads
+
+ time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
+
+ inner_dim = n_heads * d_head
+ if use_spatial_context:
+ time_context_dim = context_dim
+
+ if not self.skip_time:
+ self.time_stack = nn.ModuleList(
+ [
+ VideoTransformerBlock(
+ inner_dim,
+ n_time_mix_heads,
+ time_mix_d_head,
+ dropout=dropout,
+ context_dim=time_context_dim,
+ timesteps=timesteps,
+ checkpoint=checkpoint,
+ ff_in=ff_in,
+ inner_dim=time_mix_inner_dim,
+ attn_mode=attn_mode,
+ disable_self_attn=disable_self_attn,
+ disable_temporal_crossattention=disable_temporal_crossattention,
+ )
+ for _ in range(self.depth)
+ ]
+ )
+ else:
+ self.time_stack = None
+
+ self.audio_stack = None
+ if audio_context_dim is not None:
+ self.audio_stack = nn.ModuleList(
+ [
+ VideoTransformerBlock(
+ inner_dim,
+ n_time_mix_heads,
+ time_mix_d_head,
+ dropout=dropout,
+ context_dim=audio_context_dim,
+ timesteps=timesteps,
+ checkpoint=checkpoint,
+ ff_in=ff_in,
+ inner_dim=time_mix_inner_dim,
+ attn_mode=attn_mode,
+ disable_self_attn=disable_self_attn,
+ disable_temporal_crossattention=disable_temporal_crossattention or self.skip_time,
+ )
+ for _ in range(self.depth)
+ ]
+ )
+ self.audio_mixer = AlphaBlender(alpha=merge_audio_factor, merge_strategy=merge_strategy)
+
+ if self.time_stack is None:
+ self.time_stack = [None] * len(self.transformer_blocks)
+ assert len(self.time_stack) == len(self.transformer_blocks)
+
+ self.use_spatial_context = use_spatial_context
+ self.in_channels = in_channels
+
+ time_embed_dim = self.in_channels * 4
+ self.time_pos_embed = nn.Sequential(
+ linear(self.in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, self.in_channels),
+ )
+
+ self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ reference_context: Optional[torch.Tensor] = None,
+ time_context: Optional[torch.Tensor] = None,
+ audio_context: Optional[torch.Tensor] = None,
+ timesteps: Optional[int] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ skip_spatial_attention: bool = False,
+ skip_temporal_attention: bool = False,
+ ) -> torch.Tensor:
+ _, _, h, w = x.shape
+ x_in = x
+ spatial_context = None
+ if exists(context):
+ spatial_context = context
+
+ if not isinstance(spatial_context, list):
+ spatial_context = [spatial_context]
+ if reference_context is not None and not isinstance(reference_context, list):
+ reference_context = [reference_context]
+ # else:
+ # # spatial_context.reverse()
+ # print([c.shape for c in spatial_context])
+
+ if self.use_spatial_context and not self.skip_time:
+ assert (
+ isinstance(context, list) or context.ndim == 3
+ ), f"n dims of spatial context should be 3 but are {context.ndim}"
+ time_context = context
+ if not isinstance(context, list):
+ time_context_first_timestep = time_context[::timesteps]
+ time_context = repeat(time_context_first_timestep, "b ... -> (b n) ...", n=h * w)
+ elif time_context is not None and not self.use_spatial_context:
+ time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
+ if time_context.ndim == 2:
+ time_context = rearrange(time_context, "b c -> b 1 c")
+
+ if audio_context is not None:
+ audio_context = repeat(audio_context, "b ... -> (b n) ...", n=h * w)
+ if audio_context.ndim == 2:
+ audio_context = rearrange(audio_context, "b c -> b 1 c")
+
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ if self.use_linear:
+ x = self.proj_in(x)
+
+ if not self.skip_time:
+ num_frames = torch.arange(timesteps, device=x.device, dtype=x.dtype)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(
+ num_frames,
+ self.in_channels,
+ repeat_only=False,
+ max_period=self.max_time_embed_period,
+ )
+ emb = self.time_pos_embed(t_emb)
+ emb = emb[:, None, :]
+
+ for it_, (block, mix_block) in enumerate(zip(self.transformer_blocks, self.time_stack)):
+ if it_ > 0 and len(spatial_context) == 1:
+ it_ = 0 # use same context for each block
+
+ x = block(
+ x,
+ context=spatial_context[it_],
+ reference_context=reference_context[it_] if reference_context is not None else None,
+ skip_attention=skip_spatial_attention,
+ )
+
+ if not self.skip_time:
+ x_mix = x
+ x_mix = x_mix + emb
+
+ x_mix = mix_block(
+ x_mix,
+ context=time_context,
+ timesteps=timesteps,
+ skip_attention=skip_temporal_attention,
+ )
+ x = self.time_mixer(
+ x_spatial=x,
+ x_temporal=x_mix,
+ image_only_indicator=image_only_indicator,
+ )
+
+ if self.audio_stack is not None:
+ audio_mix_block = self.audio_stack[it_]
+ x_audio = x
+ # x_audio = x_audio + emb
+ x_audio = audio_mix_block(x_audio, context=audio_context, timesteps=timesteps)
+ x = self.audio_mixer(x, x_audio, image_only_indicator=image_only_indicator)
+
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ if not self.use_linear:
+ x = self.proj_out(x)
+ out = x + x_in
+ return out
diff --git a/sgm/util.py b/sgm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b314ebc3f1a96383ba1d09f8ab349342f8744af4
--- /dev/null
+++ b/sgm/util.py
@@ -0,0 +1,406 @@
+import functools
+import importlib
+import os
+from functools import partial
+from inspect import isfunction
+
+import fsspec
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from safetensors.torch import load_file as load_safetensors
+import torchaudio
+import math
+from einops import rearrange
+import torchvision
+import moviepy.editor as mpy
+
+import contextlib
+import io
+from functools import wraps
+import warnings
+
+
+def save_audio_video(
+ video,
+ audio=None,
+ frame_rate=25,
+ sample_rate=16000,
+ save_path="temp.mp4",
+ keep_intermediate=False,
+):
+ """Save audio and video to a single file.
+ video: (t, c, h, w)
+ audio: (channels t)
+ """
+ save_path = str(save_path)
+ video_tensor = rearrange(video, "t c h w -> t h w c").astype(np.uint8)
+ if audio is not None:
+ # Assuming audio is a tensor of shape (channels, samples)
+ audio_tensor = audio
+ torchvision.io.write_video(
+ save_path,
+ video_tensor,
+ fps=frame_rate,
+ audio_array=audio_tensor,
+ audio_fps=sample_rate,
+ video_codec="h264", # Specify a codec to address the error
+ audio_codec="aac",
+ )
+ else:
+ torchvision.io.write_video(
+ save_path,
+ video_tensor,
+ fps=frame_rate,
+ video_codec="h264", # Specify a codec to address the error
+ audio_codec="aac",
+ )
+ return 1
+
+
+def get_raw_audio(audio_path, audio_rate, fps=25):
+ audio, sr = torchaudio.load(audio_path, channels_first=True)
+ if audio.shape[0] > 1:
+ audio = audio.mean(0, keepdim=True)
+ audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=audio_rate)[0]
+ samples_per_frame = math.ceil(audio_rate / fps)
+ n_frames = audio.shape[-1] / samples_per_frame
+ if not n_frames.is_integer():
+ print("Audio shape before trim_pad_audio: ", audio.shape)
+ audio = trim_pad_audio(
+ audio, audio_rate, max_len_raw=math.ceil(n_frames) * samples_per_frame
+ )
+ print("Audio shape after trim_pad_audio: ", audio.shape)
+ audio = rearrange(audio, "(f s) -> f s", s=samples_per_frame)
+ return audio
+
+
+def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
+ len_file = audio.shape[-1]
+
+ if max_len_sec or max_len_raw:
+ max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
+ if len_file < int(max_len):
+ # dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
+ # extened_wav = np.concatenate((audio_data, dummy[0]))
+ extened_wav = torch.nn.functional.pad(
+ audio, (0, int(max_len) - len_file), "constant"
+ )
+ else:
+ extened_wav = audio[:, : int(max_len)]
+ else:
+ extened_wav = audio
+
+ return extened_wav
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def get_string_from_tuple(s):
+ try:
+ # Check if the string starts and ends with parentheses
+ if s[0] == "(" and s[-1] == ")":
+ # Convert the string to a tuple
+ t = eval(s)
+ # Check if the type of t is tuple
+ if type(t) == tuple:
+ return t[0]
+ else:
+ pass
+ except:
+ pass
+ return s
+
+
+def is_power_of_two(n):
+ """
+ chat.openai.com/chat
+ Return True if n is a power of 2, otherwise return False.
+
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
+
+ """
+ if n <= 0:
+ return False
+ return (n & (n - 1)) == 0
+
+
+def autocast(f, enabled=True):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=enabled,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def load_partial_from_config(config):
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ if isinstance(xc[bi], list):
+ text_seq = xc[bi][0]
+ else:
+ text_seq = xc[bi]
+ lines = "\n".join(
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
+ )
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def partialclass(cls, *args, **kwargs):
+ class NewCls(cls):
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
+
+ return NewCls
+
+
+def make_path_absolute(path):
+ fs, p = fsspec.core.url_to_fs(path)
+ if fs.protocol == "file":
+ return os.path.abspath(p)
+ return path
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def isheatmap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+
+ return x.ndim == 2
+
+
+def isneighbors(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def expand_dims_like(x, y):
+ while x.dim() != y.dim():
+ x = x.unsqueeze(-1)
+ return x
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.0e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False, invalidate_cache=True):
+ module, cls = string.rsplit(".", 1)
+ if invalidate_cache:
+ importlib.invalidate_caches()
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
+ )
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def load_model_from_config(config, ckpt, verbose=True, freeze=True):
+ print(f"Loading model from {ckpt}")
+ if ckpt.endswith("ckpt"):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ elif ckpt.endswith("safetensors"):
+ sd = load_safetensors(ckpt)
+ else:
+ raise NotImplementedError
+
+ model = instantiate_from_config(config.model)
+
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if freeze:
+ for param in model.parameters():
+ param.requires_grad = False
+
+ model.eval()
+ return model
+
+
+def get_configs_path() -> str:
+ """
+ Get the `configs` directory.
+ For a working copy, this is the one in the root of the repository,
+ but for an installed copy, it's in the `sgm` package (see pyproject.toml).
+ """
+ this_dir = os.path.dirname(__file__)
+ candidates = (
+ os.path.join(this_dir, "configs"),
+ os.path.join(this_dir, "..", "configs"),
+ )
+ for candidate in candidates:
+ candidate = os.path.abspath(candidate)
+ if os.path.isdir(candidate):
+ return candidate
+ raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
+
+
+def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
+ """
+ Will return the result of a recursive get attribute call.
+ E.g.:
+ a.b.c
+ = getattr(getattr(a, "b"), "c")
+ = get_nested_attribute(a, "b.c")
+ If any part of the attribute call is an integer x with current obj a, will
+ try to call a[x] instead of a.x first.
+ """
+ attributes = attribute_path.split(".")
+ if depth is not None and depth > 0:
+ attributes = attributes[:depth]
+ assert len(attributes) > 0, "At least one attribute should be selected"
+ current_attribute = obj
+ current_key = None
+ for level, attribute in enumerate(attributes):
+ current_key = ".".join(attributes[: level + 1])
+ try:
+ id_ = int(attribute)
+ current_attribute = current_attribute[id_]
+ except ValueError:
+ current_attribute = getattr(current_attribute, attribute)
+
+ return (current_attribute, current_key) if return_key else current_attribute
+
+
+def suppress_output(f):
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ with (
+ contextlib.redirect_stdout(io.StringIO()),
+ contextlib.redirect_stderr(io.StringIO()),
+ warnings.catch_warnings(),
+ ):
+ warnings.simplefilter("ignore")
+ return f(*args, **kwargs)
+
+ return wrapper
+
+
+def calculate_splits(tensor, min_last_size, dim=1):
+ # Check the total number of elements in the tensor
+ total_size = tensor.size(dim) # size along the second dimension
+
+ # If total size is less than the minimum size for the last split, return the tensor as a single split
+ if total_size <= min_last_size:
+ return [tensor]
+
+ # Calculate number of splits and size of each split
+ num_splits = (total_size - min_last_size) // min_last_size + 1
+ base_size = (total_size - min_last_size) // num_splits
+
+ # Create split sizes list
+ split_sizes = [base_size] * (num_splits - 1)
+ split_sizes.append(
+ total_size - sum(split_sizes)
+ ) # Ensure the last split has at least min_last_size
+
+ # Adjust sizes to ensure they sum exactly to total_size
+ sum_sizes = sum(split_sizes)
+ while sum_sizes != total_size:
+ for i in range(num_splits):
+ if sum_sizes < total_size:
+ split_sizes[i] += 1
+ sum_sizes += 1
+ if sum_sizes >= total_size:
+ break
+
+ # Split the tensor
+ splits = torch.split(tensor, split_sizes, dim=dim)
+
+ return splits
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e0b155b3252cf979a9639e73cba409a408e0552
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,343 @@
+import torchvision
+from einops import rearrange
+import numpy as np
+import math
+import torchaudio
+import torch
+import importlib
+from data_utils import create_masks_from_landmarks_box
+import torch.nn.functional as F
+
+
+def save_audio_video(
+ video,
+ audio=None,
+ frame_rate=25,
+ sample_rate=16000,
+ save_path="temp.mp4",
+):
+ """Save audio and video to a single file.
+ video: (t, c, h, w)
+ audio: (channels t)
+ """
+ save_path = str(save_path)
+ if isinstance(video, torch.Tensor):
+ video = video.cpu().numpy()
+ video_tensor = rearrange(video, "t c h w -> t h w c").astype(np.uint8)
+ print("video_tensor shape", video_tensor.shape)
+ print("audio shape", audio.shape)
+
+ if audio is not None:
+ # Assuming audio is a tensor of shape (channels, samples)
+ audio_tensor = audio
+ torchvision.io.write_video(
+ save_path,
+ video_tensor,
+ fps=frame_rate,
+ audio_array=audio_tensor,
+ audio_fps=sample_rate,
+ video_codec="h264", # Specify a codec to address the error
+ audio_codec="aac",
+ )
+ else:
+ torchvision.io.write_video(
+ save_path,
+ video_tensor,
+ fps=frame_rate,
+ video_codec="h264", # Specify a codec to address the error
+ audio_codec="aac",
+ )
+ return save_path
+
+
+def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
+ len_file = audio.shape[-1]
+
+ if max_len_sec or max_len_raw:
+ max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
+ if len_file < int(max_len):
+ # dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
+ # extened_wav = np.concatenate((audio_data, dummy[0]))
+ extened_wav = torch.nn.functional.pad(
+ audio, (0, int(max_len) - len_file), "constant"
+ )
+ else:
+ extened_wav = audio[:, : int(max_len)]
+ else:
+ extened_wav = audio
+
+ return extened_wav
+
+
+def get_raw_audio(audio_path, audio_rate, fps=25):
+ audio, sr = torchaudio.load(audio_path, channels_first=True)
+ if audio.shape[0] > 1:
+ audio = audio.mean(0, keepdim=True)
+ audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=audio_rate)[0]
+ samples_per_frame = math.ceil(audio_rate / fps)
+ n_frames = audio.shape[-1] / samples_per_frame
+ if not n_frames.is_integer():
+ audio = trim_pad_audio(
+ audio, audio_rate, max_len_raw=math.ceil(n_frames) * samples_per_frame
+ )
+ audio = rearrange(audio, "(f s) -> f s", s=samples_per_frame)
+ return audio
+
+
+def calculate_splits(tensor, min_last_size):
+ # Check the total number of elements in the tensor
+ total_size = tensor.size(1) # size along the second dimension
+
+ # If total size is less than the minimum size for the last split, return the tensor as a single split
+ if total_size <= min_last_size:
+ return [tensor]
+
+ # Calculate number of splits and size of each split
+ num_splits = (total_size - min_last_size) // min_last_size + 1
+ base_size = (total_size - min_last_size) // num_splits
+
+ # Create split sizes list
+ split_sizes = [base_size] * (num_splits - 1)
+ split_sizes.append(
+ total_size - sum(split_sizes)
+ ) # Ensure the last split has at least min_last_size
+
+ # Adjust sizes to ensure they sum exactly to total_size
+ sum_sizes = sum(split_sizes)
+ while sum_sizes != total_size:
+ for i in range(num_splits):
+ if sum_sizes < total_size:
+ split_sizes[i] += 1
+ sum_sizes += 1
+ if sum_sizes >= total_size:
+ break
+
+ # Split the tensor
+ splits = torch.split(tensor, split_sizes, dim=1)
+
+ return splits
+
+
+def make_into_multiple_of(x, multiple, dim=0):
+ """Make the torch tensor into a multiple of the given number."""
+ if x.shape[dim] % multiple != 0:
+ x = torch.cat(
+ [
+ x,
+ torch.zeros(
+ *x.shape[:dim],
+ multiple - (x.shape[dim] % multiple),
+ *x.shape[dim + 1 :],
+ ).to(x.device),
+ ],
+ dim=dim,
+ )
+ return x
+
+
+def default(value, default_value):
+ return default_value if value is None else value
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False, invalidate_cache=True):
+ module, cls = string.rsplit(".", 1)
+ if invalidate_cache:
+ importlib.invalidate_caches()
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def load_landmarks(
+ landmarks: np.ndarray,
+ original_size,
+ target_size=(64, 64),
+ nose_index=28,
+):
+ """
+ Load and process facial landmarks to create masks.
+
+ Args:
+ landmarks: Facial landmarks array
+ original_size: Original size of the video frames
+ index: Index for non-dub mode
+ target_size: Target size for the output mask
+ is_dub: Whether this is for dubbing mode
+ what_mask: Type of mask to create ("full", "box", "heart", "mouth")
+ nose_index: Index of the nose landmark
+
+ Returns:
+ Processed landmarks mask
+ """
+ expand_box = 0.0
+ if len(landmarks.shape) == 2:
+ landmarks = landmarks[None, ...]
+
+ mask = create_masks_from_landmarks_box(
+ landmarks,
+ (original_size[0], original_size[1]),
+ box_expand=expand_box,
+ nose_index=nose_index,
+ )
+
+ mask = F.interpolate(mask.unsqueeze(1).float(), size=target_size, mode="nearest")
+ return mask
+
+
+def create_pipeline_inputs(
+ audio: torch.Tensor,
+ audio_interpolation: torch.Tensor,
+ num_frames: int,
+ video_emb: torch.Tensor,
+ landmarks: np.ndarray,
+ overlap: int = 1,
+ add_zero_flag: bool = False,
+ mask_arms: bool = None,
+ nose_index: int = 28,
+):
+ """
+ Create inputs for the keyframe generation and interpolation pipeline.
+
+ Args:
+ video: Input video tensor
+ audio: Audio embeddings for keyframe generation
+ audio_interpolation: Audio embeddings for interpolation
+ num_frames: Number of frames per segment
+ video_emb: Optional video embeddings
+ landmarks: Facial landmarks for mask generation
+ overlap: Number of frames to overlap between segments
+ add_zero_flag: Whether to add zero flag every num_frames
+ what_mask: Type of mask to generate ("box" or other options)
+ mask_arms: Optional mask for arms region
+ nose_index: Index of the nose landmark point
+
+ Returns:
+ Tuple containing all necessary inputs for the pipeline
+ """
+ audio_interpolation_chunks = []
+ audio_image_preds = []
+ gt_chunks = []
+ gt_keyframes_chunks = []
+ # Adjustment for overlap to ensure segments are created properly
+ step = num_frames - overlap
+
+ # Ensure there's at least one step forward on each iteration
+ if step < 1:
+ step = 1
+
+ audio_image_preds_idx = []
+ audio_interp_preds_idx = []
+ masks_chunks = []
+ masks_interpolation_chunks = []
+ for i in range(0, audio.shape[0] - num_frames + 1, step):
+ try:
+ audio[i + num_frames - 1]
+ except IndexError:
+ break # Last chunk is smaller than num_frames
+ segment_end = i + num_frames
+ gt_chunks.append(video_emb[i:segment_end])
+ masks = load_landmarks(
+ landmarks[i:segment_end],
+ (512, 512),
+ target_size=(64, 64),
+ nose_index=nose_index,
+ )
+ if mask_arms is not None:
+ masks = np.logical_and(
+ masks, np.logical_not(mask_arms[i:segment_end, None, ...])
+ )
+ masks_interpolation_chunks.append(masks)
+
+ if i not in audio_image_preds_idx:
+ audio_image_preds.append(audio[i])
+ masks_chunks.append(masks[0])
+ gt_keyframes_chunks.append(video_emb[i])
+ audio_image_preds_idx.append(i)
+
+ if segment_end - 1 not in audio_image_preds_idx:
+ audio_image_preds_idx.append(segment_end - 1)
+
+ audio_image_preds.append(audio[segment_end - 1])
+ masks_chunks.append(masks[-1])
+ gt_keyframes_chunks.append(video_emb[segment_end - 1])
+
+ audio_interpolation_chunks.append(audio_interpolation[i:segment_end])
+ audio_interp_preds_idx.append([i, segment_end - 1])
+
+ # If the flag is on, add element 0 every 14 audio elements
+ if add_zero_flag:
+ first_element = audio_image_preds[0]
+
+ len_audio_image_preds = (
+ len(audio_image_preds) + (len(audio_image_preds) + 1) % num_frames
+ )
+ for i in range(0, len_audio_image_preds, num_frames):
+ audio_image_preds.insert(i, first_element)
+ audio_image_preds_idx.insert(i, None)
+ masks_chunks.insert(i, masks_chunks[0])
+ gt_keyframes_chunks.insert(i, gt_keyframes_chunks[0])
+
+ to_remove = [idx is None for idx in audio_image_preds_idx]
+ audio_image_preds_idx_clone = [idx for idx in audio_image_preds_idx]
+ if add_zero_flag:
+ # Remove the added elements from the list
+ audio_image_preds_idx = [
+ sample for i, sample in zip(to_remove, audio_image_preds_idx) if not i
+ ]
+
+ interpolation_cond_list = []
+ for i in range(0, len(audio_image_preds_idx) - 1, overlap if overlap > 0 else 2):
+ interpolation_cond_list.append(
+ [audio_image_preds_idx[i], audio_image_preds_idx[i + 1]]
+ )
+
+ # Since we generate num_frames at a time, we need to ensure that the last chunk is of size num_frames
+ # Calculate the number of frames needed to make audio_image_preds a multiple of num_frames
+ frames_needed = (num_frames - (len(audio_image_preds) % num_frames)) % num_frames
+
+ # Extend from the start of audio_image_preds
+ audio_image_preds = audio_image_preds + [audio_image_preds[-1]] * frames_needed
+ masks_chunks = masks_chunks + [masks_chunks[-1]] * frames_needed
+ gt_keyframes_chunks = (
+ gt_keyframes_chunks + [gt_keyframes_chunks[-1]] * frames_needed
+ )
+
+ to_remove = to_remove + [True] * frames_needed
+ audio_image_preds_idx_clone = (
+ audio_image_preds_idx_clone + [audio_image_preds_idx_clone[-1]] * frames_needed
+ )
+
+ print(
+ f"Added {frames_needed} frames from the start to make audio_image_preds a multiple of {num_frames}"
+ )
+
+ # random_cond_idx = np.random.randint(0, len(video_emb))
+ random_cond_idx = 0
+
+ assert len(to_remove) == len(audio_image_preds), (
+ "to_remove and audio_image_preds must have the same length"
+ )
+
+ return (
+ gt_chunks,
+ gt_keyframes_chunks,
+ audio_interpolation_chunks,
+ audio_image_preds,
+ video_emb[random_cond_idx],
+ masks_chunks,
+ masks_interpolation_chunks,
+ to_remove,
+ audio_interp_preds_idx,
+ audio_image_preds_idx_clone,
+ )
diff --git a/vae_wrapper.py b/vae_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..0907a38975df5a385a62ca6a052be441909cb3f3
--- /dev/null
+++ b/vae_wrapper.py
@@ -0,0 +1,184 @@
+import os
+import torch
+import torch.nn as nn
+from einops import rearrange
+from diffusers import (
+ AutoencoderKL,
+ AutoencoderKLTemporalDecoder,
+ StableDiffusionPipeline,
+)
+
+
+def default(value, default_value):
+ return default_value if value is None else value
+
+
+def load_stable_model(model_path):
+ vae_model = StableDiffusionPipeline.from_pretrained(model_path)
+ vae_model.set_use_memory_efficient_attention_xformers(True)
+ return vae_model.vae
+
+
+def process_image(image: torch.Tensor, resolution=None) -> torch.Tensor:
+ """
+ Process image tensor by resizing and normalizing.
+
+ Args:
+ image: Input image tensor
+ resolution: Target resolution for resizing
+
+ Returns:
+ Processed image tensor
+ """
+ if resolution is not None:
+ image = torch.nn.functional.interpolate(
+ image.float(), size=resolution, mode="bilinear", align_corners=False
+ )
+ return image / 127.5 - 1.0
+
+
+def encode_video_chunk(
+ model,
+ video,
+ target_resolution,
+) -> torch.Tensor:
+ """
+ Encode a chunk of video frames into latent space.
+
+ Args:
+ model: VAE model for encoding
+ video: Video tensor to encode
+ target_resolution: Target resolution for processing
+
+ Returns:
+ Encoded latent tensor
+ """
+ video = rearrange(video, "t h w c -> c t h w")
+ vid_rez = min(video.shape[-1], video.shape[-2])
+ to_rez = default(target_resolution, vid_rez)
+ video = process_image(video, to_rez)
+
+ encoded = model.encode_video(video.cuda().unsqueeze(0)).squeeze(0)
+ return rearrange(encoded, "c t h w -> t c h w")
+
+
+class VaeWrapper(nn.Module):
+ def __init__(self, latent_type, max_chunk_decode=16, variant="fp16"):
+ super().__init__()
+ self.vae_model = self.get_vae(latent_type, variant)
+ # self.latent_scale = latent_scale
+ self.latent_type = latent_type
+ self.max_chunk_decode = max_chunk_decode
+
+ def get_vae(self, latent_type, variant="fp16"):
+ if latent_type == "stable":
+ vae_model = load_stable_model("stabilityai/stable-diffusion-x4-upscaler")
+ vae_model.enable_slicing()
+ vae_model.set_use_memory_efficient_attention_xformers(True)
+ self.down_factor = 4
+ elif latent_type == "video":
+ vae_model = AutoencoderKLTemporalDecoder.from_pretrained(
+ "stabilityai/stable-video-diffusion-img2vid",
+ subfolder="vae",
+ torch_dtype=torch.float16 if variant == "fp16" else torch.float32,
+ variant="fp16" if variant == "fp16" else None,
+ )
+ vae_model.set_use_memory_efficient_attention_xformers(True)
+ self.down_factor = 8
+ elif latent_type == "refiner":
+ vae_model = AutoencoderKL.from_pretrained(
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
+ subfolder="vae",
+ revision=None,
+ )
+ vae_model.enable_slicing()
+ vae_model.set_use_memory_efficient_attention_xformers(True)
+ self.down_factor = 8
+
+ vae_model.eval()
+ vae_model.requires_grad_(False)
+ vae_model.cuda()
+
+ vae_model = torch.compile(vae_model)
+ return vae_model
+
+ # def accelerate_model(self, example_shape):
+ # self.vae_model = torch.jit.trace(self.vae_model, torch.randn(example_shape).cuda())
+ # self.vae_model = torch.compile(self.vae_model)
+ # self.is_accelerated = True
+ def disable_slicing(self):
+ self.vae_model.disable_slicing()
+
+ @torch.no_grad()
+ def encode_video(self, video):
+ """
+ video: (B, C, T, H, W)
+ """
+ is_video = False
+ if len(video.shape) == 5:
+ is_video = True
+ T = video.shape[2]
+ video = rearrange(video, "b c t h w -> (b t) c h w")
+ or_dtype = video.dtype
+ # if not self.is_accelerated:
+ # self.accelerate_model(video.shape)
+ if self.latent_type in ["stable", "refiner", "video"]:
+ encoded_video = (
+ self.vae_model.encode(video.to(dtype=self.vae_model.dtype))
+ .latent_dist.sample()
+ .to(dtype=or_dtype)
+ * self.vae_model.config.scaling_factor
+ )
+ elif self.latent_type == "ldm":
+ encoded_video = self.vae_model.encode_first_stage(video) * 0.18215
+ if not is_video:
+ return encoded_video
+ return rearrange(encoded_video, "(b t) c h w -> b c t h w", t=T)
+
+ @torch.no_grad()
+ def decode_video(self, encoded_video):
+ """
+ encoded_video: (B, C, T, H, W)
+ """
+ is_video = False
+ B, T = encoded_video.shape[0], 1
+ if len(encoded_video.shape) == 5:
+ is_video = True
+ T = encoded_video.shape[2]
+ encoded_video = rearrange(encoded_video, "b c t h w -> (b t) c h w")
+ decoded_full = []
+ or_dtype = encoded_video.dtype
+
+ for i in range(0, T * B, self.max_chunk_decode): # Slow but no memory issues
+ if self.latent_type in ["stable", "refiner"]:
+ decoded_full.append(
+ self.vae_model.decode(
+ (1 / self.vae_model.config.scaling_factor)
+ * encoded_video[i : i + self.max_chunk_decode]
+ ).sample
+ )
+ elif self.latent_type == "video":
+ chunk = encoded_video[i : i + self.max_chunk_decode].to(
+ dtype=self.vae_model.dtype
+ )
+ num_frames_in = chunk.shape[0]
+ decode_kwargs = {}
+ decode_kwargs["num_frames"] = num_frames_in
+ decoded_full.append(
+ self.vae_model.decode(
+ 1 / self.vae_model.config.scaling_factor * chunk,
+ **decode_kwargs,
+ ).sample.to(or_dtype)
+ )
+ elif self.latent_type == "ldm":
+ decoded_full.append(
+ self.vae_model.decode_first_stage(
+ 1 / 0.18215 * encoded_video[i : i + self.max_chunk_decode]
+ )
+ )
+ decoded_video = torch.cat(decoded_full, dim=0)
+ if not is_video:
+ return decoded_video.clamp(-1.0, 1.0)
+ return rearrange(decoded_video, "(b t) c h w -> b c t h w", t=T).clamp(
+ -1.0, 1.0
+ )
diff --git a/wordle_game.py b/wordle_game.py
new file mode 100644
index 0000000000000000000000000000000000000000..123fe688e86e2a181a73af0d3350c68662349478
--- /dev/null
+++ b/wordle_game.py
@@ -0,0 +1,203 @@
+import random
+from collections import Counter
+
+
+class WordleGame:
+ # A more extensive word list with common 5-letter words
+ WORD_LIST = [
+ "apple",
+ "beach",
+ "crane",
+ "doubt",
+ "eagle",
+ "flame",
+ "grape",
+ "house",
+ "igloo",
+ "joker",
+ "knife",
+ "lemon",
+ "mango",
+ "night",
+ "ocean",
+ "piano",
+ "queen",
+ "river",
+ "stone",
+ "tiger",
+ "vivid",
+ "waste",
+ "yacht",
+ "zebra",
+ "about",
+ "above",
+ "actor",
+ "adapt",
+ "admit",
+ "adopt",
+ "after",
+ "again",
+ "album",
+ "alert",
+ "alike",
+ "alive",
+ "allow",
+ "alone",
+ "along",
+ "alter",
+ "among",
+ "anger",
+ "angle",
+ "angry",
+ "ankle",
+ "apart",
+ "apple",
+ "apply",
+ "arena",
+ "argue",
+ "arise",
+ "armor",
+ "array",
+ "arrow",
+ "asset",
+ "avoid",
+ "award",
+ "aware",
+ "awful",
+ "bacon",
+ "badge",
+ "badly",
+ "basic",
+ "basis",
+ "beach",
+ "beard",
+ "beast",
+ "begin",
+ "being",
+ "below",
+ "bench",
+ "birth",
+ "black",
+ "blade",
+ "blame",
+ "blank",
+ "blast",
+ "bleed",
+ "blend",
+ "bless",
+ ]
+
+ def __init__(self):
+ """Initialize the game state."""
+ self.target_word = None
+ self.guesses = []
+ self.feedbacks = []
+ self.game_over = True
+ self.won = False
+ self.max_attempts = 6
+ self.difficulty = "normal" # Can be "easy", "normal", or "hard"
+
+ def new_game(self, difficulty="normal"):
+ """Start a new game by resetting state and picking a random word."""
+ self.difficulty = difficulty
+ self.target_word = random.choice(self.WORD_LIST)
+ self.guesses = []
+ self.feedbacks = []
+ self.game_over = False
+ self.won = False
+
+ # Adjust max attempts based on difficulty
+ if difficulty == "easy":
+ self.max_attempts = 8
+ elif difficulty == "hard":
+ self.max_attempts = 4
+ else: # normal
+ self.max_attempts = 6
+
+ return f"New game started ({self.difficulty} mode). Guess a five-letter word. You have {self.max_attempts} attempts."
+
+ def generate_feedback(self, guess):
+ """Generate HTML feedback for a guess (green, yellow, gray)."""
+ target_count = Counter(self.target_word)
+ feedback = [""] * 5
+ # Mark correct letters in correct positions (green)
+ for i in range(5):
+ if guess[i] == self.target_word[i]:
+ feedback[i] = (
+ f'{guess[i].upper()}'
+ )
+ target_count[guess[i]] -= 1
+ # Mark letters in wrong positions (yellow) or not in word (gray)
+ for i in range(5):
+ if feedback[i] == "":
+ if guess[i] in target_count and target_count[guess[i]] > 0:
+ feedback[i] = (
+ f'{guess[i].upper()}'
+ )
+ target_count[guess[i]] -= 1
+ else:
+ feedback[i] = f'{guess[i].upper()}'
+ return "".join(feedback)
+
+ def submit_guess(self, guess):
+ """Process a player's guess and update game state."""
+ if self.game_over:
+ return "Please start a new game."
+
+ if len(guess) != 5:
+ return "Please enter a five-letter word."
+
+ guess = guess.lower()
+
+ # Check if the guess is a valid word (optional validation)
+ if self.difficulty == "hard" and guess not in self.WORD_LIST:
+ return "Not in word list. Try again."
+
+ feedback = self.generate_feedback(guess)
+ self.guesses.append(guess)
+ self.feedbacks.append(feedback)
+
+ if guess == self.target_word:
+ self.game_over = True
+ self.won = True
+ attempts = len(self.guesses)
+ message = f"Congratulations! You guessed the word in {attempts}/{'attempt' if attempts == 1 else 'attempts'}."
+ elif len(self.guesses) >= self.max_attempts:
+ self.game_over = True
+ message = f"Game over. The word was {self.target_word.upper()}."
+ else:
+ message = f"You have {self.max_attempts - len(self.guesses)} guesses left."
+
+ # Give hints in easy mode
+ if self.difficulty == "easy" and len(self.guesses) >= 3:
+ # Find a letter position that hasn't been guessed correctly yet
+ for i in range(5):
+ if all(g[i] != self.target_word[i] for g in self.guesses):
+ message += (
+ f" Hint: Letter {i + 1} is '{self.target_word[i].upper()}'."
+ )
+ break
+
+ return message
+
+ def get_feedback_history(self):
+ """Return the history of guesses and feedbacks as an HTML string."""
+ if not self.feedbacks:
+ return "No guesses yet."
+
+ history = []
+ for i, fb in enumerate(self.feedbacks):
+ history.append(f"Guess {i + 1}/{self.max_attempts}: {fb}")
+
+ return "
".join(history)
+
+ def get_game_stats(self):
+ """Return current game statistics."""
+ return {
+ "target_word": self.target_word if self.game_over else "???",
+ "guesses_made": len(self.guesses),
+ "max_attempts": self.max_attempts,
+ "won": self.won,
+ "game_over": self.game_over,
+ "difficulty": self.difficulty,
+ }