Spaces:
Running
on
T4
Running
on
T4
try to figure out how ZeroGPU works
Browse files
Architectures/ControllabilityGAN/GAN.py
CHANGED
@@ -21,14 +21,14 @@ class GanWrapper(torch.nn.Module):
|
|
21 |
|
22 |
self.z_list = list()
|
23 |
for _ in range(1100):
|
24 |
-
self.z_list.append(self.wgan.G.module.sample_latent(1, 32))
|
25 |
self.z = self.z_list[0]
|
26 |
|
27 |
def set_latent(self, seed):
|
28 |
self.z = self.z = self.z_list[seed]
|
29 |
|
30 |
def reset_default_latent(self):
|
31 |
-
self.z = self.wgan.G.module.sample_latent(1, 32)
|
32 |
|
33 |
def load_model(self, path):
|
34 |
gan_checkpoint = torch.load(path, map_location="cpu")
|
|
|
21 |
|
22 |
self.z_list = list()
|
23 |
for _ in range(1100):
|
24 |
+
self.z_list.append(self.wgan.G.module.sample_latent(1, 32).to("cpu"))
|
25 |
self.z = self.z_list[0]
|
26 |
|
27 |
def set_latent(self, seed):
|
28 |
self.z = self.z = self.z_list[seed]
|
29 |
|
30 |
def reset_default_latent(self):
|
31 |
+
self.z = self.wgan.G.module.sample_latent(1, 32).to("cpu")
|
32 |
|
33 |
def load_model(self, path):
|
34 |
gan_checkpoint = torch.load(path, map_location="cpu")
|
Architectures/ControllabilityGAN/wgan/wgan_qc.py
CHANGED
@@ -245,7 +245,7 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
|
|
245 |
latent_samples = latent_samples.to(self.device)
|
246 |
if nograd:
|
247 |
with torch.no_grad():
|
248 |
-
generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
|
249 |
else:
|
250 |
generated_data = self.G(latent_samples)
|
251 |
self.G.train()
|
|
|
245 |
latent_samples = latent_samples.to(self.device)
|
246 |
if nograd:
|
247 |
with torch.no_grad():
|
248 |
+
generated_data = self.G(latent_samples.to("cpu"), return_intermediate=return_intermediate)
|
249 |
else:
|
250 |
generated_data = self.G(latent_samples)
|
251 |
self.G.train()
|
app.py
CHANGED
@@ -21,11 +21,10 @@ from Utility.storage_config import MODELS_DIR
|
|
21 |
|
22 |
class ControllableInterface(torch.nn.Module):
|
23 |
|
24 |
-
@spaces.GPU
|
25 |
def __init__(self, available_artificial_voices=1000):
|
26 |
super().__init__()
|
27 |
self.model = ToucanTTSInterface(device="cpu", tts_model_path="Meta")
|
28 |
-
self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device="
|
29 |
self.generated_speaker_embeds = list()
|
30 |
self.available_artificial_voices = available_artificial_voices
|
31 |
self.current_language = ""
|
|
|
21 |
|
22 |
class ControllableInterface(torch.nn.Module):
|
23 |
|
|
|
24 |
def __init__(self, available_artificial_voices=1000):
|
25 |
super().__init__()
|
26 |
self.model = ToucanTTSInterface(device="cpu", tts_model_path="Meta")
|
27 |
+
self.wgan = GanWrapper(os.path.join(MODELS_DIR, "Embedding", "embedding_gan.pt"), device="cpu")
|
28 |
self.generated_speaker_embeds = list()
|
29 |
self.available_artificial_voices = available_artificial_voices
|
30 |
self.current_language = ""
|