tfwang commited on
Commit
ccf29e8
·
1 Parent(s): 9d92961

Update glide_text2im/adv.py

Browse files
Files changed (1) hide show
  1. glide_text2im/adv.py +5 -5
glide_text2im/adv.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn.functional as F
4
  import torch.optim as optim
5
  from torch.nn.parallel.distributed import DistributedDataParallel as DDP
6
  from .nn import mean_flat
7
- from . import dist_util
8
  import functools
9
 
10
  class AdversarialLoss(nn.Module):
@@ -16,11 +16,11 @@ class AdversarialLoss(nn.Module):
16
  self.gan_type = gan_type
17
  self.gan_k = gan_k
18
 
19
- model = NLayerDiscriminator().to(dist_util.dev())
20
  self.discriminator = DDP(
21
  model,
22
- device_ids=[dist_util.dev()],
23
- output_device=dist_util.dev(),
24
  broadcast_buffers=False,
25
  bucket_cap_mb=128,
26
  find_unused_parameters=False,
@@ -41,7 +41,7 @@ class AdversarialLoss(nn.Module):
41
  if (self.gan_type.find('WGAN') >= 0):
42
  loss_d = (d_fake - d_real).mean()
43
  if self.gan_type.find('GP') >= 0:
44
- epsilon = torch.rand(real.size(0), 1, 1, 1).to(dist_util.dev())
45
  epsilon = epsilon.expand(real.size())
46
  hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
47
  hat.requires_grad = True
 
4
  import torch.optim as optim
5
  from torch.nn.parallel.distributed import DistributedDataParallel as DDP
6
  from .nn import mean_flat
7
+ #from . import dist_util
8
  import functools
9
 
10
  class AdversarialLoss(nn.Module):
 
16
  self.gan_type = gan_type
17
  self.gan_k = gan_k
18
 
19
+ model = NLayerDiscriminator().cuda()
20
  self.discriminator = DDP(
21
  model,
22
+ device_ids=[torch.device('cuda')],
23
+ output_device=torch.device('cuda'),
24
  broadcast_buffers=False,
25
  bucket_cap_mb=128,
26
  find_unused_parameters=False,
 
41
  if (self.gan_type.find('WGAN') >= 0):
42
  loss_d = (d_fake - d_real).mean()
43
  if self.gan_type.find('GP') >= 0:
44
+ epsilon = torch.rand(real.size(0), 1, 1, 1).cuda()
45
  epsilon = epsilon.expand(real.size())
46
  hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
47
  hat.requires_grad = True