import pdb import sys [sys.path.append(i) for i in ['.', '..']] sys.path.append("./models/qp_vqvae") sys.path.append("./models/qp_vqvae/utils") import numpy as np import torch as t import torch.nn as nn from .qp_vqvae.encdec import Encoder, Decoder, assert_shape from .qp_vqvae.bottleneck import NoBottleneck, Bottleneck from .qp_vqvae.utils.logger import average_metrics from .qp_vqvae.utils.torch_utils import parse_args import torch.nn.functional as F args = parse_args() mydevice = t.device('cuda:' + args.gpu) def dont_update(params): for param in params: param.requires_grad = False def update(params): for param in params: param.requires_grad = True def calculate_strides(strides, downs): return [stride ** down for stride, down in zip(strides, downs)] # def _loss_fn(loss_fn, x_target, x_pred, hps): # if loss_fn == 'l1': # return t.mean(t.abs(x_pred - x_target)) / hps.bandwidth['l1'] # elif loss_fn == 'l2': # return t.mean((x_pred - x_target) ** 2) / hps.bandwidth['l2'] # elif loss_fn == 'linf': # residual = ((x_pred - x_target) ** 2).reshape(x_target.shape[0], -1) # values, _ = t.topk(residual, hps.linf_k, dim=1) # return t.mean(values) / hps.bandwidth['l2'] # elif loss_fn == 'lmix': # loss = 0.0 # if hps.lmix_l1: # loss += hps.lmix_l1 * _loss_fn('l1', x_target, x_pred, hps) # if hps.lmix_l2: # loss += hps.lmix_l2 * _loss_fn('l2', x_target, x_pred, hps) # if hps.lmix_linf: # loss += hps.lmix_linf * _loss_fn('linf', x_target, x_pred, hps) # return loss # else: # assert False, f"Unknown loss_fn {loss_fn}" def _loss_fn(x_target, x_pred): smooth_l1_loss = nn.SmoothL1Loss(reduction='none') return smooth_l1_loss(x_pred,x_target).mean() #return t.mean(t.abs(x_pred - x_target)) class VQVAE(nn.Module): def __init__(self, hps, input_dim=72): super().__init__() self.hps = hps input_dim=hps.pose_dims input_shape = (hps.sample_length, input_dim) levels = hps.levels downs_t = hps.downs_t strides_t = hps.strides_t emb_width = hps.emb_width l_bins = hps.l_bins mu = hps.l_mu commit = hps.commit #root_weight = hps.root_weight # spectral = hps.spectral # multispectral = hps.multispectral multipliers = hps.hvqvae_multipliers use_bottleneck = hps.use_bottleneck if use_bottleneck: print('We use bottleneck!') else: print('We do not use bottleneck!') if not hasattr(hps, 'dilation_cycle'): hps.dilation_cycle = None block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ dilation_growth_rate=hps.dilation_growth_rate, \ dilation_cycle=hps.dilation_cycle, \ reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) self.sample_length = input_shape[0] x_shape, x_channels = input_shape[:-1], input_shape[-1] self.x_shape = x_shape self.downsamples = calculate_strides(strides_t, downs_t) self.hop_lengths = np.cumprod(self.downsamples) self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] self.levels = levels if multipliers is None: self.multipliers = [1] * levels else: assert len(multipliers) == levels, "Invalid number of multipliers" self.multipliers = multipliers def _block_kwargs(level): this_block_kwargs = dict(block_kwargs) this_block_kwargs["width"] *= self.multipliers[level] this_block_kwargs["depth"] *= self.multipliers[level] return this_block_kwargs encoder = lambda level: Encoder(x_channels, emb_width, level + 1, downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) # different from supplemental decoder = lambda level: Decoder(x_channels, emb_width, level + 1, downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() for level in range(levels): self.encoders.append(encoder(level)) self.decoders.append(decoder(level)) if use_bottleneck: self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels) # 512, 512, 0.99, 1 else: self.bottleneck = NoBottleneck(levels) self.downs_t = downs_t self.strides_t = strides_t self.l_bins = l_bins self.commit = commit #self.root_weight = root_weight self.reg = hps.reg if hasattr(hps, 'reg') else 0 self.acc = hps.acc if hasattr(hps, 'acc') else 0 self.vel = hps.vel if hasattr(hps, 'vel') else 0 if self.reg == 0: print('No motion regularization!') # self.spectral = spectral # self.multispectral = multispectral def preprocess(self, x): # x: NTC [-1,1] -> NCT [-1,1] assert len(x.shape) == 3 x = x.permute(0,2,1).float() return x def postprocess(self, x): # x: NTC [-1,1] <- NCT [-1,1] x = x.permute(0,2,1) return x def _decode(self, zs, start_level=0, end_level=None): # Decode if end_level is None: end_level = self.levels assert len(zs) == end_level - start_level xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level) assert len(xs_quantised) == end_level - start_level # Use only lowest level decoder, x_quantised = self.decoders[start_level], xs_quantised[0:1] x_out = decoder(x_quantised, all_levels=False) x_out = self.postprocess(x_out) return x_out def decode(self, zs, start_level=0, end_level=None, bs_chunks=1): z_chunks = [t.chunk(z, bs_chunks, dim=0) for z in zs] x_outs = [] for i in range(bs_chunks): zs_i = [z_chunk[i] for z_chunk in z_chunks] x_out = self._decode(zs_i, start_level=start_level, end_level=end_level) x_outs.append(x_out) return t.cat(x_outs, dim=0) def _encode(self, x, start_level=0, end_level=None): # Encode if end_level is None: end_level = self.levels x_in = self.preprocess(x) xs = [] for level in range(self.levels): encoder = self.encoders[level] x_out = encoder(x_in) xs.append(x_out[-1]) zs = self.bottleneck.encode(xs) return zs[start_level:end_level] def encode(self, x, start_level=0, end_level=None, bs_chunks=1): x_chunks = t.chunk(x, bs_chunks, dim=0) zs_list = [] for x_i in x_chunks: zs_i = self._encode(x_i, start_level=start_level, end_level=end_level) zs_list.append(zs_i) zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)] return zs def sample(self, n_samples): zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device=mydevice) for z_shape in self.z_shapes] return self.decode(zs) def forward(self, x): # ([256, 80, 282]) metrics = {} N = x.shape[0] # Encode/Decode x_in = self.preprocess(x) # ([256, 282, 80]) xs = [] for level in range(self.levels): encoder = self.encoders[level] x_out = encoder(x_in) xs.append(x_out[-1]) # xs[0]: (32, 512, 30) zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) #xs[0].shape=([256, 512, 5]) ''' zs[0]: (32, 30) xs_quantised[0]: (32, 512, 30) commit_losses[0]: 0.0009 quantiser_metrics[0]: fit 0.4646 pn 0.0791 entropy 5.9596 used_curr 512 usage 512 dk 0.0006 ''' x_outs = [] for level in range(self.levels): decoder = self.decoders[level] x_out = decoder(xs_quantised[level:level+1], all_levels=False) assert_shape(x_out, x_in.shape) x_outs.append(x_out) # x_outs[0]: (32, 45, 240) # Loss # def _spectral_loss(x_target, x_out, self.hps): # if hps.use_nonrelative_specloss: # sl = spectral_loss(x_target, x_out, self.hps) / hps.bandwidth['spec'] # else: # sl = spectral_convergence(x_target, x_out, self.hps) # sl = t.mean(sl) # return sl # def _multispectral_loss(x_target, x_out, self.hps): # sl = multispectral_loss(x_target, x_out, self.hps) / hps.bandwidth['spec'] # sl = t.mean(sl) # return sl recons_loss = t.zeros(()).cuda() regularization = t.zeros(()).cuda() velocity_loss = t.zeros(()).cuda() acceleration_loss = t.zeros(()).cuda() # spec_loss = t.zeros(()).to(x.device) # multispec_loss = t.zeros(()).to(x.device) # x_target = audio_postprocess(x.float(), self.hps) x_target = x.float() for level in reversed(range(self.levels)): x_out = self.postprocess(x_outs[level]) # (32, 240, 45) # x_out = audio_postprocess(x_out, self.hps) # scale_factor = t.ones(self.hps.pose_dims).to(x_target.device) # scale_factor[:3]=self.root_weight # x_target = x_target * scale_factor # x_out = x_out * scale_factor # this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps) this_recons_loss = _loss_fn(x_target, x_out) # this_spec_loss = _spectral_loss(x_target, x_out, hps) # this_multispec_loss = _multispectral_loss(x_target, x_out, hps) metrics[f'recons_loss_l{level + 1}'] = this_recons_loss # metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss # metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss recons_loss += this_recons_loss # spec_loss += this_spec_loss # multispec_loss += this_multispec_loss regularization += t.mean((x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1])**2) velocity_loss += _loss_fn( x_out[:, 1:] - x_out[:, :-1], x_target[:, 1:] - x_target[:, :-1]) acceleration_loss += _loss_fn(x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1], x_target[:, 2:] + x_target[:, :-2] - 2 * x_target[:, 1:-1]) # if not hasattr(self.) commit_loss = sum(commit_losses) # loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss # pdb.set_trace() loss = recons_loss + commit_loss * self.commit + self.reg * regularization + self.vel * velocity_loss + self.acc * acceleration_loss ''' x:-0.8474 ~ 1.1465 0.2080 5.5e-5 * 0.02 0.0011 0.0163 * 1 0.0274 * 1 ''' encodings = F.one_hot(zs[0].reshape(-1), self.hps.l_bins).float() avg_probs = t.mean(encodings, dim=0) perplexity = t.exp(-t.sum(avg_probs * t.log(avg_probs + 1e-10))) with t.no_grad(): # sc = t.mean(spectral_convergence(x_target, x_out, hps)) # l2_loss = _loss_fn("l2", x_target, x_out, hps) l1_loss = _loss_fn(x_target, x_out) # linf_loss = _loss_fn("linf", x_target, x_out, hps) quantiser_metrics = average_metrics(quantiser_metrics) metrics.update(dict( loss = loss, recons_loss=recons_loss, # spectral_loss=spec_loss, # multispectral_loss=multispec_loss, # spectral_convergence=sc, # l2_loss=l2_loss, l1_loss=l1_loss, # linf_loss=linf_loss, commit_loss=commit_loss, regularization=regularization, velocity_loss=velocity_loss, acceleration_loss=acceleration_loss, perplexity=perplexity, **quantiser_metrics)) for key, val in metrics.items(): metrics[key] = val.detach() return { # "poses_feat":vq_latent, # "embedding_loss":embedding_loss, # "perplexity":perplexity, "rec_pose": x_out, "loss": loss, "metrics": metrics, "embedding_loss": commit_loss * self.commit, } class VQVAE_Encoder(nn.Module): def __init__(self, hps, input_dim=72): super().__init__() self.hps = hps input_dim=hps.pose_dims input_shape = (hps.sample_length, input_dim) levels = hps.levels downs_t = hps.downs_t strides_t = hps.strides_t emb_width = hps.emb_width l_bins = hps.l_bins mu = hps.l_mu commit = hps.commit # spectral = hps.spectral # multispectral = hps.multispectral multipliers = hps.hvqvae_multipliers use_bottleneck = hps.use_bottleneck if use_bottleneck: print('We use bottleneck!') else: print('We do not use bottleneck!') if not hasattr(hps, 'dilation_cycle'): hps.dilation_cycle = None block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ dilation_growth_rate=hps.dilation_growth_rate, \ dilation_cycle=hps.dilation_cycle, \ reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) self.sample_length = input_shape[0] x_shape, x_channels = input_shape[:-1], input_shape[-1] self.x_shape = x_shape self.downsamples = calculate_strides(strides_t, downs_t) self.hop_lengths = np.cumprod(self.downsamples) self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] self.levels = levels if multipliers is None: self.multipliers = [1] * levels else: assert len(multipliers) == levels, "Invalid number of multipliers" self.multipliers = multipliers def _block_kwargs(level): this_block_kwargs = dict(block_kwargs) this_block_kwargs["width"] *= self.multipliers[level] this_block_kwargs["depth"] *= self.multipliers[level] return this_block_kwargs encoder = lambda level: Encoder(x_channels, emb_width, level + 1, downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) # different from supplemental decoder = lambda level: Decoder(x_channels, emb_width, level + 1, downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() for level in range(levels): self.encoders.append(encoder(level)) self.decoders.append(decoder(level)) if use_bottleneck: self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels) # 512, 512, 0.99, 1 else: self.bottleneck = NoBottleneck(levels) self.downs_t = downs_t self.strides_t = strides_t self.l_bins = l_bins self.commit = commit self.reg = hps.reg if hasattr(hps, 'reg') else 0 self.acc = hps.acc if hasattr(hps, 'acc') else 0 self.vel = hps.vel if hasattr(hps, 'vel') else 0 if self.reg == 0: print('No motion regularization!') # self.spectral = spectral # self.multispectral = multispectral def preprocess(self, x): # x: NTC [-1,1] -> NCT [-1,1] assert len(x.shape) == 3 x = x.permute(0,2,1).float() return x def postprocess(self, x): # x: NTC [-1,1] <- NCT [-1,1] x = x.permute(0,2,1) return x def sample(self, n_samples): zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device=mydevice) for z_shape in self.z_shapes] return self.decode(zs) def forward(self, x): # ([256, 80, 282]) metrics = {} N = x.shape[0] # Encode/Decode x_in = self.preprocess(x) xs = [] for level in range(self.levels): encoder = self.encoders[level] x_out = encoder(x_in) xs.append(x_out[-1]) # xs[0]: (32, 512, 30) zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) #xs[0].shape=([256, 512, 5]) return zs[0],xs[0] , xs_quantised[0] class VQVAE_Decoder(nn.Module): def __init__(self, hps, input_dim=72): super().__init__() self.hps = hps input_dim=hps.pose_dims input_shape = (hps.sample_length, input_dim) levels = hps.levels downs_t = hps.downs_t strides_t = hps.strides_t emb_width = hps.emb_width l_bins = hps.l_bins mu = hps.l_mu commit = hps.commit # spectral = hps.spectral # multispectral = hps.multispectral multipliers = hps.hvqvae_multipliers use_bottleneck = hps.use_bottleneck if use_bottleneck: print('We use bottleneck!') else: print('We do not use bottleneck!') if not hasattr(hps, 'dilation_cycle'): hps.dilation_cycle = None block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ dilation_growth_rate=hps.dilation_growth_rate, \ dilation_cycle=hps.dilation_cycle, \ reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) self.sample_length = input_shape[0] x_shape, x_channels = input_shape[:-1], input_shape[-1] self.x_shape = x_shape self.downsamples = calculate_strides(strides_t, downs_t) self.hop_lengths = np.cumprod(self.downsamples) self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] self.levels = levels if multipliers is None: self.multipliers = [1] * levels else: assert len(multipliers) == levels, "Invalid number of multipliers" self.multipliers = multipliers def _block_kwargs(level): this_block_kwargs = dict(block_kwargs) this_block_kwargs["width"] *= self.multipliers[level] this_block_kwargs["depth"] *= self.multipliers[level] return this_block_kwargs encoder = lambda level: Encoder(x_channels, emb_width, level + 1, downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) # different from supplemental decoder = lambda level: Decoder(x_channels, emb_width, level + 1, downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() for level in range(levels): self.encoders.append(encoder(level)) self.decoders.append(decoder(level)) if use_bottleneck: self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels) # 512, 512, 0.99, 1 else: self.bottleneck = NoBottleneck(levels) self.downs_t = downs_t self.strides_t = strides_t self.l_bins = l_bins self.commit = commit self.reg = hps.reg if hasattr(hps, 'reg') else 0 self.acc = hps.acc if hasattr(hps, 'acc') else 0 self.vel = hps.vel if hasattr(hps, 'vel') else 0 if self.reg == 0: print('No motion regularization!') # self.spectral = spectral # self.multispectral = multispectral def preprocess(self, x): # x: NTC [-1,1] -> NCT [-1,1] assert len(x.shape) == 3 x = x.permute(0,2,1).float() return x def postprocess(self, x): # x: NTC [-1,1] <- NCT [-1,1] x = x.permute(0,2,1) return x def forward(self, xs): # ([256, 80, 282]) xs=[xs] zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) x_outs = [] for level in range(self.levels): decoder = self.decoders[level] x_out = decoder(xs_quantised[level:level+1], all_levels=False) x_outs.append(x_out) for level in reversed(range(self.levels)): x_out = self.postprocess(x_outs[level]) # (32, 240, 45) return x_out class Residual_VQVAE(nn.Module): def __init__(self, hps, input_dim=72): super().__init__() self.hps = hps input_dim=hps.pose_dims input_shape = (hps.sample_length, input_dim) levels = hps.levels downs_t = hps.downs_t strides_t = hps.strides_t emb_width = hps.emb_width l_bins = hps.l_bins mu = hps.l_mu commit = hps.commit root_weight = hps.root_weight # spectral = hps.spectral # multispectral = hps.multispectral multipliers = hps.hvqvae_multipliers use_bottleneck = hps.use_bottleneck if use_bottleneck: print('We use bottleneck!') else: print('We do not use bottleneck!') if not hasattr(hps, 'dilation_cycle'): hps.dilation_cycle = None block_kwargs = dict(width=hps.width, depth=hps.depth, m_conv=hps.m_conv, \ dilation_growth_rate=hps.dilation_growth_rate, \ dilation_cycle=hps.dilation_cycle, \ reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) self.sample_length = input_shape[0] x_shape, x_channels = input_shape[:-1], input_shape[-1] self.x_shape = x_shape self.downsamples = calculate_strides(strides_t, downs_t) self.hop_lengths = np.cumprod(self.downsamples) self.z_shapes = z_shapes = [(x_shape[0] // self.hop_lengths[level],) for level in range(levels)] self.levels = levels if multipliers is None: self.multipliers = [1] * levels else: assert len(multipliers) == levels, "Invalid number of multipliers" self.multipliers = multipliers def _block_kwargs(level): this_block_kwargs = dict(block_kwargs) this_block_kwargs["width"] *= self.multipliers[level] this_block_kwargs["depth"] *= self.multipliers[level] return this_block_kwargs encoder = lambda level: Encoder(x_channels, emb_width, level + 1, downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) # different from supplemental decoder = lambda level: Decoder(x_channels, emb_width, level + 1, downs_t[:level+1], strides_t[:level+1], **_block_kwargs(level)) self.encoders = nn.ModuleList() self.decoders = nn.ModuleList() for level in range(levels): self.encoders.append(encoder(level)) self.decoders.append(decoder(level)) if use_bottleneck: self.bottleneck = Bottleneck(l_bins, emb_width, mu, levels) # 512, 512, 0.99, 1 else: self.bottleneck = NoBottleneck(levels) self.downs_t = downs_t self.strides_t = strides_t self.l_bins = l_bins self.commit = commit self.root_weight = root_weight self.reg = hps.reg if hasattr(hps, 'reg') else 0 self.acc = hps.acc if hasattr(hps, 'acc') else 0 self.vel = hps.vel if hasattr(hps, 'vel') else 0 if self.reg == 0: print('No motion regularization!') # self.spectral = spectral # self.multispectral = multispectral def preprocess(self, x): # x: NTC [-1,1] -> NCT [-1,1] assert len(x.shape) == 3 x = x.permute(0,2,1).float() return x def postprocess(self, x): # x: NTC [-1,1] <- NCT [-1,1] x = x.permute(0,2,1) return x def _decode(self, zs, start_level=0, end_level=None): # Decode if end_level is None: end_level = self.levels assert len(zs) == end_level - start_level xs_quantised = self.bottleneck.decode(zs, start_level=start_level, end_level=end_level) assert len(xs_quantised) == end_level - start_level # Use only lowest level decoder, x_quantised = self.decoders[start_level], xs_quantised[0:1] x_out = decoder(x_quantised, all_levels=False) x_out = self.postprocess(x_out) return x_out def decode(self, zs, start_level=0, end_level=None, bs_chunks=1): z_chunks = [t.chunk(z, bs_chunks, dim=0) for z in zs] x_outs = [] for i in range(bs_chunks): zs_i = [z_chunk[i] for z_chunk in z_chunks] x_out = self._decode(zs_i, start_level=start_level, end_level=end_level) x_outs.append(x_out) return t.cat(x_outs, dim=0) def _encode(self, x, start_level=0, end_level=None): # Encode if end_level is None: end_level = self.levels x_in = self.preprocess(x) xs = [] for level in range(self.levels): encoder = self.encoders[level] x_out = encoder(x_in) xs.append(x_out[-1]) zs = self.bottleneck.encode(xs) return zs[start_level:end_level] def encode(self, x, start_level=0, end_level=None, bs_chunks=1): x_chunks = t.chunk(x, bs_chunks, dim=0) zs_list = [] for x_i in x_chunks: zs_i = self._encode(x_i, start_level=start_level, end_level=end_level) zs_list.append(zs_i) zs = [t.cat(zs_level_list, dim=0) for zs_level_list in zip(*zs_list)] return zs def sample(self, n_samples): zs = [t.randint(0, self.l_bins, size=(n_samples, *z_shape), device=mydevice) for z_shape in self.z_shapes] return self.decode(zs) def forward(self, x): # ([256, 80, 282]) metrics = {} N = x.shape[0] # Encode/Decode x_in = self.preprocess(x) # ([256, 282, 80]) xs = [] for level in range(self.levels): encoder = self.encoders[level] x_out = encoder(x_in) xs.append(x_out[-1]) # xs[0]: (32, 512, 30) zs, xs_quantised, commit_losses, quantiser_metrics = self.bottleneck(xs) #xs[0].shape=([256, 512, 5]) ''' zs[0]: (32, 30) xs_quantised[0]: (32, 512, 30) commit_losses[0]: 0.0009 quantiser_metrics[0]: fit 0.4646 pn 0.0791 entropy 5.9596 used_curr 512 usage 512 dk 0.0006 ''' x_outs = [] for level in range(self.levels): decoder = self.decoders[level] x_out = decoder(xs_quantised[level:level+1], all_levels=False) assert_shape(x_out, x_in.shape) x_outs.append(x_out) # x_outs[0]: (32, 45, 240) # Loss # def _spectral_loss(x_target, x_out, self.hps): # if hps.use_nonrelative_specloss: # sl = spectral_loss(x_target, x_out, self.hps) / hps.bandwidth['spec'] # else: # sl = spectral_convergence(x_target, x_out, self.hps) # sl = t.mean(sl) # return sl # def _multispectral_loss(x_target, x_out, self.hps): # sl = multispectral_loss(x_target, x_out, self.hps) / hps.bandwidth['spec'] # sl = t.mean(sl) # return sl recons_loss = t.zeros(()).cuda() regularization = t.zeros(()).cuda() velocity_loss = t.zeros(()).cuda() acceleration_loss = t.zeros(()).cuda() # spec_loss = t.zeros(()).to(x.device) # multispec_loss = t.zeros(()).to(x.device) # x_target = audio_postprocess(x.float(), self.hps) x_target = x.float() for level in reversed(range(self.levels)): x_out = self.postprocess(x_outs[level]) # (32, 240, 45) # x_out = audio_postprocess(x_out, self.hps) scale_factor = t.ones(self.hps.pose_dims).to(x_target.device) scale_factor[:3]=self.root_weight x_target = x_target * scale_factor x_out = x_out * scale_factor # this_recons_loss = _loss_fn(loss_fn, x_target, x_out, hps) this_recons_loss = _loss_fn(x_target, x_out) # this_spec_loss = _spectral_loss(x_target, x_out, hps) # this_multispec_loss = _multispectral_loss(x_target, x_out, hps) metrics[f'recons_loss_l{level + 1}'] = this_recons_loss # metrics[f'spectral_loss_l{level + 1}'] = this_spec_loss # metrics[f'multispectral_loss_l{level + 1}'] = this_multispec_loss recons_loss += this_recons_loss # spec_loss += this_spec_loss # multispec_loss += this_multispec_loss regularization += t.mean((x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1])**2) velocity_loss += _loss_fn( x_out[:, 1:] - x_out[:, :-1], x_target[:, 1:] - x_target[:, :-1]) acceleration_loss += _loss_fn(x_out[:, 2:] + x_out[:, :-2] - 2 * x_out[:, 1:-1], x_target[:, 2:] + x_target[:, :-2] - 2 * x_target[:, 1:-1]) # if not hasattr(self.) commit_loss = sum(commit_losses) # loss = recons_loss + self.spectral * spec_loss + self.multispectral * multispec_loss + self.commit * commit_loss # pdb.set_trace() loss = recons_loss + commit_loss * self.commit + self.reg * regularization + self.vel * velocity_loss + self.acc * acceleration_loss ''' x:-0.8474 ~ 1.1465 0.2080 5.5e-5 * 0.02 0.0011 0.0163 * 1 0.0274 * 1 ''' encodings = F.one_hot(zs[0].reshape(-1), self.hps.l_bins).float() avg_probs = t.mean(encodings, dim=0) perplexity = t.exp(-t.sum(avg_probs * t.log(avg_probs + 1e-10))) with t.no_grad(): # sc = t.mean(spectral_convergence(x_target, x_out, hps)) # l2_loss = _loss_fn("l2", x_target, x_out, hps) l1_loss = _loss_fn(x_target, x_out) # linf_loss = _loss_fn("linf", x_target, x_out, hps) quantiser_metrics = average_metrics(quantiser_metrics) metrics.update(dict( loss = loss, recons_loss=recons_loss, # spectral_loss=spec_loss, # multispectral_loss=multispec_loss, # spectral_convergence=sc, # l2_loss=l2_loss, l1_loss=l1_loss, # linf_loss=linf_loss, commit_loss=commit_loss, regularization=regularization, velocity_loss=velocity_loss, acceleration_loss=acceleration_loss, perplexity=perplexity, **quantiser_metrics)) for key, val in metrics.items(): metrics[key] = val.detach() return { # "poses_feat":vq_latent, # "embedding_loss":embedding_loss, # "perplexity":perplexity, "rec_pose": x_out, "loss": loss, "metrics": metrics, "embedding_loss": commit_loss * self.commit, } if __name__ == '__main__': ''' cd codebook/ python vqvae.py --config=./codebook.yml --train --no_cuda 2 --gpu 2 ''' import yaml from pprint import pprint from easydict import EasyDict with open(args.config) as f: config = yaml.safe_load(f) for k, v in vars(args).items(): config[k] = v pprint(config) config = EasyDict(config) x = t.rand(32, 40, 15 * 9).to(mydevice) model = VQVAE(config.VQVAE, 15 * 9) # n_joints * n_chanels model = nn.DataParallel(model, device_ids=[eval(i) for i in config.no_cuda]) model = model.to(mydevice) model = model.train() output, loss, metrics = model(x) pdb.set_trace()