Files changed (1) hide show
  1. app.py +24 -32
app.py CHANGED
@@ -1,4 +1,3 @@
1
- from functools import partial
2
  import os
3
  from PIL import Image, ImageOps
4
  import random
@@ -46,6 +45,7 @@ If you have uploaded one of your own images, it is very likely that you will nee
46
  You should verify that the preprocessed image is object-centric (i.e., clearly contains a single object) and has white background.
47
  '''
48
 
 
49
  def center_and_square_image(pil_image_rgba, drags):
50
  image = pil_image_rgba
51
  alpha = np.array(image)[:, :, 3] # Extract the alpha channel
@@ -70,11 +70,13 @@ def center_and_square_image(pil_image_rgba, drags):
70
  image = image.resize((256, 256), Image.Resampling.LANCZOS)
71
  return image, new_drags
72
 
 
73
  def sam_init():
74
  sam_checkpoint = os.path.join(os.path.dirname(__file__), "ckpts", "sam_vit_h_4b8939.pth")
75
  predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to("cuda"))
76
  return predictor
77
 
 
78
  def model_init():
79
  model_checkpoint = os.path.join(os.path.dirname(__file__), "ckpts", "drag-a-part-final.pt")
80
  model = UNet2DDragConditionModel.from_pretrained_sd(
@@ -94,13 +96,24 @@ def model_init():
94
  model.load_state_dict(torch.load(model_checkpoint, map_location="cpu")["model"])
95
  return model.to("cuda")
96
 
 
 
 
 
 
 
 
 
 
 
 
97
  @spaces.GPU(duration=10)
98
- def sam_segment(predictor, input_image, drags, foreground_points=None):
99
  image = np.asarray(input_image)
100
- predictor.set_image(image)
101
 
102
  with torch.no_grad():
103
- masks_bbox, _, _ = predictor.predict(
104
  point_coords=foreground_points if foreground_points is not None else None,
105
  point_labels=np.ones(len(foreground_points)) if foreground_points is not None else None,
106
  multimask_output=True
@@ -114,6 +127,7 @@ def sam_segment(predictor, input_image, drags, foreground_points=None):
114
 
115
  return out_image, new_drags
116
 
 
117
  def get_point(img, sel_pix, evt: gr.SelectData):
118
  sel_pix.append(evt.index)
119
  points = []
@@ -136,10 +150,12 @@ def get_point(img, sel_pix, evt: gr.SelectData):
136
  points = []
137
  return img if isinstance(img, np.ndarray) else np.array(img)
138
 
 
139
  def clear_drag():
140
  return []
141
 
142
- def preprocess_image(SAM_predictor, img, chk_group, drags):
 
143
  if img is None:
144
  gr.Warning("No image is specified. Please specify an image before preprocessing.")
145
  return None, drags
@@ -157,7 +173,6 @@ def preprocess_image(SAM_predictor, img, chk_group, drags):
157
  img_np = np.array(img)
158
  rgb_img = img_np[..., :3]
159
  img, new_drags = sam_segment(
160
- SAM_predictor,
161
  rgb_img,
162
  drags,
163
  foreground_points=foreground_points,
@@ -173,8 +188,6 @@ def preprocess_image(SAM_predictor, img, chk_group, drags):
173
 
174
 
175
  def single_image_sample(
176
- model,
177
- diffusion,
178
  x_cond,
179
  x_cond_clip,
180
  rel,
@@ -183,7 +196,6 @@ def single_image_sample(
183
  drags,
184
  hidden_cls,
185
  num_steps=50,
186
- vae=None,
187
  ):
188
  z = torch.randn(2, 4, 32, 32).to("cuda")
189
 
@@ -231,16 +243,11 @@ def single_image_sample(
231
 
232
 
233
  @spaces.GPU(duration=20)
234
- def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion, img_cond, seed, cfg_scale, drags_list):
235
  if img_cond is None:
236
  gr.Warning("Please preprocess the image first.")
237
  return None
238
 
239
- model = model.to("cuda")
240
- vae = vae.to("cuda")
241
- clip_model = clip_model.to("cuda")
242
- clip_vit = clip_vit.to("cuda")
243
-
244
  with torch.no_grad():
245
  torch.manual_seed(seed)
246
  np.random.seed(seed)
@@ -279,8 +286,6 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
279
  break
280
 
281
  return single_image_sample(
282
- model.to("cuda"),
283
- diffusion,
284
  x_cond,
285
  cond_clip_features,
286
  rel,
@@ -289,22 +294,9 @@ def generate_image(model, image_processor, vae, clip_model, clip_vit, diffusion,
289
  drags,
290
  cls_embedding,
291
  num_steps=50,
292
- vae=vae,
293
  )
294
 
295
 
296
- sam_predictor = sam_init()
297
- model = model_init()
298
-
299
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to('cuda')
300
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to('cuda')
301
- clip_vit = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to('cuda')
302
- image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
303
- diffusion = create_diffusion(
304
- timestep_respacing="",
305
- learn_sigma=False,
306
- )
307
-
308
  with gr.Blocks(title=TITLE) as demo:
309
  gr.Markdown("# " + DESCRIPTION)
310
 
@@ -378,7 +370,7 @@ with gr.Blocks(title=TITLE) as demo:
378
  value="Preprocess Input Image",
379
  )
380
  preprocess_button.click(
381
- fn=partial(preprocess_image, sam_predictor),
382
  inputs=[input_image, preprocess_chk_group, drags],
383
  outputs=[processed_image, drags],
384
  queue=True,
@@ -407,7 +399,7 @@ with gr.Blocks(title=TITLE) as demo:
407
  value="Generate Image",
408
  )
409
  generate_button.click(
410
- fn=partial(generate_image, model, image_processor, vae, clip_model, clip_vit, diffusion),
411
  inputs=[processed_image, seed, cfg_scale, drags],
412
  outputs=[generated_image],
413
  )
 
 
1
  import os
2
  from PIL import Image, ImageOps
3
  import random
 
45
  You should verify that the preprocessed image is object-centric (i.e., clearly contains a single object) and has white background.
46
  '''
47
 
48
+
49
  def center_and_square_image(pil_image_rgba, drags):
50
  image = pil_image_rgba
51
  alpha = np.array(image)[:, :, 3] # Extract the alpha channel
 
70
  image = image.resize((256, 256), Image.Resampling.LANCZOS)
71
  return image, new_drags
72
 
73
+
74
  def sam_init():
75
  sam_checkpoint = os.path.join(os.path.dirname(__file__), "ckpts", "sam_vit_h_4b8939.pth")
76
  predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to("cuda"))
77
  return predictor
78
 
79
+
80
  def model_init():
81
  model_checkpoint = os.path.join(os.path.dirname(__file__), "ckpts", "drag-a-part-final.pt")
82
  model = UNet2DDragConditionModel.from_pretrained_sd(
 
96
  model.load_state_dict(torch.load(model_checkpoint, map_location="cpu")["model"])
97
  return model.to("cuda")
98
 
99
+
100
+ sam_predictor = sam_init()
101
+ model = model_init()
102
+
103
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to('cuda')
104
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to('cuda')
105
+ clip_vit = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to('cuda')
106
+ image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
107
+ diffusion = create_diffusion(timestep_respacing="", learn_sigma=False)
108
+
109
+
110
  @spaces.GPU(duration=10)
111
+ def sam_segment(input_image, drags, foreground_points=None):
112
  image = np.asarray(input_image)
113
+ sam_predictor.set_image(image)
114
 
115
  with torch.no_grad():
116
+ masks_bbox, _, _ = sam_predictor.predict(
117
  point_coords=foreground_points if foreground_points is not None else None,
118
  point_labels=np.ones(len(foreground_points)) if foreground_points is not None else None,
119
  multimask_output=True
 
127
 
128
  return out_image, new_drags
129
 
130
+
131
  def get_point(img, sel_pix, evt: gr.SelectData):
132
  sel_pix.append(evt.index)
133
  points = []
 
150
  points = []
151
  return img if isinstance(img, np.ndarray) else np.array(img)
152
 
153
+
154
  def clear_drag():
155
  return []
156
 
157
+
158
+ def preprocess_image(img, chk_group, drags):
159
  if img is None:
160
  gr.Warning("No image is specified. Please specify an image before preprocessing.")
161
  return None, drags
 
173
  img_np = np.array(img)
174
  rgb_img = img_np[..., :3]
175
  img, new_drags = sam_segment(
 
176
  rgb_img,
177
  drags,
178
  foreground_points=foreground_points,
 
188
 
189
 
190
  def single_image_sample(
 
 
191
  x_cond,
192
  x_cond_clip,
193
  rel,
 
196
  drags,
197
  hidden_cls,
198
  num_steps=50,
 
199
  ):
200
  z = torch.randn(2, 4, 32, 32).to("cuda")
201
 
 
243
 
244
 
245
  @spaces.GPU(duration=20)
246
+ def generate_image(img_cond, seed, cfg_scale, drags_list):
247
  if img_cond is None:
248
  gr.Warning("Please preprocess the image first.")
249
  return None
250
 
 
 
 
 
 
251
  with torch.no_grad():
252
  torch.manual_seed(seed)
253
  np.random.seed(seed)
 
286
  break
287
 
288
  return single_image_sample(
 
 
289
  x_cond,
290
  cond_clip_features,
291
  rel,
 
294
  drags,
295
  cls_embedding,
296
  num_steps=50,
 
297
  )
298
 
299
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  with gr.Blocks(title=TITLE) as demo:
301
  gr.Markdown("# " + DESCRIPTION)
302
 
 
370
  value="Preprocess Input Image",
371
  )
372
  preprocess_button.click(
373
+ fn=preprocess_image,
374
  inputs=[input_image, preprocess_chk_group, drags],
375
  outputs=[processed_image, drags],
376
  queue=True,
 
399
  value="Generate Image",
400
  )
401
  generate_button.click(
402
+ fn=generate_image,
403
  inputs=[processed_image, seed, cfg_scale, drags],
404
  outputs=[generated_image],
405
  )