Yw22 commited on
Commit
67fa832
1 Parent(s): dc4889e
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -304,7 +304,7 @@ class ImageConductor:
304
  else:
305
  input_all_points = tracking_points.value
306
 
307
-
308
  resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
309
 
310
  dir, base, ext = split_filename(first_frame_path)
@@ -392,7 +392,7 @@ def reset_states(first_frame_path, tracking_points):
392
  return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
393
 
394
 
395
- def preprocess_image(image):
396
  image_pil = image2pil(image.name)
397
  raw_w, raw_h = image_pil.size
398
  resize_ratio = max(384/raw_w, 256/raw_h)
@@ -401,7 +401,8 @@ def preprocess_image(image):
401
  id = str(uuid.uuid4())[:4]
402
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
403
  image_pil.save(first_frame_path, quality=95)
404
- return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: gr.State([])}
 
405
 
406
 
407
  def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
@@ -438,7 +439,10 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
438
 
439
 
440
  def add_drag(tracking_points):
441
- tracking_points.value.append([])
 
 
 
442
  print(tracking_points.value)
443
  return {tracking_points_var: tracking_points}
444
 
@@ -615,7 +619,6 @@ with block:
615
  inputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
616
  outputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
617
  fn=process_example,
618
- run_on_click=True,
619
  examples_per_page=10,
620
  cache_examples=False,
621
  )
@@ -625,9 +628,9 @@ with block:
625
  gr.Markdown(citation)
626
 
627
 
628
- image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path_var, tracking_points_var])
629
 
630
- add_drag_button.click(add_drag, [tracking_points_var], tracking_points_var)
631
 
632
  delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
633
 
 
304
  else:
305
  input_all_points = tracking_points.value
306
 
307
+ print("input_all_points", input_all_points)
308
  resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
309
 
310
  dir, base, ext = split_filename(first_frame_path)
 
392
  return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
393
 
394
 
395
+ def preprocess_image(image, tracking_points):
396
  image_pil = image2pil(image.name)
397
  raw_w, raw_h = image_pil.size
398
  resize_ratio = max(384/raw_w, 256/raw_h)
 
401
  id = str(uuid.uuid4())[:4]
402
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
403
  image_pil.save(first_frame_path, quality=95)
404
+ tracking_points = gr.State([])
405
+ return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
406
 
407
 
408
  def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
 
439
 
440
 
441
  def add_drag(tracking_points):
442
+ print("before", tracking_points.value)
443
+
444
+ if tracking_points.value != []:
445
+ tracking_points.value.append([])
446
  print(tracking_points.value)
447
  return {tracking_points_var: tracking_points}
448
 
 
619
  inputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
620
  outputs=[input_image, prompt, drag_mode, seed, personalized, examples_type],
621
  fn=process_example,
 
622
  examples_per_page=10,
623
  cache_examples=False,
624
  )
 
628
  gr.Markdown(citation)
629
 
630
 
631
+ image_upload_button.upload(preprocess_image, [image_upload_button, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
632
 
633
+ add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
634
 
635
  delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
636