Spaces:
Running
on
Zero
Running
on
Zero
Ruining Li
commited on
Commit
•
0af5cc2
1
Parent(s):
23f2fb5
Adapt to HF ZeroGPU
Browse files
app.py
CHANGED
@@ -185,8 +185,11 @@ def single_image_sample(
|
|
185 |
drags,
|
186 |
hidden_cls,
|
187 |
num_steps=50,
|
|
|
188 |
):
|
189 |
z = torch.randn(2, 4, 32, 32).to("cuda")
|
|
|
|
|
190 |
|
191 |
# Prepare input for classifer-free guidance
|
192 |
rel = torch.cat([rel, rel], dim=0).to("cuda")
|
@@ -222,7 +225,10 @@ def single_image_sample(
|
|
222 |
)
|
223 |
|
224 |
samples, _ = samples.chunk(2, dim=0)
|
225 |
-
|
|
|
|
|
|
|
226 |
|
227 |
@spaces.GPU
|
228 |
def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
|
@@ -272,7 +278,7 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
|
|
272 |
if idx == 9:
|
273 |
break
|
274 |
|
275 |
-
|
276 |
model.to("cuda"),
|
277 |
diffusion,
|
278 |
x_cond,
|
@@ -283,11 +289,8 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
|
|
283 |
drags,
|
284 |
cls_embedding,
|
285 |
num_steps=50,
|
|
|
286 |
)
|
287 |
-
|
288 |
-
with torch.no_grad():
|
289 |
-
images = vae.decode(samples / 0.18215).sample
|
290 |
-
images = ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
291 |
return images
|
292 |
|
293 |
|
|
|
185 |
drags,
|
186 |
hidden_cls,
|
187 |
num_steps=50,
|
188 |
+
vae=None,
|
189 |
):
|
190 |
z = torch.randn(2, 4, 32, 32).to("cuda")
|
191 |
+
if vae is not None:
|
192 |
+
vae = vae.to("cuda")
|
193 |
|
194 |
# Prepare input for classifer-free guidance
|
195 |
rel = torch.cat([rel, rel], dim=0).to("cuda")
|
|
|
225 |
)
|
226 |
|
227 |
samples, _ = samples.chunk(2, dim=0)
|
228 |
+
|
229 |
+
with torch.no_grad():
|
230 |
+
images = vae.decode(samples / 0.18215).sample
|
231 |
+
return ((images + 1)[0].permute(1, 2, 0) * 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
|
232 |
|
233 |
@spaces.GPU
|
234 |
def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
|
|
|
278 |
if idx == 9:
|
279 |
break
|
280 |
|
281 |
+
images = single_image_sample(
|
282 |
model.to("cuda"),
|
283 |
diffusion,
|
284 |
x_cond,
|
|
|
289 |
drags,
|
290 |
cls_embedding,
|
291 |
num_steps=50,
|
292 |
+
vae=vae,
|
293 |
)
|
|
|
|
|
|
|
|
|
294 |
return images
|
295 |
|
296 |
|