hysts HF staff commited on
Commit
54d2095
1 Parent(s): befa34f

Apply black

Browse files
Files changed (1) hide show
  1. app.py +233 -135
app.py CHANGED
@@ -19,10 +19,8 @@ import warnings
19
 
20
  from gradio_demo.utils_drag import *
21
  from models_diffusers.controlnet_svd import ControlNetSVDModel
22
- from models_diffusers.unet_spatio_temporal_condition import \
23
- UNetSpatioTemporalConditionModel
24
- from pipelines.pipeline_stable_video_diffusion_interp_control import \
25
- StableVideoDiffusionInterpControlPipeline
26
 
27
  print("gr file", gr.__file__)
28
 
@@ -43,6 +41,7 @@ snapshot_download(
43
 
44
  def get_args():
45
  import argparse
 
46
  parser = argparse.ArgumentParser()
47
 
48
  parser.add_argument("--min_guidance_scale", type=float, default=1.0)
@@ -55,11 +54,12 @@ def get_args():
55
  parser.add_argument(
56
  "--dataset",
57
  type=str,
58
- default='videoswap',
59
  )
60
 
61
  parser.add_argument(
62
- "--model", type=str,
 
63
  default="checkpoints/framer_512x320",
64
  help="Path to model.",
65
  )
@@ -112,27 +112,34 @@ def interpolate_trajectory(points, n_points):
112
 
113
  def gen_gaussian_heatmap(imgSize=200):
114
  circle_img = np.zeros((imgSize, imgSize), np.float32)
115
- circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)
116
 
117
  isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
118
 
119
  for i in range(imgSize):
120
  for j in range(imgSize):
121
- isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
122
- -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))
 
 
 
 
 
123
 
124
  isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
125
  isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
126
- isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)
127
 
128
  return isotropicGrayscaleImage
129
 
130
 
131
  def get_vis_image(
132
- target_size=(512 , 512), points=None, side=20,
133
- num_frames=14,
134
- # original_size=(512 , 512), args="", first_frame=None, is_mask = False, model_id=None,
135
- ):
 
 
136
 
137
  # images = []
138
  vis_images = []
@@ -140,13 +147,13 @@ def get_vis_image(
140
 
141
  trajectory_list = []
142
  radius_list = []
143
-
144
  for index, point in enumerate(points):
145
  trajectories = [[int(i[0]), int(i[1])] for i in point]
146
  trajectory_list.append(trajectories)
147
 
148
  radius = 20
149
- radius_list.append(radius)
150
 
151
  if len(trajectory_list) == 0:
152
  vis_images = [Image.fromarray(np.zeros(target_size, np.uint8)) for _ in range(num_frames)]
@@ -156,33 +163,39 @@ def get_vis_image(
156
  new_img = np.zeros(target_size, np.uint8)
157
  vis_img = new_img.copy()
158
  # ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
159
-
160
  if idxx >= args.num_frames:
161
  break
162
 
163
  # for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
164
  for cc, (trajectory, radius) in enumerate(zip(trajectory_list, radius_list)):
165
-
166
  center_coordinate = trajectory[idxx]
167
  trajectory_ = trajectory[:idxx]
168
  side = min(radius, 50)
169
-
170
- y1 = max(center_coordinate[1] - side,0)
171
  y2 = min(center_coordinate[1] + side, target_size[0] - 1)
172
  x1 = max(center_coordinate[0] - side, 0)
173
  x2 = min(center_coordinate[0] + side, target_size[1] - 1)
174
-
175
- if x2-x1>3 and y2-y1>3:
176
- need_map = cv2.resize(heatmap, (x2-x1, y2-y1))
177
  new_img[y1:y2, x1:x2] = need_map.copy()
178
-
179
  if cc >= 0:
180
- vis_img[y1:y2,x1:x2] = need_map.copy()
181
  if len(trajectory_) == 1:
182
  vis_img[trajectory_[0][1], trajectory_[0][0]] = 255
183
  else:
184
- for itt in range(len(trajectory_)-1):
185
- cv2.line(vis_img, (trajectory_[itt][0], trajectory_[itt][1]), (trajectory_[itt+1][0], trajectory_[itt+1][1]), (255, 255, 255), 3)
 
 
 
 
 
 
186
 
187
  img = new_img
188
 
@@ -193,7 +206,7 @@ def get_vis_image(
193
  elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format
194
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
195
  vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
196
-
197
  # Convert the numpy array to a PIL image
198
  # pil_img = Image.fromarray(img)
199
  # images.append(pil_img)
@@ -214,7 +227,7 @@ def frames_to_video(frames_folder, output_video_path, fps=7):
214
  video.append(frame)
215
 
216
  video = torch.stack(video)
217
- video = rearrange(video, 'T C H W -> T H W C')
218
  torchvision.io.write_video(output_video_path, video, fps=fps)
219
 
220
 
@@ -222,11 +235,12 @@ def save_gifs_side_by_side(
222
  batch_output,
223
  validation_control_images,
224
  output_folder,
225
- target_size=(512 , 512),
226
  duration=200,
227
  point_tracks=None,
228
  ):
229
  flattened_batch_output = batch_output
 
230
  def create_gif(image_list, gif_path, duration=100):
231
  pil_images = [validate_and_convert_image(img, target_size=target_size) for img in image_list]
232
  pil_images = [img for img in pil_images if img is not None]
@@ -242,7 +256,7 @@ def save_gifs_side_by_side(
242
  tmp_frame_path = os.path.join(tmp_folder, f"{idx}.png")
243
  pil_image.save(tmp_frame_path)
244
  tmp_frame_list.append(tmp_frame_path)
245
-
246
  # also save as mp4
247
  output_video_path = gif_path.replace(".gif", ".mp4")
248
  frames_to_video(tmp_folder, output_video_path, fps=7)
@@ -285,25 +299,25 @@ def save_gifs_side_by_side(
285
  if output_path.endswith(".mp4"):
286
  video = [torchvision.transforms.functional.pil_to_tensor(frame) for frame in frames]
287
  video = torch.stack(video)
288
- video = rearrange(video, 'T C H W -> T H W C')
289
  torchvision.io.write_video(output_path, video, fps=7)
290
  print(f"Saved video to {output_path}")
291
  else:
292
  frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration)
293
-
294
  # Helper function to concatenate images horizontally
295
  def get_concat_h(im1, im2, gap=10):
296
  # # img first, heatmap second
297
  # im1, im2 = im2, im1
298
 
299
- dst = Image.new('RGB', (im1.width + im2.width + gap, max(im1.height, im2.height)), (255, 255, 255))
300
  dst.paste(im1, (0, 0))
301
  dst.paste(im2, (im1.width + gap, 0))
302
  return dst
303
 
304
  # Helper function to concatenate images vertically
305
  def get_concat_v(im1, im2):
306
- dst = Image.new('RGB', (max(im1.width, im2.width), im1.height + im2.height))
307
  dst.paste(im1, (0, 0))
308
  dst.paste(im2, (0, im1.height))
309
  return dst
@@ -324,7 +338,7 @@ def save_gifs_side_by_side(
324
 
325
 
326
  # Define functions
327
- def validate_and_convert_image(image, target_size=(512 , 512)):
328
  if image is None:
329
  print("Encountered a None image")
330
  return None
@@ -345,7 +359,7 @@ def validate_and_convert_image(image, target_size=(512 , 512)):
345
  else:
346
  print("Image is not a PIL Image or a PyTorch tensor")
347
  return None
348
-
349
  return image
350
 
351
 
@@ -371,19 +385,21 @@ class Drag:
371
 
372
  if is_xformers_available():
373
  import xformers
 
374
  xformers_version = version.parse(xformers.__version__)
375
  unet.enable_xformers_memory_efficient_attention()
376
  # controlnet.enable_xformers_memory_efficient_attention()
377
  else:
378
- raise ValueError(
379
- "xformers is not available. Make sure it is installed correctly")
380
 
381
  pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
382
  "checkpoints/stable-video-diffusion-img2vid-xt",
383
  unet=unet,
384
  controlnet=controlnet,
385
  low_cpu_mem_usage=False,
386
- torch_dtype=torch.float16, variant="fp16", local_files_only=True,
 
 
387
  )
388
  pipe.to(device)
389
 
@@ -397,18 +413,18 @@ class Drag:
397
  self.use_sift = use_sift
398
 
399
  @spaces.GPU
400
- def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
401
  original_width, original_height = 512, 320 # TODO
402
 
403
  # load_image
404
- image = Image.open(first_frame_path).convert('RGB')
405
  width, height = image.size
406
  image = image.resize((self.width, self.height))
407
 
408
- image_end = Image.open(last_frame_path).convert('RGB')
409
  image_end = image_end.resize((self.width, self.height))
410
 
411
- input_all_points = tracking_points.constructor_args['value']
412
 
413
  sift_track_update = False
414
  anchor_points_flag = None
@@ -417,11 +433,10 @@ class Drag:
417
  sift_track_update = True
418
  controlnet_cond_scale = 0.5
419
 
420
- from models_diffusers.sift_match import \
421
- interpolate_trajectory as sift_interpolate_trajectory
422
  from models_diffusers.sift_match import sift_match
423
 
424
- output_file_sift = os.path.join(args.output_dir, "sift.png")
425
 
426
  # (f, topk, 2), f=2 (before interpolation)
427
  pred_tracks = sift_match(
@@ -446,9 +461,12 @@ class Drag:
446
  else:
447
 
448
  resized_all_points = [
449
- tuple([
450
- tuple([int(e1[0] * self.width / original_width), int(e1[1] * self.height / original_height)])
451
- for e1 in e])
 
 
 
452
  for e in input_all_points
453
  ]
454
 
@@ -460,12 +478,12 @@ class Drag:
460
  warnings.warn("running without point trajectory control")
461
  continue
462
 
463
- if len(splited_track) == 1: # stationary point
464
  displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
465
  splited_track = tuple([splited_track[0], displacement_point])
466
  # interpolate the track
467
  splited_track = interpolate_trajectory(splited_track, self.model_length)
468
- splited_track = splited_track[:self.model_length]
469
  resized_all_points[idx] = splited_track
470
 
471
  pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
@@ -498,7 +516,7 @@ class Drag:
498
  num_frames=14,
499
  width=width,
500
  height=height,
501
- # decode_chunk_size=8,
502
  # generator=generator,
503
  motion_bucket_id=motion_bucket_id,
504
  fps=7,
@@ -511,12 +529,12 @@ class Drag:
511
  vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
512
  vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
513
  vis_images = [Image.fromarray(img) for img in vis_images]
514
-
515
  # video_frames = [img for sublist in video_frames for img in sublist]
516
  val_save_dir = os.path.join(args.output_dir, "vis_gif.gif")
517
  save_gifs_side_by_side(
518
- video_frames,
519
- vis_images[:self.model_length],
520
  val_save_dir,
521
  target_size=(self.width, self.height),
522
  duration=110,
@@ -545,7 +563,7 @@ def preprocess_image(image):
545
  image_pil = image_pil.resize((512, 320), Image.BILINEAR)
546
 
547
  first_frame_path = os.path.join(args.output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
548
-
549
  image_pil.save(first_frame_path)
550
 
551
  return first_frame_path, first_frame_path, gr.State([])
@@ -569,29 +587,42 @@ def preprocess_image_end(image_end):
569
 
570
 
571
  def add_drag(tracking_points):
572
- tracking_points.constructor_args['value'].append([])
573
  return tracking_points
574
 
575
 
576
  def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
577
- tracking_points.constructor_args['value'].pop()
578
- transparent_background = Image.open(first_frame_path).convert('RGBA')
579
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
580
  w, h = transparent_background.size
581
  transparent_layer = np.zeros((h, w, 4))
582
 
583
- for track in tracking_points.constructor_args['value']:
584
  if len(track) > 1:
585
- for i in range(len(track)-1):
586
  start_point = track[i]
587
- end_point = track[i+1]
588
  vx = end_point[0] - start_point[0]
589
  vy = end_point[1] - start_point[1]
590
  arrow_length = np.sqrt(vx**2 + vy**2)
591
- if i == len(track)-2:
592
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
 
 
 
 
 
 
 
593
  else:
594
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
 
 
 
 
 
 
595
  else:
596
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
597
 
@@ -603,24 +634,37 @@ def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
603
 
604
 
605
  def delete_last_step(tracking_points, first_frame_path, last_frame_path):
606
- tracking_points.constructor_args['value'][-1].pop()
607
- transparent_background = Image.open(first_frame_path).convert('RGBA')
608
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
609
  w, h = transparent_background.size
610
  transparent_layer = np.zeros((h, w, 4))
611
 
612
- for track in tracking_points.constructor_args['value']:
613
  if len(track) > 1:
614
- for i in range(len(track)-1):
615
  start_point = track[i]
616
- end_point = track[i+1]
617
  vx = end_point[0] - start_point[0]
618
  vy = end_point[1] - start_point[1]
619
  arrow_length = np.sqrt(vx**2 + vy**2)
620
- if i == len(track)-2:
621
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
 
 
 
 
 
 
 
622
  else:
623
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
 
 
 
 
 
 
624
  else:
625
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
626
 
@@ -631,34 +675,49 @@ def delete_last_step(tracking_points, first_frame_path, last_frame_path):
631
  return tracking_points, trajectory_map, trajectory_map_end
632
 
633
 
634
- def add_tracking_points(tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData): # SelectData is a subclass of EventData
 
 
635
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
636
- tracking_points.constructor_args['value'][-1].append(evt.index)
637
 
638
- transparent_background = Image.open(first_frame_path).convert('RGBA')
639
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
640
 
641
  w, h = transparent_background.size
642
  transparent_layer = 0
643
- for idx, track in enumerate(tracking_points.constructor_args['value']):
644
  # mask = cv2.imread(
645
  # os.path.join(args.output_dir, f"mask_{idx+1}.jpg")
646
  # )
647
  mask = np.zeros((320, 512, 3))
648
- color = color_list[idx+1]
649
  transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer
650
 
651
  if len(track) > 1:
652
- for i in range(len(track)-1):
653
  start_point = track[i]
654
- end_point = track[i+1]
655
  vx = end_point[0] - start_point[0]
656
  vy = end_point[1] - start_point[1]
657
  arrow_length = np.sqrt(vx**2 + vy**2)
658
- if i == len(track)-2:
659
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
 
 
 
 
 
 
 
660
  else:
661
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
 
 
 
 
 
 
662
  else:
663
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
664
 
@@ -678,22 +737,25 @@ if __name__ == "__main__":
678
 
679
  args = get_args()
680
  ensure_dirname(args.output_dir)
681
-
682
  color_list = []
683
  for i in range(20):
684
- color = np.concatenate([np.random.random(4)*255], axis=0)
685
  color_list.append(color)
686
 
687
  with gr.Blocks() as demo:
688
  gr.Markdown("""<h1 align="center">Framer: Interactive Frame Interpolation</h1><br>""")
689
-
690
- gr.Markdown("""Gradio Demo for <a href='https://arxiv.org/abs/2410.18978'><b>Framer: Interactive Frame Interpolation</b></a>.<br>
 
691
  Github Repo can be found at https://github.com/aim-uofa/Framer<br>
692
- The template is inspired by DragAnything.""")
693
-
 
694
  gr.Image(label="Framer: Interactive Frame Interpolation", value="assets/demos.gif", height=432, width=768)
695
-
696
- gr.Markdown("""## Usage: <br>
 
697
  1. Upload images<br>
698
  &ensp; 1.1 Upload the start image via the "Upload Start Image" button.<br>
699
  &ensp; 1.2. Upload the end image via the "Upload End Image" button.<br>
@@ -702,14 +764,15 @@ if __name__ == "__main__":
702
  &ensp; 2.2. You can click several points on either start or end image to forms a path.<br>
703
  &ensp; 2.3. Click "Delete last drag" to delete the whole lastest path.<br>
704
  &ensp; 2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
705
- 3. Interpolate the images (according the path) with a click on "Run" button. <br>""")
706
-
 
707
  # device, args, height, width, model_length
708
  Framer = Drag("cuda", args, 320, 512, 14)
709
  first_frame_path = gr.State()
710
  last_frame_path = gr.State()
711
  tracking_points = gr.State([])
712
-
713
  with gr.Row():
714
  with gr.Column(scale=1):
715
  image_upload_button = gr.UploadButton(label="Upload Start Image", file_types=["image"])
@@ -720,7 +783,7 @@ if __name__ == "__main__":
720
  run_button = gr.Button(value="Run")
721
  delete_last_drag_button = gr.Button(value="Delete last drag")
722
  delete_last_step_button = gr.Button(value="Delete last step")
723
-
724
  with gr.Column(scale=7):
725
  with gr.Row():
726
  with gr.Column(scale=6):
@@ -731,7 +794,7 @@ if __name__ == "__main__":
731
  width=512,
732
  sources=[],
733
  )
734
-
735
  with gr.Column(scale=6):
736
  input_image_end = gr.Image(
737
  label="end frame",
@@ -740,36 +803,36 @@ if __name__ == "__main__":
740
  width=512,
741
  sources=[],
742
  )
743
-
744
  with gr.Row():
745
  with gr.Column(scale=1):
746
-
747
  controlnet_cond_scale = gr.Slider(
748
- label='Control Scale',
749
- minimum=0.0,
750
- maximum=10,
751
- step=0.1,
752
  value=1.0,
753
  )
754
-
755
  motion_bucket_id = gr.Slider(
756
- label='Motion Bucket',
757
- minimum=1,
758
- maximum=180,
759
- step=1,
760
  value=100,
761
  )
762
-
763
  with gr.Column(scale=5):
764
  output_video = gr.Image(
765
  label="Output Video",
766
  height=320,
767
  width=1152,
768
  )
769
-
770
-
771
  with gr.Row():
772
- gr.Markdown("""
 
773
  ## Citation
774
  ```bibtex
775
  @article{wang2024framer,
@@ -779,24 +842,59 @@ if __name__ == "__main__":
779
  year={2024}
780
  }
781
  ```
782
- """)
783
-
784
- image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
785
-
786
- image_end_upload_button.upload(preprocess_image_end, image_end_upload_button, [input_image_end, last_frame_path, tracking_points])
787
-
788
- add_drag_button.click(add_drag, tracking_points, [tracking_points, ])
789
-
790
- delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
791
-
792
- delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
793
-
794
- reset_button.click(reset_states, [first_frame_path, last_frame_path, tracking_points], [first_frame_path, last_frame_path, tracking_points])
795
-
796
- input_image.select(add_tracking_points, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
797
-
798
- input_image_end.select(add_tracking_points, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
799
-
800
- run_button.click(Framer.run, [first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id], output_video)
801
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
  demo.launch()
 
19
 
20
  from gradio_demo.utils_drag import *
21
  from models_diffusers.controlnet_svd import ControlNetSVDModel
22
+ from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
23
+ from pipelines.pipeline_stable_video_diffusion_interp_control import StableVideoDiffusionInterpControlPipeline
 
 
24
 
25
  print("gr file", gr.__file__)
26
 
 
41
 
42
  def get_args():
43
  import argparse
44
+
45
  parser = argparse.ArgumentParser()
46
 
47
  parser.add_argument("--min_guidance_scale", type=float, default=1.0)
 
54
  parser.add_argument(
55
  "--dataset",
56
  type=str,
57
+ default="videoswap",
58
  )
59
 
60
  parser.add_argument(
61
+ "--model",
62
+ type=str,
63
  default="checkpoints/framer_512x320",
64
  help="Path to model.",
65
  )
 
112
 
113
  def gen_gaussian_heatmap(imgSize=200):
114
  circle_img = np.zeros((imgSize, imgSize), np.float32)
115
+ circle_mask = cv2.circle(circle_img, (imgSize // 2, imgSize // 2), imgSize // 2, 1, -1)
116
 
117
  isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
118
 
119
  for i in range(imgSize):
120
  for j in range(imgSize):
121
+ isotropicGrayscaleImage[i, j] = (
122
+ 1
123
+ / 2
124
+ / np.pi
125
+ / (40**2)
126
+ * np.exp(-1 / 2 * ((i - imgSize / 2) ** 2 / (40**2) + (j - imgSize / 2) ** 2 / (40**2)))
127
+ )
128
 
129
  isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
130
  isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
131
+ isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage) * 255).astype(np.uint8)
132
 
133
  return isotropicGrayscaleImage
134
 
135
 
136
  def get_vis_image(
137
+ target_size=(512, 512),
138
+ points=None,
139
+ side=20,
140
+ num_frames=14,
141
+ # original_size=(512 , 512), args="", first_frame=None, is_mask = False, model_id=None,
142
+ ):
143
 
144
  # images = []
145
  vis_images = []
 
147
 
148
  trajectory_list = []
149
  radius_list = []
150
+
151
  for index, point in enumerate(points):
152
  trajectories = [[int(i[0]), int(i[1])] for i in point]
153
  trajectory_list.append(trajectories)
154
 
155
  radius = 20
156
+ radius_list.append(radius)
157
 
158
  if len(trajectory_list) == 0:
159
  vis_images = [Image.fromarray(np.zeros(target_size, np.uint8)) for _ in range(num_frames)]
 
163
  new_img = np.zeros(target_size, np.uint8)
164
  vis_img = new_img.copy()
165
  # ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
166
+
167
  if idxx >= args.num_frames:
168
  break
169
 
170
  # for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
171
  for cc, (trajectory, radius) in enumerate(zip(trajectory_list, radius_list)):
172
+
173
  center_coordinate = trajectory[idxx]
174
  trajectory_ = trajectory[:idxx]
175
  side = min(radius, 50)
176
+
177
+ y1 = max(center_coordinate[1] - side, 0)
178
  y2 = min(center_coordinate[1] + side, target_size[0] - 1)
179
  x1 = max(center_coordinate[0] - side, 0)
180
  x2 = min(center_coordinate[0] + side, target_size[1] - 1)
181
+
182
+ if x2 - x1 > 3 and y2 - y1 > 3:
183
+ need_map = cv2.resize(heatmap, (x2 - x1, y2 - y1))
184
  new_img[y1:y2, x1:x2] = need_map.copy()
185
+
186
  if cc >= 0:
187
+ vis_img[y1:y2, x1:x2] = need_map.copy()
188
  if len(trajectory_) == 1:
189
  vis_img[trajectory_[0][1], trajectory_[0][0]] = 255
190
  else:
191
+ for itt in range(len(trajectory_) - 1):
192
+ cv2.line(
193
+ vis_img,
194
+ (trajectory_[itt][0], trajectory_[itt][1]),
195
+ (trajectory_[itt + 1][0], trajectory_[itt + 1][1]),
196
+ (255, 255, 255),
197
+ 3,
198
+ )
199
 
200
  img = new_img
201
 
 
206
  elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format
207
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
208
  vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
209
+
210
  # Convert the numpy array to a PIL image
211
  # pil_img = Image.fromarray(img)
212
  # images.append(pil_img)
 
227
  video.append(frame)
228
 
229
  video = torch.stack(video)
230
+ video = rearrange(video, "T C H W -> T H W C")
231
  torchvision.io.write_video(output_video_path, video, fps=fps)
232
 
233
 
 
235
  batch_output,
236
  validation_control_images,
237
  output_folder,
238
+ target_size=(512, 512),
239
  duration=200,
240
  point_tracks=None,
241
  ):
242
  flattened_batch_output = batch_output
243
+
244
  def create_gif(image_list, gif_path, duration=100):
245
  pil_images = [validate_and_convert_image(img, target_size=target_size) for img in image_list]
246
  pil_images = [img for img in pil_images if img is not None]
 
256
  tmp_frame_path = os.path.join(tmp_folder, f"{idx}.png")
257
  pil_image.save(tmp_frame_path)
258
  tmp_frame_list.append(tmp_frame_path)
259
+
260
  # also save as mp4
261
  output_video_path = gif_path.replace(".gif", ".mp4")
262
  frames_to_video(tmp_folder, output_video_path, fps=7)
 
299
  if output_path.endswith(".mp4"):
300
  video = [torchvision.transforms.functional.pil_to_tensor(frame) for frame in frames]
301
  video = torch.stack(video)
302
+ video = rearrange(video, "T C H W -> T H W C")
303
  torchvision.io.write_video(output_path, video, fps=7)
304
  print(f"Saved video to {output_path}")
305
  else:
306
  frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration)
307
+
308
  # Helper function to concatenate images horizontally
309
  def get_concat_h(im1, im2, gap=10):
310
  # # img first, heatmap second
311
  # im1, im2 = im2, im1
312
 
313
+ dst = Image.new("RGB", (im1.width + im2.width + gap, max(im1.height, im2.height)), (255, 255, 255))
314
  dst.paste(im1, (0, 0))
315
  dst.paste(im2, (im1.width + gap, 0))
316
  return dst
317
 
318
  # Helper function to concatenate images vertically
319
  def get_concat_v(im1, im2):
320
+ dst = Image.new("RGB", (max(im1.width, im2.width), im1.height + im2.height))
321
  dst.paste(im1, (0, 0))
322
  dst.paste(im2, (0, im1.height))
323
  return dst
 
338
 
339
 
340
  # Define functions
341
+ def validate_and_convert_image(image, target_size=(512, 512)):
342
  if image is None:
343
  print("Encountered a None image")
344
  return None
 
359
  else:
360
  print("Image is not a PIL Image or a PyTorch tensor")
361
  return None
362
+
363
  return image
364
 
365
 
 
385
 
386
  if is_xformers_available():
387
  import xformers
388
+
389
  xformers_version = version.parse(xformers.__version__)
390
  unet.enable_xformers_memory_efficient_attention()
391
  # controlnet.enable_xformers_memory_efficient_attention()
392
  else:
393
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
 
394
 
395
  pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
396
  "checkpoints/stable-video-diffusion-img2vid-xt",
397
  unet=unet,
398
  controlnet=controlnet,
399
  low_cpu_mem_usage=False,
400
+ torch_dtype=torch.float16,
401
+ variant="fp16",
402
+ local_files_only=True,
403
  )
404
  pipe.to(device)
405
 
 
413
  self.use_sift = use_sift
414
 
415
  @spaces.GPU
416
+ def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
417
  original_width, original_height = 512, 320 # TODO
418
 
419
  # load_image
420
+ image = Image.open(first_frame_path).convert("RGB")
421
  width, height = image.size
422
  image = image.resize((self.width, self.height))
423
 
424
+ image_end = Image.open(last_frame_path).convert("RGB")
425
  image_end = image_end.resize((self.width, self.height))
426
 
427
+ input_all_points = tracking_points.constructor_args["value"]
428
 
429
  sift_track_update = False
430
  anchor_points_flag = None
 
433
  sift_track_update = True
434
  controlnet_cond_scale = 0.5
435
 
436
+ from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
 
437
  from models_diffusers.sift_match import sift_match
438
 
439
+ output_file_sift = os.path.join(args.output_dir, "sift.png")
440
 
441
  # (f, topk, 2), f=2 (before interpolation)
442
  pred_tracks = sift_match(
 
461
  else:
462
 
463
  resized_all_points = [
464
+ tuple(
465
+ [
466
+ tuple([int(e1[0] * self.width / original_width), int(e1[1] * self.height / original_height)])
467
+ for e1 in e
468
+ ]
469
+ )
470
  for e in input_all_points
471
  ]
472
 
 
478
  warnings.warn("running without point trajectory control")
479
  continue
480
 
481
+ if len(splited_track) == 1: # stationary point
482
  displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
483
  splited_track = tuple([splited_track[0], displacement_point])
484
  # interpolate the track
485
  splited_track = interpolate_trajectory(splited_track, self.model_length)
486
+ splited_track = splited_track[: self.model_length]
487
  resized_all_points[idx] = splited_track
488
 
489
  pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
 
516
  num_frames=14,
517
  width=width,
518
  height=height,
519
+ # decode_chunk_size=8,
520
  # generator=generator,
521
  motion_bucket_id=motion_bucket_id,
522
  fps=7,
 
529
  vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
530
  vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
531
  vis_images = [Image.fromarray(img) for img in vis_images]
532
+
533
  # video_frames = [img for sublist in video_frames for img in sublist]
534
  val_save_dir = os.path.join(args.output_dir, "vis_gif.gif")
535
  save_gifs_side_by_side(
536
+ video_frames,
537
+ vis_images[: self.model_length],
538
  val_save_dir,
539
  target_size=(self.width, self.height),
540
  duration=110,
 
563
  image_pil = image_pil.resize((512, 320), Image.BILINEAR)
564
 
565
  first_frame_path = os.path.join(args.output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
566
+
567
  image_pil.save(first_frame_path)
568
 
569
  return first_frame_path, first_frame_path, gr.State([])
 
587
 
588
 
589
  def add_drag(tracking_points):
590
+ tracking_points.constructor_args["value"].append([])
591
  return tracking_points
592
 
593
 
594
  def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
595
+ tracking_points.constructor_args["value"].pop()
596
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
597
+ transparent_background_end = Image.open(last_frame_path).convert("RGBA")
598
  w, h = transparent_background.size
599
  transparent_layer = np.zeros((h, w, 4))
600
 
601
+ for track in tracking_points.constructor_args["value"]:
602
  if len(track) > 1:
603
+ for i in range(len(track) - 1):
604
  start_point = track[i]
605
+ end_point = track[i + 1]
606
  vx = end_point[0] - start_point[0]
607
  vy = end_point[1] - start_point[1]
608
  arrow_length = np.sqrt(vx**2 + vy**2)
609
+ if i == len(track) - 2:
610
+ cv2.arrowedLine(
611
+ transparent_layer,
612
+ tuple(start_point),
613
+ tuple(end_point),
614
+ (255, 0, 0, 255),
615
+ 2,
616
+ tipLength=8 / arrow_length,
617
+ )
618
  else:
619
+ cv2.line(
620
+ transparent_layer,
621
+ tuple(start_point),
622
+ tuple(end_point),
623
+ (255, 0, 0, 255),
624
+ 2,
625
+ )
626
  else:
627
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
628
 
 
634
 
635
 
636
  def delete_last_step(tracking_points, first_frame_path, last_frame_path):
637
+ tracking_points.constructor_args["value"][-1].pop()
638
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
639
+ transparent_background_end = Image.open(last_frame_path).convert("RGBA")
640
  w, h = transparent_background.size
641
  transparent_layer = np.zeros((h, w, 4))
642
 
643
+ for track in tracking_points.constructor_args["value"]:
644
  if len(track) > 1:
645
+ for i in range(len(track) - 1):
646
  start_point = track[i]
647
+ end_point = track[i + 1]
648
  vx = end_point[0] - start_point[0]
649
  vy = end_point[1] - start_point[1]
650
  arrow_length = np.sqrt(vx**2 + vy**2)
651
+ if i == len(track) - 2:
652
+ cv2.arrowedLine(
653
+ transparent_layer,
654
+ tuple(start_point),
655
+ tuple(end_point),
656
+ (255, 0, 0, 255),
657
+ 2,
658
+ tipLength=8 / arrow_length,
659
+ )
660
  else:
661
+ cv2.line(
662
+ transparent_layer,
663
+ tuple(start_point),
664
+ tuple(end_point),
665
+ (255, 0, 0, 255),
666
+ 2,
667
+ )
668
  else:
669
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
670
 
 
675
  return tracking_points, trajectory_map, trajectory_map_end
676
 
677
 
678
+ def add_tracking_points(
679
+ tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData
680
+ ): # SelectData is a subclass of EventData
681
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
682
+ tracking_points.constructor_args["value"][-1].append(evt.index)
683
 
684
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
685
+ transparent_background_end = Image.open(last_frame_path).convert("RGBA")
686
 
687
  w, h = transparent_background.size
688
  transparent_layer = 0
689
+ for idx, track in enumerate(tracking_points.constructor_args["value"]):
690
  # mask = cv2.imread(
691
  # os.path.join(args.output_dir, f"mask_{idx+1}.jpg")
692
  # )
693
  mask = np.zeros((320, 512, 3))
694
+ color = color_list[idx + 1]
695
  transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer
696
 
697
  if len(track) > 1:
698
+ for i in range(len(track) - 1):
699
  start_point = track[i]
700
+ end_point = track[i + 1]
701
  vx = end_point[0] - start_point[0]
702
  vy = end_point[1] - start_point[1]
703
  arrow_length = np.sqrt(vx**2 + vy**2)
704
+ if i == len(track) - 2:
705
+ cv2.arrowedLine(
706
+ transparent_layer,
707
+ tuple(start_point),
708
+ tuple(end_point),
709
+ (255, 0, 0, 255),
710
+ 2,
711
+ tipLength=8 / arrow_length,
712
+ )
713
  else:
714
+ cv2.line(
715
+ transparent_layer,
716
+ tuple(start_point),
717
+ tuple(end_point),
718
+ (255, 0, 0, 255),
719
+ 2,
720
+ )
721
  else:
722
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
723
 
 
737
 
738
  args = get_args()
739
  ensure_dirname(args.output_dir)
740
+
741
  color_list = []
742
  for i in range(20):
743
+ color = np.concatenate([np.random.random(4) * 255], axis=0)
744
  color_list.append(color)
745
 
746
  with gr.Blocks() as demo:
747
  gr.Markdown("""<h1 align="center">Framer: Interactive Frame Interpolation</h1><br>""")
748
+
749
+ gr.Markdown(
750
+ """Gradio Demo for <a href='https://arxiv.org/abs/2410.18978'><b>Framer: Interactive Frame Interpolation</b></a>.<br>
751
  Github Repo can be found at https://github.com/aim-uofa/Framer<br>
752
+ The template is inspired by DragAnything."""
753
+ )
754
+
755
  gr.Image(label="Framer: Interactive Frame Interpolation", value="assets/demos.gif", height=432, width=768)
756
+
757
+ gr.Markdown(
758
+ """## Usage: <br>
759
  1. Upload images<br>
760
  &ensp; 1.1 Upload the start image via the "Upload Start Image" button.<br>
761
  &ensp; 1.2. Upload the end image via the "Upload End Image" button.<br>
 
764
  &ensp; 2.2. You can click several points on either start or end image to forms a path.<br>
765
  &ensp; 2.3. Click "Delete last drag" to delete the whole lastest path.<br>
766
  &ensp; 2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
767
+ 3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
768
+ )
769
+
770
  # device, args, height, width, model_length
771
  Framer = Drag("cuda", args, 320, 512, 14)
772
  first_frame_path = gr.State()
773
  last_frame_path = gr.State()
774
  tracking_points = gr.State([])
775
+
776
  with gr.Row():
777
  with gr.Column(scale=1):
778
  image_upload_button = gr.UploadButton(label="Upload Start Image", file_types=["image"])
 
783
  run_button = gr.Button(value="Run")
784
  delete_last_drag_button = gr.Button(value="Delete last drag")
785
  delete_last_step_button = gr.Button(value="Delete last step")
786
+
787
  with gr.Column(scale=7):
788
  with gr.Row():
789
  with gr.Column(scale=6):
 
794
  width=512,
795
  sources=[],
796
  )
797
+
798
  with gr.Column(scale=6):
799
  input_image_end = gr.Image(
800
  label="end frame",
 
803
  width=512,
804
  sources=[],
805
  )
806
+
807
  with gr.Row():
808
  with gr.Column(scale=1):
809
+
810
  controlnet_cond_scale = gr.Slider(
811
+ label="Control Scale",
812
+ minimum=0.0,
813
+ maximum=10,
814
+ step=0.1,
815
  value=1.0,
816
  )
817
+
818
  motion_bucket_id = gr.Slider(
819
+ label="Motion Bucket",
820
+ minimum=1,
821
+ maximum=180,
822
+ step=1,
823
  value=100,
824
  )
825
+
826
  with gr.Column(scale=5):
827
  output_video = gr.Image(
828
  label="Output Video",
829
  height=320,
830
  width=1152,
831
  )
832
+
 
833
  with gr.Row():
834
+ gr.Markdown(
835
+ """
836
  ## Citation
837
  ```bibtex
838
  @article{wang2024framer,
 
842
  year={2024}
843
  }
844
  ```
845
+ """
846
+ )
847
+
848
+ image_upload_button.upload(
849
+ preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points]
850
+ )
851
+
852
+ image_end_upload_button.upload(
853
+ preprocess_image_end, image_end_upload_button, [input_image_end, last_frame_path, tracking_points]
854
+ )
855
+
856
+ add_drag_button.click(
857
+ add_drag,
858
+ tracking_points,
859
+ [
860
+ tracking_points,
861
+ ],
862
+ )
863
+
864
+ delete_last_drag_button.click(
865
+ delete_last_drag,
866
+ [tracking_points, first_frame_path, last_frame_path],
867
+ [tracking_points, input_image, input_image_end],
868
+ )
869
+
870
+ delete_last_step_button.click(
871
+ delete_last_step,
872
+ [tracking_points, first_frame_path, last_frame_path],
873
+ [tracking_points, input_image, input_image_end],
874
+ )
875
+
876
+ reset_button.click(
877
+ reset_states,
878
+ [first_frame_path, last_frame_path, tracking_points],
879
+ [first_frame_path, last_frame_path, tracking_points],
880
+ )
881
+
882
+ input_image.select(
883
+ add_tracking_points,
884
+ [tracking_points, first_frame_path, last_frame_path],
885
+ [tracking_points, input_image, input_image_end],
886
+ )
887
+
888
+ input_image_end.select(
889
+ add_tracking_points,
890
+ [tracking_points, first_frame_path, last_frame_path],
891
+ [tracking_points, input_image, input_image_end],
892
+ )
893
+
894
+ run_button.click(
895
+ Framer.run,
896
+ [first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
897
+ output_video,
898
+ )
899
+
900
  demo.launch()