Flux9665 commited on
Commit
a0e00eb
β€’
1 Parent(s): 4daea3f

try to figure out how ZeroGPU works

Browse files
Architectures/ControllabilityGAN/wgan/wgan_qc.py CHANGED
@@ -245,7 +245,10 @@ class WassersteinGanQuadraticCost(torch.nn.Module):
245
  latent_samples = latent_samples.to(self.device)
246
  if nograd:
247
  with torch.no_grad():
248
- generated_data = self.G(latent_samples.to("cpu"), return_intermediate=return_intermediate)
 
 
 
249
  else:
250
  generated_data = self.G(latent_samples)
251
  self.G.train()
 
245
  latent_samples = latent_samples.to(self.device)
246
  if nograd:
247
  with torch.no_grad():
248
+ if isinstance(self.G, torch.nn.parallel.DataParallel):
249
+ generated_data = self.G.module(latent_samples.to("cpu"), return_intermediate=return_intermediate)
250
+ else:
251
+ generated_data = self.G(latent_samples.to("cpu"), return_intermediate=return_intermediate)
252
  else:
253
  generated_data = self.G(latent_samples)
254
  self.G.train()