Flux9665 commited on
Commit
4daea3f
·
1 Parent(s): a6e24ad

try to figure out how ZeroGPU works

Browse files
Architectures/ControllabilityGAN/GAN.py CHANGED
@@ -21,14 +21,14 @@ class GanWrapper(torch.nn.Module):
21
 
22
  self.z_list = list()
23
  for _ in range(1100):
24
- self.z_list.append(self.wgan.G.module.sample_latent(1, 32))
25
  self.z = self.z_list[0]
26
 
27
  def set_latent(self, seed):
28
  self.z = self.z = self.z_list[seed]
29
 
30
  def reset_default_latent(self):
31
- self.z = self.wgan.G.module.sample_latent(1, 32)
32
 
33
  def load_model(self, path):
34
  gan_checkpoint = torch.load(path, map_location="cpu")
 
21
 
22
  self.z_list = list()
23
  for _ in range(1100):
24
+ self.z_list.append(self.wgan.G.module.sample_latent(1, 32).to("cpu"))
25
  self.z = self.z_list[0]
26
 
27
  def set_latent(self, seed):
28
  self.z = self.z = self.z_list[seed]
29
 
30
  def reset_default_latent(self):
31
+ self.z = self.wgan.G.module.sample_latent(1, 32).to("cpu")
32
 
33
  def load_model(self, path):
34
  gan_checkpoint = torch.load(path, map_location="cpu")
Architectures/ControllabilityGAN/wgan/wgan_qc.py CHANGED
@@ -245,7 +245,7 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
245
  latent_samples = latent_samples.to(self.device)
246
  if nograd:
247
  with torch.no_grad():
248
- generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
249
  else:
250
  generated_data = self.G(latent_samples)
251
  self.G.train()
 
245
  latent_samples = latent_samples.to(self.device)
246
  if nograd:
247
  with torch.no_grad():
248
+ generated_data = self.G(latent_samples.to("cpu"), return_intermediate=return_intermediate)
249
  else:
250
  generated_data = self.G(latent_samples)
251
  self.G.train()
app.py CHANGED
@@ -21,11 +21,10 @@ from Utility.storage_config import MODELS_DIR
21
 
22
  class ControllableInterface(torch.nn.Module):
23
 
24
- @spaces.GPU
25
  def __init__(self, available_artificial_voices=1000):
26
  super().__init__()
27
  self.model = ToucanTTSInterface(device="cpu", tts_model_path="Meta")
28
- self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device="cuda")
29
  self.generated_speaker_embeds = list()
30
  self.available_artificial_voices = available_artificial_voices
31
  self.current_language = ""
 
21
 
22
  class ControllableInterface(torch.nn.Module):
23
 
 
24
  def __init__(self, available_artificial_voices=1000):
25
  super().__init__()
26
  self.model = ToucanTTSInterface(device="cpu", tts_model_path="Meta")
27
+ self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device="cpu")
28
  self.generated_speaker_embeds = list()
29
  self.available_artificial_voices = available_artificial_voices
30
  self.current_language = ""