Spaces:
Running
on
T4
Running
on
T4
try to figure out how ZeroGPU works
Browse files
Architectures/ControllabilityGAN/wgan/wgan_qc.py
CHANGED
@@ -245,7 +245,10 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
|
|
245 |
latent_samples = latent_samples.to(self.device)
|
246 |
if nograd:
|
247 |
with torch.no_grad():
|
248 |
-
|
|
|
|
|
|
|
249 |
else:
|
250 |
generated_data = self.G(latent_samples)
|
251 |
self.G.train()
|
|
|
245 |
latent_samples = latent_samples.to(self.device)
|
246 |
if nograd:
|
247 |
with torch.no_grad():
|
248 |
+
if isinstance(self.G, torch.nn.parallel.DataParallel):
|
249 |
+
generated_data = self.G.module(latent_samples.to("cpu"), return_intermediate=return_intermediate)
|
250 |
+
else:
|
251 |
+
generated_data = self.G(latent_samples.to("cpu"), return_intermediate=return_intermediate)
|
252 |
else:
|
253 |
generated_data = self.G(latent_samples)
|
254 |
self.G.train()
|