Flux9665 commited on
Commit
791a0ff
·
1 Parent(s): 339d2c6

use explicit code instead of relying on release download

Browse files
Architectures/ControllabilityGAN/wgan/wgan_qc.py CHANGED
@@ -242,6 +242,7 @@ class WassersteinGanQuadraticCost:
242
  else:
243
  latent_samples = self.G.sample_latent(num_samples, self.G.z_dim)
244
  latent_samples = latent_samples.to(self.device)
 
245
  if nograd:
246
  with torch.no_grad():
247
  generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
 
242
  else:
243
  latent_samples = self.G.sample_latent(num_samples, self.G.z_dim)
244
  latent_samples = latent_samples.to(self.device)
245
+ print(self.device)
246
  if nograd:
247
  with torch.no_grad():
248
  generated_data = self.G(latent_samples, return_intermediate=return_intermediate)
app.py CHANGED
@@ -136,6 +136,8 @@ def read(prompt,
136
  if torch.cuda.is_available():
137
  controllable_ui.to("cuda")
138
  controllable_ui.device = "cuda"
 
 
139
  try:
140
  sr, wav, fig = controllable_ui.read(prompt,
141
  language.split(" ")[-1].split("(")[1].split(")")[0],
@@ -155,6 +157,8 @@ def read(prompt,
155
  finally:
156
  controllable_ui.to("cpu")
157
  controllable_ui.device = "cpu"
 
 
158
  return (sr, float2pcm(wav)), fig
159
 
160
 
 
136
  if torch.cuda.is_available():
137
  controllable_ui.to("cuda")
138
  controllable_ui.device = "cuda"
139
+ controllable_ui.model.device = "cuda"
140
+ controllable_ui.wgan.device = "cuda"
141
  try:
142
  sr, wav, fig = controllable_ui.read(prompt,
143
  language.split(" ")[-1].split("(")[1].split(")")[0],
 
157
  finally:
158
  controllable_ui.to("cpu")
159
  controllable_ui.device = "cpu"
160
+ controllable_ui.model.device = "cpu"
161
+ controllable_ui.wgan.device = "cpu"
162
  return (sr, float2pcm(wav)), fig
163
 
164