Flux9665's picture
try to figure out how ZeroGPU works
4daea3f
raw
history blame
2.76 kB
import torch
from Architectures.ControllabilityGAN.wgan.init_wgan import create_wgan
class GanWrapper(torch.nn.Module):
def __init__(self, path_wgan, device, *args, **kwargs):
super().__init__(*args, **kwargs)
self.device = device
self.path_wgan = path_wgan
self.mean = None
self.std = None
self.wgan = None
self.normalize = True
self.load_model(path_wgan)
self.U = self.compute_controllability()
self.z_list = list()
for _ in range(1100):
self.z_list.append(self.wgan.G.module.sample_latent(1, 32).to("cpu"))
self.z = self.z_list[0]
def set_latent(self, seed):
self.z = self.z = self.z_list[seed]
def reset_default_latent(self):
self.z = self.wgan.G.module.sample_latent(1, 32).to("cpu")
def load_model(self, path):
gan_checkpoint = torch.load(path, map_location="cpu")
self.wgan = create_wgan(parameters=gan_checkpoint['model_parameters'], device=self.device)
self.wgan.G.load_state_dict(gan_checkpoint['generator_state_dict'])
self.wgan.D.load_state_dict(gan_checkpoint['critic_state_dict'])
self.mean = gan_checkpoint["dataset_mean"]
self.std = gan_checkpoint["dataset_std"]
def compute_controllability(self, n_samples=50000):
_, intermediate, z = self.wgan.sample_generator(num_samples=n_samples, nograd=True, return_intermediate=True)
intermediate = intermediate.cpu()
z = z.cpu()
U = self.controllable_speakers(intermediate, z)
return U
def controllable_speakers(self, intermediate, z):
pca = torch.pca_lowrank(intermediate)
mu = intermediate.mean()
X = torch.matmul((intermediate - mu), pca[2])
U = torch.linalg.lstsq(X, z)
return U
def get_original_embed(self):
self.wgan.G.eval()
embed_original = self.wgan.G.module.forward(self.z.to(self.device))
if self.normalize:
embed_original = inverse_normalize(
embed_original.cpu(),
self.mean.cpu().unsqueeze(0),
self.std.cpu().unsqueeze(0)
)
return embed_original
def modify_embed(self, x):
self.wgan.G.eval()
z_new = self.z.squeeze() + torch.matmul(self.U.solution.t(), x)
embed_modified = self.wgan.G.module.forward(z_new.unsqueeze(0).to(self.device))
if self.normalize:
embed_modified = inverse_normalize(
embed_modified.cpu(),
self.mean.cpu().unsqueeze(0),
self.std.cpu().unsqueeze(0)
)
return embed_modified
def inverse_normalize(tensor, mean, std):
return tensor * std + mean