""" Hugging Face compatible implementation of Open-MAGVIt2 Code reference: https://github.com/TencentARC/Open-MAGVIT2 """ from math import log2, ceil from collections import namedtuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, reduce, pack, unpack from torch import einsum from torch.nn import Module from transformers import PreTrainedModel from .configuration_lfq_tokenizer import LFQTokenizerConfig def swish(x): # swish return x * torch.sigmoid(x) class ResBlock(nn.Module): def __init__(self, in_filters, out_filters, use_conv_shortcut = False ) -> None: super().__init__() self.in_filters = in_filters self.out_filters = out_filters self.use_conv_shortcut = use_conv_shortcut self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6) self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6) self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) if in_filters != out_filters: if self.use_conv_shortcut: self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) else: self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False) def forward(self, x, **kwargs): residual = x x = self.norm1(x) x = swish(x) x = self.conv1(x) x = self.norm2(x) x = swish(x) x = self.conv2(x) if self.in_filters != self.out_filters: if self.use_conv_shortcut: residual = self.conv_shortcut(residual) else: residual = self.nin_shortcut(residual) return x + residual class Encoder(nn.Module): def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)): super().__init__() self.in_channels = in_channels self.z_channels = z_channels self.num_res_blocks = num_res_blocks self.num_blocks = len(ch_mult) self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=(3, 3), padding=1, bias=False ) ## construct the model self.down = nn.ModuleList() in_ch_mult = (1,)+tuple(ch_mult) for i_level in range(self.num_blocks): block = nn.ModuleList() block_in = ch*in_ch_mult[i_level] #[1, 1, 2, 2, 4] block_out = ch*ch_mult[i_level] #[1, 2, 2, 4] for _ in range(self.num_res_blocks): block.append(ResBlock(block_in, block_out)) block_in = block_out down = nn.Module() down.block = block if i_level < self.num_blocks - 1: down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1) self.down.append(down) ### mid self.mid_block = nn.ModuleList() for res_idx in range(self.num_res_blocks): self.mid_block.append(ResBlock(block_in, block_in)) ### end self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6) self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1)) def forward(self, x): ## down x = self.conv_in(x) for i_level in range(self.num_blocks): for i_block in range(self.num_res_blocks): x = self.down[i_level].block[i_block](x) if i_level < self.num_blocks - 1: x = self.down[i_level].downsample(x) ## mid for res in range(self.num_res_blocks): x = self.mid_block[res](x) x = self.norm_out(x) x = swish(x) x = self.conv_out(x) return x class Decoder(nn.Module): def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)) -> None: super().__init__() self.ch = ch self.num_blocks = len(ch_mult) self.num_res_blocks = num_res_blocks self.in_channels = in_channels block_in = ch*ch_mult[self.num_blocks-1] self.conv_in = nn.Conv2d( z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True ) self.mid_block = nn.ModuleList() for res_idx in range(self.num_res_blocks): self.mid_block.append(ResBlock(block_in, block_in)) self.up = nn.ModuleList() for i_level in reversed(range(self.num_blocks)): block = nn.ModuleList() block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append(ResBlock(block_in, block_out)) block_in = block_out up = nn.Module() up.block = block if i_level > 0: up.upsample = Upsampler(block_in) self.up.insert(0, up) self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1) def forward(self, z): z = self.conv_in(z) ## mid for res in range(self.num_res_blocks): z = self.mid_block[res](z) ## upsample for i_level in reversed(range(self.num_blocks)): for i_block in range(self.num_res_blocks): z = self.up[i_level].block[i_block](z) if i_level > 0: z = self.up[i_level].upsample(z) z = self.norm_out(z) z = swish(z) z = self.conv_out(z) return z def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor: """ Depth-to-Space DCR mode (depth-column-row) core implementation. Args: x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported. block_size (int): block side size """ # check inputs if x.dim() < 3: raise ValueError( f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions" ) c, h, w = x.shape[-3:] s = block_size**2 if c % s != 0: raise ValueError( f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels" ) outer_dims = x.shape[:-3] # splitting two additional dimensions from the channel dimension x = x.view(-1, block_size, block_size, c // s, h, w) # putting the two new dimensions along H and W x = x.permute(0, 3, 4, 1, 5, 2) # merging the two new dimensions with H and W x = x.contiguous().view(*outer_dims, c // s, h * block_size, w * block_size) return x class Upsampler(nn.Module): def __init__( self, dim, dim_out = None ): super().__init__() dim_out = dim * 4 self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1) self.depth2space = depth_to_space def forward(self, x): """ input_image: [B C H W] """ out = self.conv1(x) out = self.depth2space(out, block_size=2) return out class AdaptiveGroupNorm(nn.Module): def __init__(self, z_channel, in_filters, num_groups=32, eps=1e-6): super().__init__() self.gn = nn.GroupNorm(num_groups=32, num_channels=in_filters, eps=eps, affine=False) # self.lin = nn.Linear(z_channels, in_filters * 2) self.gamma = nn.Linear(z_channel, in_filters) self.beta = nn.Linear(z_channel, in_filters) self.eps = eps def forward(self, x, quantizer): B, C, _, _ = x.shape # quantizer = F.adaptive_avg_pool2d(quantizer, (1, 1)) ### calcuate var for scale scale = rearrange(quantizer, "b c h w -> b c (h w)") scale = scale.var(dim=-1) + self.eps #not unbias scale = scale.sqrt() scale = self.gamma(scale).view(B, C, 1, 1) ### calculate mean for bias bias = rearrange(quantizer, "b c h w -> b c (h w)") bias = bias.mean(dim=-1) bias = self.beta(bias).view(B, C, 1, 1) x = self.gn(x) x = scale * x + bias return x # constants LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'codebook_entropy', 'commitment', 'avg_probs']) # helper functions def exists(v): return v is not None def default(*args): for arg in args: if exists(arg): return arg() if callable(arg) else arg return None def pack_one(t, pattern): return pack([t], pattern) def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] # entropy def entropy(prob): return (-prob * torch.log(prob + 1e-5)).sum(dim=-1) # class def mult_along_first_dims(x, y): """ returns x * y elementwise along the leading dimensions of y """ ndim_to_expand = x.ndim - y.ndim for _ in range(ndim_to_expand): y = y.unsqueeze(-1) return x * y def masked_mean(x, m): """ takes the mean of the elements of x that are not masked the mean is taken along the shared leading dims of m equivalent to: x[m].mean(tuple(range(m.ndim))) The benefit of using masked_mean rather than using tensor indexing is that masked_mean is much faster for torch-compile on batches. The drawback is larger floating point errors """ x = mult_along_first_dims(x, m) x = x / m.sum() return x.sum(tuple(range(m.ndim))) def entropy_loss( logits, mask=None, temperature=0.01, sample_minimization_weight=1.0, batch_maximization_weight=1.0, eps=1e-5, ): """ Entropy loss of unnormalized logits logits: Affinities are over the last dimension https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) """ probs = F.softmax(logits / temperature, -1) log_probs = F.log_softmax(logits / temperature + eps, -1) if mask is not None: avg_probs = masked_mean(probs, mask) else: avg_probs = reduce(probs, "... D -> D", "mean") avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) sample_entropy = -torch.sum(probs * log_probs, -1) if mask is not None: sample_entropy = masked_mean(sample_entropy, mask).mean() else: sample_entropy = torch.mean(sample_entropy) loss = (sample_minimization_weight * sample_entropy) - ( batch_maximization_weight * avg_entropy ) return sample_entropy, avg_entropy, loss class LFQ(Module): def __init__( self, *, dim = None, codebook_size = None, num_codebooks = 1, sample_minimization_weight=1.0, batch_maximization_weight=1.0, token_factorization = False, ): super().__init__() # some assert validations assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ' assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})' self.codebook_size = default(codebook_size, lambda: 2 ** dim) self.codebook_dim = int(log2(codebook_size)) codebook_dims = self.codebook_dim * num_codebooks dim = default(dim, codebook_dims) has_projections = dim != codebook_dims self.has_projections = has_projections self.dim = dim self.codebook_dim = self.codebook_dim self.num_codebooks = num_codebooks # for entropy loss self.sample_minimization_weight = sample_minimization_weight self.batch_maximization_weight = batch_maximization_weight # for no auxiliary loss, during inference self.token_factorization = token_factorization ## only utilized in second stage if not self.token_factorization: #for first stage model self.register_buffer('mask', 2 ** torch.arange(self.codebook_dim - 1, -1, -1), persistent=False) else: k = self.codebook_dim // 2 self.register_buffer("mask", 2 ** torch.arange(k - 1, -1, -1), persistent=False) self.register_buffer('zero', torch.tensor(0.), persistent = False) # codes all_codes = torch.arange(codebook_size) bits = self.indices_to_bits(all_codes) codebook = bits * 2.0 - 1.0 self.register_buffer('codebook', codebook, persistent = False) @property def dtype(self): return self.codebook.dtype def indices_to_bits(self, x): """ x: long tensor of indices for constructing codebook, but actually not utilized in all the experiments. returns big endian bits """ mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) # x is now big endian bits, the last dimension being the bits x = (x.unsqueeze(-1) & mask) != 0 return x def get_codebook_entry(self, x, bhwc): if self.token_factorization: k = self.codebook_dim // 2 mask = 2 ** torch.arange(k - 1, -1, -1, device=x.device, dtype=torch.long) else: mask = 2 ** torch.arange(self.codebook_dim-1, -1, -1, device=x.device, dtype=torch.long) x = (x.unsqueeze(-1) & mask) != 0 x = x * 2.0 - 1.0 #back to the float ## scale back to the desired shape b, h, w, c = bhwc x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c) x = rearrange(x, "b h w c -> b c h w") return x def bits_to_indices(self, bits): """ bits: bool tensor of big endian bits, where the last dimension is the bit dimension returns indices, which are long integers from 0 to self.codebook_size """ assert bits.shape[-1] == self.codebook_dim indices = 2 ** torch.arange( 0, self.codebook_dim, 1, dtype=torch.long, device=bits.device, ) return (bits * indices).sum(-1) def decode(self, x): """ x: ... NH where NH is number of codebook heads A longtensor of codebook indices, containing values from 0 to self.codebook_size """ x = self.indices_to_bits(x) # to some sort of float x = x.to(self.dtype) # -1 or 1 x = x * 2 - 1 x = rearrange(x, "... NC Z-> ... (NC Z)") return x def forward( self, x, return_loss_breakdown = False, mask = None, return_loss = True, ): """ einstein notation b - batch n - sequence (or flattened spatial dimensions) d - feature dimension, which is also log2(codebook size) c - number of codebook dim """ x = rearrange(x, 'b d ... -> b ... d') x, ps = pack_one(x, 'b * d') # split out number of codebooks x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype) quantized = torch.where(x > 0, codebook_value, -codebook_value) # higher than 0 filled # calculate indices if self.token_factorization: k = self.codebook_dim // 2 indices_pre = reduce((quantized[..., :k] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") indices_post = reduce((quantized[..., k:] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") # indices_post = 2**k + indices_post #shifter to the 1024 else: indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum') # entropy aux loss if self.training and return_loss: logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) # the same as euclidean distance up to a constant per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss( logits = logits, sample_minimization_weight = self.sample_minimization_weight, batch_maximization_weight = self.batch_maximization_weight ) avg_probs = self.zero else: ## calculate the codebook_entropy needed for one batch evaluation #------------------------------------------------------------------ # logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) # probs = F.softmax(logits / 0.01, -1) # avg_probs = reduce(probs, "b n c d -> b d", "mean") # avg_probs = torch.sum(avg_probs, 0) #batch dimension #------------------------------------------------------------------- # if not training, just return dummy 0 per_sample_entropy = codebook_entropy = self.zero entropy_aux_loss = self.zero avg_probs = self.zero # commit loss if self.training: commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none') if exists(mask): commit_loss = commit_loss[mask] commit_loss = commit_loss.mean() else: commit_loss = self.zero # use straight-through gradients (optionally with custom activation fn) if training quantized = x + (quantized - x).detach() #transfer to quantized # merge back codebook dim quantized = rearrange(quantized, 'b n c d -> b n (c d)') # reconstitute image or video dimensions quantized = unpack_one(quantized, ps, 'b * d') quantized = rearrange(quantized, 'b ... d -> b d ...') if self.token_factorization: indices_pre = unpack_one(indices_pre, ps, "b * c") indices_post = unpack_one(indices_post, ps, "b * c") indices_pre = indices_pre.flatten() indices_post = indices_post.flatten() indices = (indices_pre, indices_post) else: indices = unpack_one(indices, ps, 'b * c') indices = indices.flatten() ret = (quantized, entropy_aux_loss, indices) if not return_loss_breakdown: return ret return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss, avg_probs) class LFQTokenizer(PreTrainedModel): config_class = LFQTokenizerConfig def __init__(self, config: LFQTokenizerConfig): super().__init__(config) self.encoder = Encoder(**config.encoder_decoder_config) self.decoder = Decoder(**config.encoder_decoder_config) self.quantize = LFQ(**config.quantizer_config) def encode(self, x): h = self.encoder(x) (quant, emb_loss, info), loss_breakdown = self.quantize(h, return_loss_breakdown=True) return quant, emb_loss, info, loss_breakdown def decode(self, quant): return self.decoder(quant) def forward(self, input): quant, diff, _, loss_breakdown = self.encode(input) dec = self.decoder(quant) return dec, diff, loss_breakdown def tokenize(self, input): _, _, tokens, _ = self.encode(input) return tokens def get_last_layer(self): return self.decoder.conv_out.weight def decode_tokens(self, tokens, shape: tuple): if self.quantize.token_factorization: tokens_pre, tokens_post = tokens[0], tokens[1] quant_pre = self.quantize.get_codebook_entry(tokens_pre, shape) quant_post = self.quantize.get_codebook_entry(tokens_post, shape) quant = torch.concat([quant_pre, quant_post], dim=1) return self.decode(quant) else: if tokens.ndim == 1: batch_size = shape[0] tokens = tokens.view(batch_size, -1) quant = self.quantize.get_codebook_entry(tokens, shape) return self.decode(quant)