|
from contextlib import nullcontext |
|
from functools import partial |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
import kornia |
|
import numpy as np |
|
import open_clip |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, repeat |
|
from omegaconf import ListConfig |
|
from torch.utils.checkpoint import checkpoint |
|
from transformers import ( |
|
ByT5Tokenizer, |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
T5EncoderModel, |
|
T5Tokenizer, |
|
) |
|
|
|
from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer |
|
from ...modules.diffusionmodules.model import Encoder |
|
from ...modules.diffusionmodules.openaimodel import Timestep |
|
from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule |
|
from ...modules.distributions.distributions import DiagonalGaussianDistribution |
|
from ...util import ( |
|
autocast, |
|
count_params, |
|
default, |
|
disabled_train, |
|
expand_dims_like, |
|
instantiate_from_config, |
|
) |
|
|
|
import math |
|
import string |
|
import pytorch_lightning as pl |
|
from torchvision import transforms |
|
from timm.models.vision_transformer import VisionTransformer |
|
from safetensors.torch import load_file as load_safetensors |
|
|
|
|
|
from transformers import logging |
|
logging.set_verbosity_error() |
|
|
|
class AbstractEmbModel(nn.Module): |
|
def __init__(self, is_add_embedder=False): |
|
super().__init__() |
|
self._is_trainable = None |
|
self._ucg_rate = None |
|
self._input_key = None |
|
self.is_add_embedder = is_add_embedder |
|
|
|
@property |
|
def is_trainable(self) -> bool: |
|
return self._is_trainable |
|
|
|
@property |
|
def ucg_rate(self) -> Union[float, torch.Tensor]: |
|
return self._ucg_rate |
|
|
|
@property |
|
def input_key(self) -> str: |
|
return self._input_key |
|
|
|
@is_trainable.setter |
|
def is_trainable(self, value: bool): |
|
self._is_trainable = value |
|
|
|
@ucg_rate.setter |
|
def ucg_rate(self, value: Union[float, torch.Tensor]): |
|
self._ucg_rate = value |
|
|
|
@input_key.setter |
|
def input_key(self, value: str): |
|
self._input_key = value |
|
|
|
@is_trainable.deleter |
|
def is_trainable(self): |
|
del self._is_trainable |
|
|
|
@ucg_rate.deleter |
|
def ucg_rate(self): |
|
del self._ucg_rate |
|
|
|
@input_key.deleter |
|
def input_key(self): |
|
del self._input_key |
|
|
|
|
|
class GeneralConditioner(nn.Module): |
|
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} |
|
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} |
|
|
|
def __init__(self, emb_models: Union[List, ListConfig]): |
|
super().__init__() |
|
embedders = [] |
|
for n, embconfig in enumerate(emb_models): |
|
embedder = instantiate_from_config(embconfig) |
|
assert isinstance( |
|
embedder, AbstractEmbModel |
|
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" |
|
embedder.is_trainable = embconfig.get("is_trainable", False) |
|
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) |
|
if not embedder.is_trainable: |
|
embedder.train = disabled_train |
|
embedder.freeze() |
|
print( |
|
f"Initialized embedder #{n}: {embedder.__class__.__name__} " |
|
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" |
|
) |
|
|
|
if "input_key" in embconfig: |
|
embedder.input_key = embconfig["input_key"] |
|
elif "input_keys" in embconfig: |
|
embedder.input_keys = embconfig["input_keys"] |
|
else: |
|
raise KeyError( |
|
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" |
|
) |
|
|
|
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) |
|
if embedder.legacy_ucg_val is not None: |
|
embedder.ucg_prng = np.random.RandomState() |
|
|
|
embedders.append(embedder) |
|
self.embedders = nn.ModuleList(embedders) |
|
|
|
def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: |
|
assert embedder.legacy_ucg_val is not None |
|
p = embedder.ucg_rate |
|
val = embedder.legacy_ucg_val |
|
for i in range(len(batch[embedder.input_key])): |
|
if embedder.ucg_prng.choice(2, p=[1 - p, p]): |
|
batch[embedder.input_key][i] = val |
|
return batch |
|
|
|
def forward( |
|
self, batch: Dict, force_zero_embeddings: Optional[List] = None |
|
) -> Dict: |
|
output = dict() |
|
if force_zero_embeddings is None: |
|
force_zero_embeddings = [] |
|
for embedder in self.embedders: |
|
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad |
|
with embedding_context(): |
|
if hasattr(embedder, "input_key") and (embedder.input_key is not None): |
|
if embedder.legacy_ucg_val is not None: |
|
batch = self.possibly_get_ucg_val(embedder, batch) |
|
emb_out = embedder(batch[embedder.input_key]) |
|
elif hasattr(embedder, "input_keys"): |
|
emb_out = embedder(*[batch[k] for k in embedder.input_keys]) |
|
assert isinstance( |
|
emb_out, (torch.Tensor, list, tuple) |
|
), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" |
|
if not isinstance(emb_out, (list, tuple)): |
|
emb_out = [emb_out] |
|
for emb in emb_out: |
|
if embedder.is_add_embedder: |
|
out_key = "add_crossattn" |
|
else: |
|
out_key = self.OUTPUT_DIM2KEYS[emb.dim()] |
|
if embedder.input_key == "mask": |
|
H, W = batch["image"].shape[-2:] |
|
emb = nn.functional.interpolate(emb, (H//8, W//8)) |
|
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: |
|
emb = ( |
|
expand_dims_like( |
|
torch.bernoulli( |
|
(1.0 - embedder.ucg_rate) |
|
* torch.ones(emb.shape[0], device=emb.device) |
|
), |
|
emb, |
|
) |
|
* emb |
|
) |
|
if ( |
|
hasattr(embedder, "input_key") |
|
and embedder.input_key in force_zero_embeddings |
|
): |
|
emb = torch.zeros_like(emb) |
|
if out_key in output: |
|
output[out_key] = torch.cat( |
|
(output[out_key], emb), self.KEY2CATDIM[out_key] |
|
) |
|
else: |
|
output[out_key] = emb |
|
return output |
|
|
|
def get_unconditional_conditioning( |
|
self, batch_c, batch_uc=None, force_uc_zero_embeddings=None |
|
): |
|
if force_uc_zero_embeddings is None: |
|
force_uc_zero_embeddings = [] |
|
ucg_rates = list() |
|
for embedder in self.embedders: |
|
ucg_rates.append(embedder.ucg_rate) |
|
embedder.ucg_rate = 0.0 |
|
c = self(batch_c) |
|
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) |
|
|
|
for embedder, rate in zip(self.embedders, ucg_rates): |
|
embedder.ucg_rate = rate |
|
return c, uc |
|
|
|
|
|
class DualConditioner(GeneralConditioner): |
|
|
|
def get_unconditional_conditioning( |
|
self, batch_c, batch_uc_1=None, batch_uc_2=None, force_uc_zero_embeddings=None |
|
): |
|
if force_uc_zero_embeddings is None: |
|
force_uc_zero_embeddings = [] |
|
ucg_rates = list() |
|
for embedder in self.embedders: |
|
ucg_rates.append(embedder.ucg_rate) |
|
embedder.ucg_rate = 0.0 |
|
|
|
c = self(batch_c) |
|
uc_1 = self(batch_uc_1, force_uc_zero_embeddings) if batch_uc_1 is not None else None |
|
uc_2 = self(batch_uc_2, force_uc_zero_embeddings[:1]) if batch_uc_2 is not None else None |
|
|
|
for embedder, rate in zip(self.embedders, ucg_rates): |
|
embedder.ucg_rate = rate |
|
|
|
return c, uc_1, uc_2 |
|
|
|
|
|
class InceptionV3(nn.Module): |
|
"""Wrapper around the https://github.com/mseitzer/pytorch-fid inception |
|
port with an additional squeeze at the end""" |
|
|
|
def __init__(self, normalize_input=False, **kwargs): |
|
super().__init__() |
|
from pytorch_fid import inception |
|
|
|
kwargs["resize_input"] = True |
|
self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) |
|
|
|
def forward(self, inp): |
|
|
|
|
|
|
|
|
|
|
|
|
|
outp = self.model(inp) |
|
|
|
if len(outp) == 1: |
|
return outp[0].squeeze() |
|
|
|
return outp |
|
|
|
|
|
class IdentityEncoder(AbstractEmbModel): |
|
def encode(self, x): |
|
return x |
|
def freeze(self): |
|
return |
|
def forward(self, x): |
|
return x |
|
|
|
|
|
class ClassEmbedder(AbstractEmbModel): |
|
def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): |
|
super().__init__() |
|
self.embedding = nn.Embedding(n_classes, embed_dim) |
|
self.n_classes = n_classes |
|
self.add_sequence_dim = add_sequence_dim |
|
|
|
def forward(self, c): |
|
c = self.embedding(c) |
|
if self.add_sequence_dim: |
|
c = c[:, None, :] |
|
return c |
|
|
|
def get_unconditional_conditioning(self, bs, device="cuda"): |
|
uc_class = ( |
|
self.n_classes - 1 |
|
) |
|
uc = torch.ones((bs,), device=device) * uc_class |
|
uc = {self.key: uc.long()} |
|
return uc |
|
|
|
|
|
class ClassEmbedderForMultiCond(ClassEmbedder): |
|
def forward(self, batch, key=None, disable_dropout=False): |
|
out = batch |
|
key = default(key, self.key) |
|
islist = isinstance(batch[key], list) |
|
if islist: |
|
batch[key] = batch[key][0] |
|
c_out = super().forward(batch, key, disable_dropout) |
|
out[key] = [c_out] if islist else c_out |
|
return out |
|
|
|
|
|
class FrozenT5Embedder(AbstractEmbModel): |
|
"""Uses the T5 transformer encoder for text""" |
|
|
|
def __init__( |
|
self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True |
|
): |
|
super().__init__() |
|
self.tokenizer = T5Tokenizer.from_pretrained(version) |
|
self.transformer = T5EncoderModel.from_pretrained(version) |
|
self.device = device |
|
self.max_length = max_length |
|
if freeze: |
|
self.freeze() |
|
|
|
def freeze(self): |
|
self.transformer = self.transformer.eval() |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, text): |
|
batch_encoding = self.tokenizer( |
|
text, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_length=True, |
|
return_overflowing_tokens=False, |
|
padding="max_length", |
|
return_tensors="pt", |
|
) |
|
tokens = batch_encoding["input_ids"].to(self.device) |
|
with torch.autocast("cuda", enabled=False): |
|
outputs = self.transformer(input_ids=tokens) |
|
z = outputs.last_hidden_state |
|
return z |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|
|
|
|
class FrozenByT5Embedder(AbstractEmbModel): |
|
""" |
|
Uses the ByT5 transformer encoder for text. Is character-aware. |
|
""" |
|
|
|
def __init__( |
|
self, version="google/byt5-base", device="cuda", max_length=77, freeze=True, *args, **kwargs |
|
): |
|
super().__init__(*args, **kwargs) |
|
self.tokenizer = ByT5Tokenizer.from_pretrained(version) |
|
self.transformer = T5EncoderModel.from_pretrained(version) |
|
self.device = device |
|
self.max_length = max_length |
|
if freeze: |
|
self.freeze() |
|
|
|
def freeze(self): |
|
self.transformer = self.transformer.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, text): |
|
batch_encoding = self.tokenizer( |
|
text, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_length=True, |
|
return_overflowing_tokens=False, |
|
padding="max_length", |
|
return_tensors="pt", |
|
) |
|
tokens = batch_encoding["input_ids"].to(next(self.parameters()).device) |
|
with torch.autocast("cuda", enabled=False): |
|
outputs = self.transformer(input_ids=tokens) |
|
z = outputs.last_hidden_state |
|
return z |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|
|
|
|
class FrozenCLIPEmbedder(AbstractEmbModel): |
|
"""Uses the CLIP transformer encoder for text (from huggingface)""" |
|
|
|
LAYERS = ["last", "pooled", "hidden"] |
|
|
|
def __init__( |
|
self, |
|
version="openai/clip-vit-large-patch14", |
|
device="cuda", |
|
max_length=77, |
|
freeze=True, |
|
layer="last", |
|
layer_idx=None, |
|
always_return_pooled=False, |
|
): |
|
super().__init__() |
|
assert layer in self.LAYERS |
|
self.tokenizer = CLIPTokenizer.from_pretrained(version) |
|
self.transformer = CLIPTextModel.from_pretrained(version) |
|
self.device = device |
|
self.max_length = max_length |
|
if freeze: |
|
self.freeze() |
|
self.layer = layer |
|
self.layer_idx = layer_idx |
|
self.return_pooled = always_return_pooled |
|
if layer == "hidden": |
|
assert layer_idx is not None |
|
assert 0 <= abs(layer_idx) <= 12 |
|
|
|
def freeze(self): |
|
self.transformer = self.transformer.eval() |
|
|
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
@autocast |
|
def forward(self, text): |
|
batch_encoding = self.tokenizer( |
|
text, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_length=True, |
|
return_overflowing_tokens=False, |
|
padding="max_length", |
|
return_tensors="pt", |
|
) |
|
device = next(self.transformer.parameters()).device |
|
tokens = batch_encoding["input_ids"].to(device) |
|
outputs = self.transformer( |
|
input_ids=tokens, output_hidden_states=self.layer == "hidden" |
|
) |
|
if self.layer == "last": |
|
z = outputs.last_hidden_state |
|
elif self.layer == "pooled": |
|
z = outputs.pooler_output[:, None, :] |
|
else: |
|
z = outputs.hidden_states[self.layer_idx] |
|
if self.return_pooled: |
|
return z, outputs.pooler_output |
|
return z |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|
|
|
|
class FrozenOpenCLIPEmbedder2(AbstractEmbModel): |
|
""" |
|
Uses the OpenCLIP transformer encoder for text |
|
""" |
|
|
|
LAYERS = ["pooled", "last", "penultimate"] |
|
|
|
def __init__( |
|
self, |
|
arch="ViT-H-14", |
|
version="laion2b_s32b_b79k", |
|
device="cuda", |
|
max_length=77, |
|
freeze=True, |
|
layer="last", |
|
always_return_pooled=False, |
|
legacy=True, |
|
): |
|
super().__init__() |
|
assert layer in self.LAYERS |
|
model, _, _ = open_clip.create_model_and_transforms( |
|
arch, |
|
device=torch.device("cpu"), |
|
pretrained=version, |
|
) |
|
del model.visual |
|
self.model = model |
|
|
|
self.device = device |
|
self.max_length = max_length |
|
self.return_pooled = always_return_pooled |
|
if freeze: |
|
self.freeze() |
|
self.layer = layer |
|
if self.layer == "last": |
|
self.layer_idx = 0 |
|
elif self.layer == "penultimate": |
|
self.layer_idx = 1 |
|
else: |
|
raise NotImplementedError() |
|
self.legacy = legacy |
|
|
|
def freeze(self): |
|
self.model = self.model.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
@autocast |
|
def forward(self, text): |
|
device = next(self.model.parameters()).device |
|
tokens = open_clip.tokenize(text) |
|
z = self.encode_with_transformer(tokens.to(device)) |
|
if not self.return_pooled and self.legacy: |
|
return z |
|
if self.return_pooled: |
|
assert not self.legacy |
|
return z[self.layer], z["pooled"] |
|
return z[self.layer] |
|
|
|
def encode_with_transformer(self, text): |
|
x = self.model.token_embedding(text) |
|
x = x + self.model.positional_embedding |
|
x = x.permute(1, 0, 2) |
|
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) |
|
if self.legacy: |
|
x = x[self.layer] |
|
x = self.model.ln_final(x) |
|
return x |
|
else: |
|
|
|
o = x["last"] |
|
o = self.model.ln_final(o) |
|
pooled = self.pool(o, text) |
|
x["pooled"] = pooled |
|
return x |
|
|
|
def pool(self, x, text): |
|
|
|
x = ( |
|
x[torch.arange(x.shape[0]), text.argmax(dim=-1)] |
|
@ self.model.text_projection |
|
) |
|
return x |
|
|
|
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): |
|
outputs = {} |
|
for i, r in enumerate(self.model.transformer.resblocks): |
|
if i == len(self.model.transformer.resblocks) - 1: |
|
outputs["penultimate"] = x.permute(1, 0, 2) |
|
if ( |
|
self.model.transformer.grad_checkpointing |
|
and not torch.jit.is_scripting() |
|
): |
|
x = checkpoint(r, x, attn_mask) |
|
else: |
|
x = r(x, attn_mask=attn_mask) |
|
outputs["last"] = x.permute(1, 0, 2) |
|
return outputs |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|
|
|
|
class FrozenOpenCLIPEmbedder(AbstractEmbModel): |
|
LAYERS = [ |
|
|
|
"last", |
|
"penultimate", |
|
] |
|
|
|
def __init__( |
|
self, |
|
arch="ViT-H-14", |
|
version="laion2b_s32b_b79k", |
|
device="cuda", |
|
max_length=77, |
|
freeze=True, |
|
layer="last", |
|
): |
|
super().__init__() |
|
assert layer in self.LAYERS |
|
model, _, _ = open_clip.create_model_and_transforms( |
|
arch, device=torch.device("cpu"), pretrained=version |
|
) |
|
del model.visual |
|
self.model = model |
|
|
|
self.device = device |
|
self.max_length = max_length |
|
if freeze: |
|
self.freeze() |
|
self.layer = layer |
|
if self.layer == "last": |
|
self.layer_idx = 0 |
|
elif self.layer == "penultimate": |
|
self.layer_idx = 1 |
|
else: |
|
raise NotImplementedError() |
|
|
|
def freeze(self): |
|
self.model = self.model.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, text): |
|
device = next(self.model.parameters()).device |
|
tokens = open_clip.tokenize(text) |
|
z = self.encode_with_transformer(tokens.to(device)) |
|
return z |
|
|
|
def encode_with_transformer(self, text): |
|
x = self.model.token_embedding(text) |
|
x = x + self.model.positional_embedding |
|
x = x.permute(1, 0, 2) |
|
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) |
|
x = x.permute(1, 0, 2) |
|
x = self.model.ln_final(x) |
|
return x |
|
|
|
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): |
|
for i, r in enumerate(self.model.transformer.resblocks): |
|
if i == len(self.model.transformer.resblocks) - self.layer_idx: |
|
break |
|
if ( |
|
self.model.transformer.grad_checkpointing |
|
and not torch.jit.is_scripting() |
|
): |
|
x = checkpoint(r, x, attn_mask) |
|
else: |
|
x = r(x, attn_mask=attn_mask) |
|
return x |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|
|
|
|
class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): |
|
""" |
|
Uses the OpenCLIP vision transformer encoder for images |
|
""" |
|
|
|
def __init__( |
|
self, |
|
arch="ViT-H-14", |
|
version="laion2b_s32b_b79k", |
|
device="cuda", |
|
max_length=77, |
|
freeze=True, |
|
antialias=True, |
|
ucg_rate=0.0, |
|
unsqueeze_dim=False, |
|
repeat_to_max_len=False, |
|
num_image_crops=0, |
|
output_tokens=False, |
|
): |
|
super().__init__() |
|
model, _, _ = open_clip.create_model_and_transforms( |
|
arch, |
|
device=torch.device("cpu"), |
|
pretrained=version, |
|
) |
|
del model.transformer |
|
self.model = model |
|
self.max_crops = num_image_crops |
|
self.pad_to_max_len = self.max_crops > 0 |
|
self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) |
|
self.device = device |
|
self.max_length = max_length |
|
if freeze: |
|
self.freeze() |
|
|
|
self.antialias = antialias |
|
|
|
self.register_buffer( |
|
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False |
|
) |
|
self.register_buffer( |
|
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False |
|
) |
|
self.ucg_rate = ucg_rate |
|
self.unsqueeze_dim = unsqueeze_dim |
|
self.stored_batch = None |
|
self.model.visual.output_tokens = output_tokens |
|
self.output_tokens = output_tokens |
|
|
|
def preprocess(self, x): |
|
|
|
x = kornia.geometry.resize( |
|
x, |
|
(224, 224), |
|
interpolation="bicubic", |
|
align_corners=True, |
|
antialias=self.antialias, |
|
) |
|
x = (x + 1.0) / 2.0 |
|
|
|
x = kornia.enhance.normalize(x, self.mean, self.std) |
|
return x |
|
|
|
def freeze(self): |
|
self.model = self.model.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
@autocast |
|
def forward(self, image, no_dropout=False): |
|
z = self.encode_with_vision_transformer(image) |
|
tokens = None |
|
if self.output_tokens: |
|
z, tokens = z[0], z[1] |
|
z = z.to(image.dtype) |
|
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): |
|
z = ( |
|
torch.bernoulli( |
|
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) |
|
)[:, None] |
|
* z |
|
) |
|
if tokens is not None: |
|
tokens = ( |
|
expand_dims_like( |
|
torch.bernoulli( |
|
(1.0 - self.ucg_rate) |
|
* torch.ones(tokens.shape[0], device=tokens.device) |
|
), |
|
tokens, |
|
) |
|
* tokens |
|
) |
|
if self.unsqueeze_dim: |
|
z = z[:, None, :] |
|
if self.output_tokens: |
|
assert not self.repeat_to_max_len |
|
assert not self.pad_to_max_len |
|
return tokens, z |
|
if self.repeat_to_max_len: |
|
if z.dim() == 2: |
|
z_ = z[:, None, :] |
|
else: |
|
z_ = z |
|
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z |
|
elif self.pad_to_max_len: |
|
assert z.dim() == 3 |
|
z_pad = torch.cat( |
|
( |
|
z, |
|
torch.zeros( |
|
z.shape[0], |
|
self.max_length - z.shape[1], |
|
z.shape[2], |
|
device=z.device, |
|
), |
|
), |
|
1, |
|
) |
|
return z_pad, z_pad[:, 0, ...] |
|
return z |
|
|
|
def encode_with_vision_transformer(self, img): |
|
|
|
|
|
if img.dim() == 5: |
|
assert self.max_crops == img.shape[1] |
|
img = rearrange(img, "b n c h w -> (b n) c h w") |
|
img = self.preprocess(img) |
|
if not self.output_tokens: |
|
assert not self.model.visual.output_tokens |
|
x = self.model.visual(img) |
|
tokens = None |
|
else: |
|
assert self.model.visual.output_tokens |
|
x, tokens = self.model.visual(img) |
|
if self.max_crops > 0: |
|
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) |
|
|
|
x = ( |
|
torch.bernoulli( |
|
(1.0 - self.ucg_rate) |
|
* torch.ones(x.shape[0], x.shape[1], 1, device=x.device) |
|
) |
|
* x |
|
) |
|
if tokens is not None: |
|
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) |
|
print( |
|
f"You are running very experimental token-concat in {self.__class__.__name__}. " |
|
f"Check what you are doing, and then remove this message." |
|
) |
|
if self.output_tokens: |
|
return x, tokens |
|
return x |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|
|
|
|
class FrozenCLIPT5Encoder(AbstractEmbModel): |
|
def __init__( |
|
self, |
|
clip_version="openai/clip-vit-large-patch14", |
|
t5_version="google/t5-v1_1-xl", |
|
device="cuda", |
|
clip_max_length=77, |
|
t5_max_length=77, |
|
): |
|
super().__init__() |
|
self.clip_encoder = FrozenCLIPEmbedder( |
|
clip_version, device, max_length=clip_max_length |
|
) |
|
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) |
|
print( |
|
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " |
|
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." |
|
) |
|
|
|
def encode(self, text): |
|
return self(text) |
|
|
|
def forward(self, text): |
|
clip_z = self.clip_encoder.encode(text) |
|
t5_z = self.t5_encoder.encode(text) |
|
return [clip_z, t5_z] |
|
|
|
|
|
class SpatialRescaler(nn.Module): |
|
def __init__( |
|
self, |
|
n_stages=1, |
|
method="bilinear", |
|
multiplier=0.5, |
|
in_channels=3, |
|
out_channels=None, |
|
bias=False, |
|
wrap_video=False, |
|
kernel_size=1, |
|
remap_output=False, |
|
): |
|
super().__init__() |
|
self.n_stages = n_stages |
|
assert self.n_stages >= 0 |
|
assert method in [ |
|
"nearest", |
|
"linear", |
|
"bilinear", |
|
"trilinear", |
|
"bicubic", |
|
"area", |
|
] |
|
self.multiplier = multiplier |
|
self.interpolator = partial(torch.nn.functional.interpolate, mode=method) |
|
self.remap_output = out_channels is not None or remap_output |
|
if self.remap_output: |
|
print( |
|
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." |
|
) |
|
self.channel_mapper = nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
bias=bias, |
|
padding=kernel_size // 2, |
|
) |
|
self.wrap_video = wrap_video |
|
|
|
def forward(self, x): |
|
if self.wrap_video and x.ndim == 5: |
|
B, C, T, H, W = x.shape |
|
x = rearrange(x, "b c t h w -> b t c h w") |
|
x = rearrange(x, "b t c h w -> (b t) c h w") |
|
|
|
for stage in range(self.n_stages): |
|
x = self.interpolator(x, scale_factor=self.multiplier) |
|
|
|
if self.wrap_video: |
|
x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) |
|
x = rearrange(x, "b t c h w -> b c t h w") |
|
if self.remap_output: |
|
x = self.channel_mapper(x) |
|
return x |
|
|
|
def encode(self, x): |
|
return self(x) |
|
|
|
|
|
class LowScaleEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
model_config, |
|
linear_start, |
|
linear_end, |
|
timesteps=1000, |
|
max_noise_level=250, |
|
output_size=64, |
|
scale_factor=1.0, |
|
): |
|
super().__init__() |
|
self.max_noise_level = max_noise_level |
|
self.model = instantiate_from_config(model_config) |
|
self.augmentation_schedule = self.register_schedule( |
|
timesteps=timesteps, linear_start=linear_start, linear_end=linear_end |
|
) |
|
self.out_size = output_size |
|
self.scale_factor = scale_factor |
|
|
|
def register_schedule( |
|
self, |
|
beta_schedule="linear", |
|
timesteps=1000, |
|
linear_start=1e-4, |
|
linear_end=2e-2, |
|
cosine_s=8e-3, |
|
): |
|
betas = make_beta_schedule( |
|
beta_schedule, |
|
timesteps, |
|
linear_start=linear_start, |
|
linear_end=linear_end, |
|
cosine_s=cosine_s, |
|
) |
|
alphas = 1.0 - betas |
|
alphas_cumprod = np.cumprod(alphas, axis=0) |
|
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) |
|
|
|
(timesteps,) = betas.shape |
|
self.num_timesteps = int(timesteps) |
|
self.linear_start = linear_start |
|
self.linear_end = linear_end |
|
assert ( |
|
alphas_cumprod.shape[0] == self.num_timesteps |
|
), "alphas have to be defined for each timestep" |
|
|
|
to_torch = partial(torch.tensor, dtype=torch.float32) |
|
|
|
self.register_buffer("betas", to_torch(betas)) |
|
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) |
|
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) |
|
|
|
|
|
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) |
|
self.register_buffer( |
|
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) |
|
) |
|
self.register_buffer( |
|
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) |
|
) |
|
self.register_buffer( |
|
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) |
|
) |
|
self.register_buffer( |
|
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) |
|
) |
|
|
|
def q_sample(self, x_start, t, noise=None): |
|
noise = default(noise, lambda: torch.randn_like(x_start)) |
|
return ( |
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start |
|
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) |
|
* noise |
|
) |
|
|
|
def forward(self, x): |
|
z = self.model.encode(x) |
|
if isinstance(z, DiagonalGaussianDistribution): |
|
z = z.sample() |
|
z = z * self.scale_factor |
|
noise_level = torch.randint( |
|
0, self.max_noise_level, (x.shape[0],), device=x.device |
|
).long() |
|
z = self.q_sample(z, noise_level) |
|
if self.out_size is not None: |
|
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") |
|
|
|
return z, noise_level |
|
|
|
def decode(self, z): |
|
z = z / self.scale_factor |
|
return self.model.decode(z) |
|
|
|
|
|
class ConcatTimestepEmbedderND(AbstractEmbModel): |
|
"""embeds each dimension independently and concatenates them""" |
|
|
|
def __init__(self, outdim): |
|
super().__init__() |
|
self.timestep = Timestep(outdim) |
|
self.outdim = outdim |
|
|
|
def freeze(self): |
|
self.eval() |
|
|
|
def forward(self, x): |
|
if x.ndim == 1: |
|
x = x[:, None] |
|
assert len(x.shape) == 2 |
|
b, dims = x.shape[0], x.shape[1] |
|
x = rearrange(x, "b d -> (b d)") |
|
emb = self.timestep(x) |
|
emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) |
|
return emb |
|
|
|
|
|
class GaussianEncoder(Encoder, AbstractEmbModel): |
|
def __init__( |
|
self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs |
|
): |
|
super().__init__(*args, **kwargs) |
|
self.posterior = DiagonalGaussianRegularizer() |
|
self.weight = weight |
|
self.flatten_output = flatten_output |
|
|
|
def forward(self, x) -> Tuple[Dict, torch.Tensor]: |
|
z = super().forward(x) |
|
z, log = self.posterior(z) |
|
log["loss"] = log["kl_loss"] |
|
log["weight"] = self.weight |
|
if self.flatten_output: |
|
z = rearrange(z, "b c h w -> b (h w ) c") |
|
return log, z |
|
|
|
|
|
class LatentEncoder(AbstractEmbModel): |
|
|
|
def __init__(self, scale_factor, config, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.scale_factor = scale_factor |
|
self.model = instantiate_from_config(config).eval() |
|
self.model.train = disabled_train |
|
|
|
def freeze(self): |
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, x): |
|
z = self.model.encode(x) |
|
z = self.scale_factor * z |
|
return z |
|
|
|
|
|
class ViTSTREncoder(VisionTransformer): |
|
''' |
|
ViTSTREncoder is basically a ViT that uses ViTSTR weights |
|
''' |
|
def __init__(self, size=224, ckpt_path=None, freeze=True, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.grayscale = transforms.Grayscale() |
|
self.resize = transforms.Resize((size, size), transforms.InterpolationMode.BICUBIC, antialias=True) |
|
|
|
self.character = string.printable[:-6] |
|
self.reset_classifier(num_classes=len(self.character)+2) |
|
|
|
if ckpt_path is not None: |
|
self.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) |
|
|
|
if freeze: |
|
self.freeze() |
|
|
|
def reset_classifier(self, num_classes): |
|
self.num_classes = num_classes |
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
def freeze(self): |
|
for param in self.parameters(): |
|
param.requires_grad_(False) |
|
|
|
def forward_features(self, x): |
|
B = x.shape[0] |
|
x = self.patch_embed(x) |
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
x = torch.cat((cls_tokens, x), dim=1) |
|
x = x + self.pos_embed |
|
x = self.pos_drop(x) |
|
|
|
for blk in self.blocks: |
|
x = blk(x) |
|
|
|
x = self.norm(x) |
|
return x |
|
|
|
def forward(self, x): |
|
|
|
x = self.forward_features(x) |
|
|
|
return x |
|
|
|
def encode(self, x): |
|
return self(x) |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
|
def __init__(self, d_model, dropout=0.1, max_len=5000): |
|
super(PositionalEncoding, self).__init__() |
|
|
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
pe = torch.zeros(max_len, d_model) |
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
x = x + torch.tile(self.pe[None, ...].to(x.device), (x.shape[0], 1, 1)) |
|
return self.dropout(x) |
|
|
|
|
|
class LabelEncoder(AbstractEmbModel, pl.LightningModule): |
|
|
|
def __init__(self, max_len, emb_dim, n_heads=8, n_trans_layers=12, ckpt_path=None, trainable=False, |
|
lr=1e-4, lambda_cls=0.1, lambda_pos=0.1, clip_dim=1024, visual_len=197, visual_dim=768, visual_config=None, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.max_len = max_len |
|
self.emd_dim = emb_dim |
|
self.n_heads = n_heads |
|
self.n_trans_layers = n_trans_layers |
|
self.character = string.printable[:-6] |
|
self.num_cls = len(self.character) + 1 |
|
|
|
self.label_embedding = nn.Embedding(self.num_cls, self.emd_dim) |
|
self.pos_embedding = PositionalEncoding(d_model=self.emd_dim, max_len=self.max_len) |
|
transformer_block = nn.TransformerEncoderLayer(d_model=self.emd_dim, nhead=self.n_heads, batch_first=True) |
|
self.encoder = nn.TransformerEncoder(transformer_block, num_layers=self.n_trans_layers) |
|
|
|
if ckpt_path is not None: |
|
self.load_state_dict(torch.load(ckpt_path, map_location="cpu")["state_dict"], strict=False) |
|
|
|
if trainable: |
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
self.visual_encoder = instantiate_from_config(visual_config) |
|
|
|
self.learning_rate = lr |
|
self.clip_dim = clip_dim |
|
self.visual_len = visual_len |
|
self.visual_dim = visual_dim |
|
self.lambda_cls = lambda_cls |
|
self.lambda_pos = lambda_pos |
|
|
|
self.cls_head = nn.Sequential(*[ |
|
nn.InstanceNorm1d(self.max_len), |
|
nn.Linear(self.emd_dim, self.emd_dim), |
|
nn.GELU(), |
|
nn.Linear(self.emd_dim, self.num_cls) |
|
]) |
|
|
|
self.pos_head = nn.Sequential(*[ |
|
nn.InstanceNorm1d(self.max_len), |
|
nn.Linear(self.emd_dim, self.max_len, bias=False) |
|
]) |
|
|
|
self.text_head = nn.Sequential(*[ |
|
nn.InstanceNorm1d(self.max_len), |
|
nn.Linear(self.emd_dim, self.clip_dim, bias=False), |
|
nn.Conv1d(in_channels=self.max_len, out_channels=1, kernel_size=1) |
|
]) |
|
|
|
self.visual_head = nn.Sequential(*[ |
|
nn.InstanceNorm1d(self.visual_len), |
|
nn.Linear(self.visual_dim, self.clip_dim, bias=False), |
|
nn.Conv1d(in_channels=self.visual_len, out_channels=1, kernel_size=1) |
|
]) |
|
|
|
def freeze(self): |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def get_index(self, labels): |
|
|
|
indexes = [] |
|
for label in labels: |
|
assert len(label) <= self.max_len |
|
index = [self.character.find(c)+1 for c in label] |
|
index = index + [0] * (self.max_len - len(index)) |
|
indexes.append(index) |
|
|
|
return torch.tensor(indexes, device=next(self.parameters()).device) |
|
|
|
def get_embeddings(self, x): |
|
|
|
emb = self.label_embedding(x) |
|
emb = self.pos_embedding(emb) |
|
out = self.encoder(emb) |
|
|
|
return out |
|
|
|
def forward(self, labels): |
|
|
|
idx = self.get_index(labels) |
|
out = self.get_embeddings(idx) |
|
|
|
return out |
|
|
|
def get_loss(self, text_out, visual_out, clip_target, cls_out, pos_out, cls_target, pos_target): |
|
|
|
text_out = text_out / text_out.norm(dim=1, keepdim=True) |
|
visual_out = visual_out / visual_out.norm(dim=1, keepdim=True) |
|
|
|
logit_scale = self.logit_scale.exp() |
|
logits_per_image = logit_scale * visual_out @ text_out.T |
|
logits_per_text = logits_per_image.T |
|
|
|
clip_loss_image = nn.functional.cross_entropy(logits_per_image, clip_target) |
|
clip_loss_text = nn.functional.cross_entropy(logits_per_text, clip_target) |
|
clip_loss = (clip_loss_image + clip_loss_text) / 2 |
|
|
|
cls_loss = nn.functional.cross_entropy(cls_out.permute(0,2,1), cls_target) |
|
pos_loss = nn.functional.cross_entropy(pos_out.permute(0,2,1), pos_target) |
|
|
|
return clip_loss, cls_loss, pos_loss, logits_per_text |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
text = batch["text"] |
|
image = batch["image"] |
|
|
|
idx = self.get_index(text) |
|
text_emb = self.get_embeddings(idx) |
|
visual_emb = self.visual_encoder(image) |
|
|
|
cls_out = self.cls_head(text_emb) |
|
pos_out = self.pos_head(text_emb) |
|
text_out = self.text_head(text_emb).squeeze(1) |
|
visual_out = self.visual_head(visual_emb).squeeze(1) |
|
|
|
cls_target = idx |
|
pos_target = torch.arange(start=0, end=self.max_len, step=1) |
|
pos_target = pos_target[None].tile((idx.shape[0], 1)).to(cls_target) |
|
clip_target = torch.arange(0, idx.shape[0], 1).to(cls_target) |
|
|
|
clip_loss, cls_loss, pos_loss, logits_per_text = self.get_loss(text_out, visual_out, clip_target, cls_out, pos_out, cls_target, pos_target) |
|
loss = clip_loss + self.lambda_cls * cls_loss + self.lambda_pos * pos_loss |
|
|
|
loss_dict = {} |
|
loss_dict["loss/clip_loss"] = clip_loss |
|
loss_dict["loss/cls_loss"] = cls_loss |
|
loss_dict["loss/pos_loss"] = pos_loss |
|
loss_dict["loss/full_loss"] = loss |
|
|
|
clip_idx = torch.max(logits_per_text, dim=-1).indices |
|
clip_acc = (clip_idx == clip_target).to(dtype=torch.float32).mean() |
|
|
|
cls_idx = torch.max(cls_out, dim=-1).indices |
|
cls_acc = (cls_idx == cls_target).to(dtype=torch.float32).mean() |
|
|
|
pos_idx = torch.max(pos_out, dim=-1).indices |
|
pos_acc = (pos_idx == pos_target).to(dtype=torch.float32).mean() |
|
|
|
loss_dict["acc/clip_acc"] = clip_acc |
|
loss_dict["acc/cls_acc"] = cls_acc |
|
loss_dict["acc/pos_acc"] = pos_acc |
|
|
|
self.log_dict(loss_dict, prog_bar=True, batch_size=len(text), |
|
logger=True, on_step=True, on_epoch=True, sync_dist=True) |
|
|
|
return loss |
|
|
|
def configure_optimizers(self): |
|
|
|
lr = self.learning_rate |
|
opt = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=lr) |
|
|
|
return opt |
|
|
|
|
|
|