Spaces:
Running
on
Zero
Running
on
Zero
import pdb | |
import numpy as np | |
import torch as t | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import utils.dist_adapter as dist | |
import sys | |
[sys.path.append(i) for i in ['.', '..']] | |
from utils.torch_utils import parse_args | |
args = parse_args() | |
mydevice = t.device('cuda:' + args.gpu) | |
class BottleneckBlock(nn.Module): | |
def __init__(self, k_bins, emb_width, mu): | |
super().__init__() | |
self.k_bins = k_bins | |
self.emb_width = emb_width | |
self.mu = mu | |
self.reset_k() | |
self.threshold = 1.0 | |
def reset_k(self): | |
self.init = False | |
self.k_sum = None | |
self.k_elem = None | |
self.register_buffer('k', t.zeros(self.k_bins, self.emb_width).cuda()) | |
def _tile(self, x): | |
d, ew = x.shape # 960, 512 | |
if d < self.k_bins: | |
n_repeats = (self.k_bins + d - 1) // d | |
std = 0.01 / np.sqrt(ew) | |
x = x.repeat(n_repeats, 1) | |
x = x + t.randn_like(x) * std | |
return x | |
def init_k(self, x): | |
mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins # mu=0.99, emb_width=512, k_bins=512 | |
self.init = True | |
# init k_w using random vectors from x | |
y = self._tile(x) | |
_k_rand = y[t.randperm(y.shape[0])][:k_bins] # (512, 512), a random permutation of integers from 0 to n - 1 | |
# dist.broadcast(_k_rand, 0) | |
self.k = _k_rand | |
assert self.k.shape == (k_bins, emb_width) | |
self.k_sum = self.k | |
self.k_elem = t.ones(k_bins, device=self.k.device) | |
def restore_k(self, num_tokens=None, threshold=1.0): | |
mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins | |
self.init = True | |
assert self.k.shape == (k_bins, emb_width) | |
self.k_sum = self.k.clone() | |
self.k_elem = t.ones(k_bins, device=self.k.device) | |
if num_tokens is not None: | |
expected_usage = num_tokens / k_bins | |
self.k_elem.data.mul_(expected_usage) | |
self.k_sum.data.mul_(expected_usage) | |
self.threshold = threshold | |
def update_k(self, x, x_l): # (960, 512), (960) | |
mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins # mu=0.99, emb_width=512, k_bins=512 | |
with t.no_grad(): | |
# Calculate new centres | |
x_l_onehot = t.zeros(k_bins, x.shape[0], device=x.device) # (512(k_bins), 960(N * L)) | |
x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1) # (1, 190) -> (512, 960), find which axis | |
_k_sum = t.matmul(x_l_onehot, x) #(512(k_bins), 512(w)) | |
_k_elem = x_l_onehot.sum(dim=-1) # (512(k_bins)) | |
y = self._tile(x) # (960, 512) | |
_k_rand = y[t.randperm(y.shape[0])][:k_bins] # (512, 512) | |
# dist.broadcast(_k_rand, 0) | |
# dist.all_reduce(_k_sum) | |
# dist.all_reduce(_k_elem) | |
# Update centres | |
old_k = self.k | |
self.k_sum = mu * self.k_sum + (1. - mu) * _k_sum # w, k_bins | |
self.k_elem = mu * self.k_elem + (1. - mu) * _k_elem # k_bins | |
usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float() | |
self.k = usage * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) \ | |
+ (1 - usage) * _k_rand | |
_k_prob = _k_elem / t.sum(_k_elem) # x_l_onehot.mean(dim=-1) # prob of each bin | |
entropy = -t.sum(_k_prob * t.log(_k_prob + 1e-8)) # entropy ie how diverse | |
used_curr = (_k_elem >= self.threshold).sum() | |
usage = t.sum(usage) | |
dk = t.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape)) | |
return dict(entropy=entropy, | |
used_curr=used_curr, | |
usage=usage, | |
dk=dk) | |
def preprocess(self, x): | |
# NCT -> NTC -> [NT, C] | |
x = x.permute(0, 2, 1).contiguous() | |
x = x.view(-1, x.shape[-1]) # x_en = (N * L, w), k_j = (w, k_bins) | |
if x.shape[-1] == self.emb_width: | |
prenorm = t.norm(x - t.mean(x)) / np.sqrt(np.prod(x.shape)) # np.sqrt - product of array elements over a given axis | |
elif x.shape[-1] == 2 * self.emb_width: | |
x1, x2 = x[...,:self.emb_width], x[...,self.emb_width:] | |
prenorm = (t.norm(x1 - t.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (t.norm(x2 - t.mean(x2)) / np.sqrt(np.prod(x2.shape))) | |
# Normalise | |
x = x1 + x2 | |
else: | |
assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}" | |
return x, prenorm | |
def postprocess(self, x_l, x_d, x_shape): | |
# [NT, C] -> NTC -> NCT | |
N, T = x_shape | |
x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() | |
x_l = x_l.view(N, T) | |
return x_l, x_d | |
def quantise(self, x): | |
# Calculate latent code x_l | |
k_w = self.k.t() # (512, 512) | |
distance = t.sum(x ** 2, dim=-1, keepdim=True) - 2 * t.matmul(x, k_w) + t.sum(k_w ** 2, dim=0, keepdim=True) # (960(N * L), 512(b)) | |
min_distance, x_l = t.min(distance, dim=-1) # (960), (960) | |
fit = t.mean(min_distance) | |
return x_l, fit | |
def dequantise(self, x_l): | |
x = F.embedding(x_l, self.k) # self.k: (512, 512) weighted array | |
return x | |
def encode(self, x): | |
N, width, T = x.shape | |
# Preprocess. | |
x, prenorm = self.preprocess(x) | |
# Quantise | |
x_l, fit = self.quantise(x) | |
# Postprocess. | |
x_l = x_l.view(N, T) | |
return x_l | |
def decode(self, x_l): | |
N, T = x_l.shape | |
width = self.emb_width | |
# Dequantise | |
x_d = self.dequantise(x_l) | |
# Postprocess | |
x_d = x_d.view(N, T, width).permute(0, 2, 1).contiguous() | |
return x_d | |
def forward(self, x, update_k=True): | |
N, width, T = x.shape # 32, 512, 30 | |
# Preprocess | |
x, prenorm = self.preprocess(x) # (960, 512), 0.2888 | |
# Init k if not inited | |
if update_k and not self.init: | |
self.init_k(x) | |
# Quantise and dequantise through bottleneck | |
x_l, fit = self.quantise(x) # (960), 34.1081 | |
x_d = self.dequantise(x_l) # (960, 512) | |
# Update embeddings | |
if update_k: | |
update_metrics = self.update_k(x, x_l) | |
else: | |
update_metrics = {} | |
# Loss | |
commit_loss = t.norm(x_d.detach() - x) ** 2 / np.prod(x.shape) # L2 loss -> L1 loss | |
# Passthrough | |
x_d = x + (x_d - x).detach() | |
# Postprocess | |
x_l, x_d = self.postprocess(x_l, x_d, (N,T)) | |
return x_l, x_d, commit_loss, dict(fit=fit, | |
pn=prenorm, | |
**update_metrics) | |
class Bottleneck(nn.Module): | |
def __init__(self, l_bins, emb_width, mu, levels): | |
super().__init__() | |
self.levels = levels | |
level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu) | |
self.level_blocks = nn.ModuleList() | |
for level in range(self.levels): | |
self.level_blocks.append(level_block(level)) | |
def encode(self, xs): | |
zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)] | |
return zs | |
def decode(self, zs, start_level=0, end_level=None): | |
if end_level is None: | |
end_level = self.levels | |
xs_quantised = [level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs)] | |
return xs_quantised | |
def forward(self, xs): | |
zs, xs_quantised, commit_losses, metrics = [], [], [], [] | |
for level in range(self.levels): | |
level_block = self.level_blocks[level] | |
x = xs[level] # (32, 512, 30) | |
z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training) | |
''' | |
z: (32, 30) | |
x_quantised: (32, 512, 30) | |
commit_loss: 0.0666 | |
metric: same as models/vqvae.py | |
''' | |
zs.append(z) | |
if not self.training: | |
# Be extra paranoid and make sure the encoder weights can't | |
# change from straight-through estimator | |
x_quantised = x_quantised.detach() | |
xs_quantised.append(x_quantised) | |
commit_losses.append(commit_loss) | |
if self.training: | |
metrics.append(metric) | |
return zs, xs_quantised, commit_losses, metrics | |
class Residual_Bottleneck(nn.Module): | |
def __init__(self, l_bins, emb_width, mu, levels): | |
super().__init__() | |
self.levels = levels | |
self.residuals = 4 | |
level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu) | |
self.level_blocks = nn.ModuleList() | |
for level in range(self.levels): | |
self.level_blocks.append(level_block(level)) | |
for residual in range(self.residuals): | |
self.residual_blocks.append(level_block(residual)) | |
def encode(self, xs): | |
zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)] | |
return zs | |
def decode(self, zs, start_level=0, end_level=None): | |
if end_level is None: | |
end_level = self.levels | |
xs_quantised = [level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs)] | |
return xs_quantised | |
def forward(self, xs): | |
zs, xs_quantised, commit_losses, metrics = [], [], [], [] | |
for level in range(self.levels): | |
level_block = self.level_blocks[level] | |
x = xs[level] # (32, 512, 30) | |
residual = x | |
quantized_out = 0. | |
for residual_num in range(self.residuals): | |
residual_block = self.residual_blocks[residual_num] | |
z, x_quantised, commit_loss, metric = residual_block(x, update_k=self.training) | |
residual = residual - x_quantised.detach() | |
quantized_out = quantized_out + x_quantised | |
z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training) | |
''' | |
z: (32, 30) | |
x_quantised: (32, 512, 30) | |
commit_loss: 0.0666 | |
metric: same as models/vqvae.py | |
''' | |
zs.append(z) | |
if not self.training: | |
# Be extra paranoid and make sure the encoder weights can't | |
# change from straight-through estimator | |
x_quantised = x_quantised.detach() | |
xs_quantised.append(x_quantised) | |
commit_losses.append(commit_loss) | |
if self.training: | |
metrics.append(metric) | |
return zs, xs_quantised, commit_losses, metrics | |
class NoBottleneckBlock(nn.Module): | |
def restore_k(self): | |
pass | |
class NoBottleneck(nn.Module): | |
def __init__(self, levels): | |
super().__init__() | |
self.level_blocks = nn.ModuleList() | |
self.levels = levels | |
for level in range(levels): | |
self.level_blocks.append(NoBottleneckBlock()) | |
def encode(self, xs): | |
return xs | |
def decode(self, zs, start_level=0, end_level=None): | |
if end_level is None: | |
end_level = self.levels | |
return zs | |
def forward(self, xs): | |
zero = t.zeros(()).cuda() | |
commit_losses = [zero for _ in range(self.levels)] | |
metrics = [dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels)] | |
return xs, xs, commit_losses, metrics | |
if __name__ == '__main__': | |
''' | |
python -m models.bottleneck --config configs/sep_vqvae.yaml --train --no_cuda 2 --gpu 2 | |
''' | |
# x = [t.rand(32, 512, 30)] | |
# bottleneck = Bottleneck(512, 512, 0.99, 1).to(mydevice) | |
# zs, xs_quantised, commit_losses, quantiser_metrics = bottleneck(x) | |
x = t.rand(32, 512, 30) | |
model = BottleneckBlock(k_bins=512, emb_width=512, mu=0.99) | |
zs, xs_quantised, commit_losses, quantiser_metrics = model(x) | |