Spaces:
Running
on
T4
Running
on
T4
File size: 2,759 Bytes
9e275b8 ab12c36 9e275b8 ab12c36 9e275b8 4daea3f 9e275b8 4daea3f 9e275b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
from Architectures.ControllabilityGAN.wgan.init_wgan import create_wgan
class GanWrapper(torch.nn.Module):
def __init__(self, path_wgan, device, *args, **kwargs):
super().__init__(*args, **kwargs)
self.device = device
self.path_wgan = path_wgan
self.mean = None
self.std = None
self.wgan = None
self.normalize = True
self.load_model(path_wgan)
self.U = self.compute_controllability()
self.z_list = list()
for _ in range(1100):
self.z_list.append(self.wgan.G.module.sample_latent(1, 32).to("cpu"))
self.z = self.z_list[0]
def set_latent(self, seed):
self.z = self.z = self.z_list[seed]
def reset_default_latent(self):
self.z = self.wgan.G.module.sample_latent(1, 32).to("cpu")
def load_model(self, path):
gan_checkpoint = torch.load(path, map_location="cpu")
self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
self.wgan.G.load_state_dict(gan_checkpoint['generator_state_dict'])
self.wgan.D.load_state_dict(gan_checkpoint['critic_state_dict'])
self.mean = gan_checkpoint["dataset_mean"]
self.std = gan_checkpoint["dataset_std"]
def compute_controllability(self, n_samples=50000):
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
intermediate = intermediate.cpu()
z = z.cpu()
U = self.controllable_speakers(intermediate, z)
return U
def controllable_speakers(self, intermediate, z):
pca = torch.pca_lowrank(intermediate)
mu = intermediate.mean()
X = torch.matmul((intermediate - mu), pca[2])
U = torch.linalg.lstsq(X, z)
return U
def get_original_embed(self):
self.wgan.G.eval()
embed_original = self.wgan.G.module.forward(self.z.to(self.device))
if self.normalize:
embed_original = inverse_normalize(
embed_original.cpu(),
self.mean.cpu().unsqueeze(0),
self.std.cpu().unsqueeze(0)
)
return embed_original
def modify_embed(self, x):
self.wgan.G.eval()
z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x)
embed_modified = self.wgan.G.module.forward(z_new.unsqueeze(0).to(self.device))
if self.normalize:
embed_modified = inverse_normalize(
embed_modified.cpu(),
self.mean.cpu().unsqueeze(0),
self.std.cpu().unsqueeze(0)
)
return embed_modified
def inverse_normalize(tensor, mean, std):
return tensor * std + mean
|