fffiloni commited on
Commit
a86116e
·
verified ·
1 Parent(s): fefc151

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -0
main.py CHANGED
@@ -59,6 +59,9 @@ def setup(args):
59
  pipe = get_model(
60
  args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
61
  )
 
 
 
62
  trainer = LatentNoiseTrainer(
63
  reward_losses=reward_losses,
64
  model=pipe,
@@ -75,6 +78,8 @@ def setup(args):
75
  imageselect=args.imageselect,
76
  )
77
 
 
 
78
  # Create latents
79
  if args.model == "flux":
80
  # currently only support 512x512 generation
@@ -97,6 +102,7 @@ def setup(args):
97
  height // pipe.vae_scale_factor,
98
  width // pipe.vae_scale_factor,
99
  )
 
100
  enable_grad = not args.no_optim
101
 
102
  if args.enable_multi_apply:
@@ -111,6 +117,8 @@ def setup(args):
111
  else:
112
  multi_apply_fn = None
113
 
 
 
114
  return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
115
 
116
  def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
 
59
  pipe = get_model(
60
  args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
61
  )
62
+
63
+ torch.cuda.empty_cache() # Free up cached memory
64
+
65
  trainer = LatentNoiseTrainer(
66
  reward_losses=reward_losses,
67
  model=pipe,
 
78
  imageselect=args.imageselect,
79
  )
80
 
81
+ torch.cuda.empty_cache() # Free up cached memory
82
+
83
  # Create latents
84
  if args.model == "flux":
85
  # currently only support 512x512 generation
 
102
  height // pipe.vae_scale_factor,
103
  width // pipe.vae_scale_factor,
104
  )
105
+
106
  enable_grad = not args.no_optim
107
 
108
  if args.enable_multi_apply:
 
117
  else:
118
  multi_apply_fn = None
119
 
120
+ torch.cuda.empty_cache() # Free up cached memory
121
+
122
  return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
123
 
124
  def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):