Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -59,6 +59,9 @@ def setup(args):
|
|
59 |
pipe = get_model(
|
60 |
args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
|
61 |
)
|
|
|
|
|
|
|
62 |
trainer = LatentNoiseTrainer(
|
63 |
reward_losses=reward_losses,
|
64 |
model=pipe,
|
@@ -75,6 +78,8 @@ def setup(args):
|
|
75 |
imageselect=args.imageselect,
|
76 |
)
|
77 |
|
|
|
|
|
78 |
# Create latents
|
79 |
if args.model == "flux":
|
80 |
# currently only support 512x512 generation
|
@@ -97,6 +102,7 @@ def setup(args):
|
|
97 |
height // pipe.vae_scale_factor,
|
98 |
width // pipe.vae_scale_factor,
|
99 |
)
|
|
|
100 |
enable_grad = not args.no_optim
|
101 |
|
102 |
if args.enable_multi_apply:
|
@@ -111,6 +117,8 @@ def setup(args):
|
|
111 |
else:
|
112 |
multi_apply_fn = None
|
113 |
|
|
|
|
|
114 |
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
115 |
|
116 |
def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
|
|
|
59 |
pipe = get_model(
|
60 |
args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading
|
61 |
)
|
62 |
+
|
63 |
+
torch.cuda.empty_cache() # Free up cached memory
|
64 |
+
|
65 |
trainer = LatentNoiseTrainer(
|
66 |
reward_losses=reward_losses,
|
67 |
model=pipe,
|
|
|
78 |
imageselect=args.imageselect,
|
79 |
)
|
80 |
|
81 |
+
torch.cuda.empty_cache() # Free up cached memory
|
82 |
+
|
83 |
# Create latents
|
84 |
if args.model == "flux":
|
85 |
# currently only support 512x512 generation
|
|
|
102 |
height // pipe.vae_scale_factor,
|
103 |
width // pipe.vae_scale_factor,
|
104 |
)
|
105 |
+
|
106 |
enable_grad = not args.no_optim
|
107 |
|
108 |
if args.enable_multi_apply:
|
|
|
117 |
else:
|
118 |
multi_apply_fn = None
|
119 |
|
120 |
+
torch.cuda.empty_cache() # Free up cached memory
|
121 |
+
|
122 |
return args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings
|
123 |
|
124 |
def execute_task(args, trainer, device, dtype, shape, enable_grad, multi_apply_fn, settings, progress_callback=None):
|