Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from Architectures.ControllabilityGAN.wgan.init_wgan import create_wgan | |
class GanWrapper: | |
def __init__(self, path_wgan, device): | |
self.device = device | |
self.path_wgan = path_wgan | |
self.mean = None | |
self.std = None | |
self.wgan = None | |
self.normalize = True | |
self.load_model(path_wgan) | |
self.U = self.compute_controllability() | |
self.z_list = list() | |
for _ in range(1100): | |
self.z_list.append(self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8)) | |
self.z = self.z_list[0] | |
def set_latent(self, seed): | |
self.z = self.z = self.z_list[seed] | |
def reset_default_latent(self): | |
self.z = self.wgan.G.sample_latent(1, self.wgan.G.z_dim, temperature=0.8) | |
def load_model(self, path): | |
gan_checkpoint = torch.load(path, map_location="cpu") | |
self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device) | |
# Create a new state dict without 'module.' prefix | |
new_state_dict_G = {} | |
for key, value in gan_checkpoint['generator_state_dict'].items(): | |
# Remove 'module.' prefix | |
new_key = key.replace('module.', '') | |
new_state_dict_G[new_key] = value | |
new_state_dict_D = {} | |
for key, value in gan_checkpoint['critic_state_dict'].items(): | |
# Remove 'module.' prefix | |
new_key = key.replace('module.', '') | |
new_state_dict_D[new_key] = value | |
self.wgan.G.load_state_dict(new_state_dict_G) | |
self.wgan.D.load_state_dict(new_state_dict_D) | |
self.mean = gan_checkpoint["dataset_mean"] | |
self.std = gan_checkpoint["dataset_std"] | |
def compute_controllability(self, n_samples=100000): | |
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True) | |
intermediate = intermediate.cpu() | |
z = z.cpu() | |
U = self.controllable_speakers(intermediate, z) | |
return U | |
def controllable_speakers(self, intermediate, z): | |
pca = torch.pca_lowrank(intermediate) | |
mu = intermediate.mean() | |
X = torch.matmul((intermediate - mu), pca[2]) | |
U = torch.linalg.lstsq(X, z) | |
return U | |
def get_original_embed(self): | |
self.wgan.G.eval() | |
embed_original = self.wgan.G.module.forward(self.z.to(self.device)) | |
if self.normalize: | |
embed_original = inverse_normalize( | |
embed_original.cpu(), | |
self.mean.cpu().unsqueeze(0), | |
self.std.cpu().unsqueeze(0) | |
) | |
return embed_original | |
def modify_embed(self, x): | |
self.wgan.G.eval() | |
z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x) | |
embed_modified = self.wgan.G.forward(z_new.unsqueeze(0).to(self.device)) | |
if self.normalize: | |
embed_modified = inverse_normalize( | |
embed_modified.cpu(), | |
self.mean.cpu().unsqueeze(0), | |
self.std.cpu().unsqueeze(0) | |
) | |
return embed_modified | |
def inverse_normalize(tensor, mean, std): | |
return tensor * std + mean | |