Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +392 -362
  3. requirements.txt +278 -13
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸƒ
4
  colorFrom: gray
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.41.0
8
  python_version: 3.8.9
9
  app_file: app.py
10
  pinned: false
 
4
  colorFrom: gray
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.5.0
8
  python_version: 3.8.9
9
  app_file: app.py
10
  pinned: false
app.py CHANGED
@@ -1,94 +1,76 @@
1
- import spaces
2
  import datetime
 
 
3
  import uuid
4
- from PIL import Image
5
- import numpy as np
6
- import cv2
7
- from scipy.interpolate import interp1d, PchipInterpolator
8
- from packaging import version
9
 
 
 
 
 
10
  import torch
11
  import torchvision
12
- import gradio as gr
13
- # from moviepy.editor import *
14
- from diffusers.utils.import_utils import is_xformers_available
15
- from diffusers.utils import load_image, export_to_video, export_to_gif
16
 
17
- import os
18
- import sys
19
  sys.path.insert(0, os.getcwd())
 
 
20
  from models_diffusers.controlnet_svd import ControlNetSVDModel
21
  from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
22
  from pipelines.pipeline_stable_video_diffusion_interp_control import StableVideoDiffusionInterpControlPipeline
23
- from gradio_demo.utils_drag import *
24
 
25
- import warnings
26
  print("gr file", gr.__file__)
27
 
28
- from huggingface_hub import hf_hub_download, snapshot_download
29
 
30
  os.makedirs("checkpoints", exist_ok=True)
31
 
32
  snapshot_download(
33
  "wwen1997/framer_512x320",
34
  local_dir="checkpoints/framer_512x320",
35
- token=os.environ["TOKEN"],
36
  )
37
 
38
  snapshot_download(
39
  "stabilityai/stable-video-diffusion-img2vid-xt",
40
  local_dir="checkpoints/stable-video-diffusion-img2vid-xt",
41
- token=os.environ["TOKEN"],
42
  )
43
 
44
 
45
- def get_args():
46
- import argparse
47
- parser = argparse.ArgumentParser()
48
-
49
- parser.add_argument("--min_guidance_scale", type=float, default=1.0)
50
- parser.add_argument("--max_guidance_scale", type=float, default=3.0)
51
- parser.add_argument("--middle_max_guidance", type=int, default=0, choices=[0, 1])
52
- parser.add_argument("--with_control", type=int, default=1, choices=[0, 1])
53
-
54
- parser.add_argument("--controlnet_cond_scale", type=float, default=1.0)
55
-
56
- parser.add_argument(
57
- "--dataset",
58
- type=str,
59
- default='videoswap',
60
- )
61
-
62
- parser.add_argument(
63
- "--model", type=str,
64
- default="checkpoints/framer_512x320",
65
- help="Path to model.",
66
- )
67
-
68
- parser.add_argument("--output_dir", type=str, default="gradio_demo/outputs", help="Path to the output video.")
69
-
70
- parser.add_argument("--seed", type=int, default=42, help="random seed.")
71
 
72
- parser.add_argument("--noise_aug", type=float, default=0.02)
 
 
 
 
73
 
74
- parser.add_argument("--num_frames", type=int, default=14)
75
- parser.add_argument("--frame_interval", type=int, default=2)
76
 
77
- parser.add_argument("--width", type=int, default=512)
78
- parser.add_argument("--height", type=int, default=320)
79
-
80
- parser.add_argument(
81
- "--num_workers",
82
- type=int,
83
- default=0,
84
- help=(
85
- "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
86
- ),
87
- )
88
-
89
- args = parser.parse_args()
90
 
91
- return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
 
94
  def interpolate_trajectory(points, n_points):
@@ -113,27 +95,34 @@ def interpolate_trajectory(points, n_points):
113
 
114
  def gen_gaussian_heatmap(imgSize=200):
115
  circle_img = np.zeros((imgSize, imgSize), np.float32)
116
- circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)
117
 
118
  isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
119
 
120
  for i in range(imgSize):
121
  for j in range(imgSize):
122
- isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
123
- -1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))
 
 
 
 
 
124
 
125
  isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
126
  isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
127
- isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)
128
 
129
  return isotropicGrayscaleImage
130
 
131
 
132
  def get_vis_image(
133
- target_size=(512 , 512), points=None, side=20,
134
- num_frames=14,
135
- # original_size=(512 , 512), args="", first_frame=None, is_mask = False, model_id=None,
136
- ):
 
 
137
 
138
  # images = []
139
  vis_images = []
@@ -141,13 +130,13 @@ def get_vis_image(
141
 
142
  trajectory_list = []
143
  radius_list = []
144
-
145
  for index, point in enumerate(points):
146
  trajectories = [[int(i[0]), int(i[1])] for i in point]
147
  trajectory_list.append(trajectories)
148
 
149
  radius = 20
150
- radius_list.append(radius)
151
 
152
  if len(trajectory_list) == 0:
153
  vis_images = [Image.fromarray(np.zeros(target_size, np.uint8)) for _ in range(num_frames)]
@@ -157,33 +146,39 @@ def get_vis_image(
157
  new_img = np.zeros(target_size, np.uint8)
158
  vis_img = new_img.copy()
159
  # ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
160
-
161
- if idxx >= args.num_frames:
162
  break
163
 
164
  # for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
165
  for cc, (trajectory, radius) in enumerate(zip(trajectory_list, radius_list)):
166
-
167
  center_coordinate = trajectory[idxx]
168
  trajectory_ = trajectory[:idxx]
169
  side = min(radius, 50)
170
-
171
- y1 = max(center_coordinate[1] - side,0)
172
  y2 = min(center_coordinate[1] + side, target_size[0] - 1)
173
  x1 = max(center_coordinate[0] - side, 0)
174
  x2 = min(center_coordinate[0] + side, target_size[1] - 1)
175
-
176
- if x2-x1>3 and y2-y1>3:
177
- need_map = cv2.resize(heatmap, (x2-x1, y2-y1))
178
  new_img[y1:y2, x1:x2] = need_map.copy()
179
-
180
  if cc >= 0:
181
- vis_img[y1:y2,x1:x2] = need_map.copy()
182
  if len(trajectory_) == 1:
183
  vis_img[trajectory_[0][1], trajectory_[0][0]] = 255
184
  else:
185
- for itt in range(len(trajectory_)-1):
186
- cv2.line(vis_img, (trajectory_[itt][0], trajectory_[itt][1]), (trajectory_[itt+1][0], trajectory_[itt+1][1]), (255, 255, 255), 3)
 
 
 
 
 
 
187
 
188
  img = new_img
189
 
@@ -194,7 +189,7 @@ def get_vis_image(
194
  elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format
195
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
196
  vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
197
-
198
  # Convert the numpy array to a PIL image
199
  # pil_img = Image.fromarray(img)
200
  # images.append(pil_img)
@@ -215,7 +210,7 @@ def frames_to_video(frames_folder, output_video_path, fps=7):
215
  video.append(frame)
216
 
217
  video = torch.stack(video)
218
- video = rearrange(video, 'T C H W -> T H W C')
219
  torchvision.io.write_video(output_video_path, video, fps=fps)
220
 
221
 
@@ -223,11 +218,12 @@ def save_gifs_side_by_side(
223
  batch_output,
224
  validation_control_images,
225
  output_folder,
226
- target_size=(512 , 512),
227
  duration=200,
228
  point_tracks=None,
229
  ):
230
  flattened_batch_output = batch_output
 
231
  def create_gif(image_list, gif_path, duration=100):
232
  pil_images = [validate_and_convert_image(img, target_size=target_size) for img in image_list]
233
  pil_images = [img for img in pil_images if img is not None]
@@ -243,7 +239,7 @@ def save_gifs_side_by_side(
243
  tmp_frame_path = os.path.join(tmp_folder, f"{idx}.png")
244
  pil_image.save(tmp_frame_path)
245
  tmp_frame_list.append(tmp_frame_path)
246
-
247
  # also save as mp4
248
  output_video_path = gif_path.replace(".gif", ".mp4")
249
  frames_to_video(tmp_folder, output_video_path, fps=7)
@@ -286,25 +282,25 @@ def save_gifs_side_by_side(
286
  if output_path.endswith(".mp4"):
287
  video = [torchvision.transforms.functional.pil_to_tensor(frame) for frame in frames]
288
  video = torch.stack(video)
289
- video = rearrange(video, 'T C H W -> T H W C')
290
  torchvision.io.write_video(output_path, video, fps=7)
291
  print(f"Saved video to {output_path}")
292
  else:
293
  frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration)
294
-
295
  # Helper function to concatenate images horizontally
296
  def get_concat_h(im1, im2, gap=10):
297
  # # img first, heatmap second
298
  # im1, im2 = im2, im1
299
 
300
- dst = Image.new('RGB', (im1.width + im2.width + gap, max(im1.height, im2.height)), (255, 255, 255))
301
  dst.paste(im1, (0, 0))
302
  dst.paste(im2, (im1.width + gap, 0))
303
  return dst
304
 
305
  # Helper function to concatenate images vertically
306
  def get_concat_v(im1, im2):
307
- dst = Image.new('RGB', (max(im1.width, im2.width), im1.height + im2.height))
308
  dst.paste(im1, (0, 0))
309
  dst.paste(im2, (0, im1.height))
310
  return dst
@@ -325,7 +321,7 @@ def save_gifs_side_by_side(
325
 
326
 
327
  # Define functions
328
- def validate_and_convert_image(image, target_size=(512 , 512)):
329
  if image is None:
330
  print("Encountered a None image")
331
  return None
@@ -346,192 +342,12 @@ def validate_and_convert_image(image, target_size=(512 , 512)):
346
  else:
347
  print("Image is not a PIL Image or a PyTorch tensor")
348
  return None
349
-
350
- return image
351
-
352
-
353
- class Drag:
354
-
355
- @spaces.GPU
356
- def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
357
- self.device = device
358
- self.dtype = dtype
359
-
360
- unet = UNetSpatioTemporalConditionModel.from_pretrained(
361
- os.path.join(args.model, "unet"),
362
- torch_dtype=torch.float16,
363
- low_cpu_mem_usage=True,
364
- custom_resume=True,
365
- )
366
- unet = unet.to(device, dtype)
367
-
368
- controlnet = ControlNetSVDModel.from_pretrained(
369
- os.path.join(args.model, "controlnet"),
370
- )
371
- controlnet = controlnet.to(device, dtype)
372
-
373
- if is_xformers_available():
374
- import xformers
375
- xformers_version = version.parse(xformers.__version__)
376
- unet.enable_xformers_memory_efficient_attention()
377
- # controlnet.enable_xformers_memory_efficient_attention()
378
- else:
379
- raise ValueError(
380
- "xformers is not available. Make sure it is installed correctly")
381
-
382
- pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
383
- "checkpoints/stable-video-diffusion-img2vid-xt",
384
- unet=unet,
385
- controlnet=controlnet,
386
- low_cpu_mem_usage=False,
387
- torch_dtype=torch.float16, variant="fp16", local_files_only=True,
388
- )
389
- pipe.to(device)
390
-
391
- self.pipeline = pipe
392
- # self.pipeline.enable_model_cpu_offload()
393
-
394
- self.height = height
395
- self.width = width
396
- self.args = args
397
- self.model_length = model_length
398
- self.use_sift = use_sift
399
-
400
- @spaces.GPU
401
- def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
402
- original_width, original_height = 512, 320 # TODO
403
-
404
- # load_image
405
- image = Image.open(first_frame_path).convert('RGB')
406
- width, height = image.size
407
- image = image.resize((self.width, self.height))
408
-
409
- image_end = Image.open(last_frame_path).convert('RGB')
410
- image_end = image_end.resize((self.width, self.height))
411
-
412
- input_all_points = tracking_points.constructor_args['value']
413
-
414
- sift_track_update = False
415
- anchor_points_flag = None
416
-
417
- if (len(input_all_points) == 0) and self.use_sift:
418
- sift_track_update = True
419
- controlnet_cond_scale = 0.5
420
-
421
- from models_diffusers.sift_match import sift_match
422
- from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
423
-
424
- output_file_sift = os.path.join(args.output_dir, "sift.png")
425
-
426
- # (f, topk, 2), f=2 (before interpolation)
427
- pred_tracks = sift_match(
428
- image,
429
- image_end,
430
- thr=0.5,
431
- topk=5,
432
- method="random",
433
- output_path=output_file_sift,
434
- )
435
-
436
- if pred_tracks is not None:
437
- # interpolate the tracks, following draganything gradio demo
438
- pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=self.model_length)
439
-
440
- anchor_points_flag = torch.zeros((self.model_length, pred_tracks.shape[1])).to(pred_tracks.device)
441
- anchor_points_flag[0] = 1
442
- anchor_points_flag[-1] = 1
443
-
444
- pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
445
-
446
- else:
447
-
448
- resized_all_points = [
449
- tuple([
450
- tuple([int(e1[0] * self.width / original_width), int(e1[1] * self.height / original_height)])
451
- for e1 in e])
452
- for e in input_all_points
453
- ]
454
-
455
- # a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
456
- # in image w & h scale
457
-
458
- for idx, splited_track in enumerate(resized_all_points):
459
- if len(splited_track) == 0:
460
- warnings.warn("running without point trajectory control")
461
- continue
462
-
463
- if len(splited_track) == 1: # stationary point
464
- displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
465
- splited_track = tuple([splited_track[0], displacement_point])
466
- # interpolate the track
467
- splited_track = interpolate_trajectory(splited_track, self.model_length)
468
- splited_track = splited_track[:self.model_length]
469
- resized_all_points[idx] = splited_track
470
-
471
- pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
472
-
473
- vis_images = get_vis_image(
474
- target_size=(self.args.height, self.args.width),
475
- points=pred_tracks,
476
- num_frames=self.model_length,
477
- )
478
-
479
- if len(pred_tracks.shape) != 3:
480
- print("pred_tracks.shape", pred_tracks.shape)
481
- with_control = False
482
- controlnet_cond_scale = 0.0
483
- else:
484
- with_control = True
485
- pred_tracks = pred_tracks.permute(1, 0, 2).to(self.device, self.dtype) # (num_frames, num_points, 2)
486
-
487
- point_embedding = None
488
- video_frames = self.pipeline(
489
- image,
490
- image_end,
491
- # trajectory control
492
- with_control=with_control,
493
- point_tracks=pred_tracks,
494
- point_embedding=point_embedding,
495
- with_id_feature=False,
496
- controlnet_cond_scale=controlnet_cond_scale,
497
- # others
498
- num_frames=14,
499
- width=width,
500
- height=height,
501
- # decode_chunk_size=8,
502
- # generator=generator,
503
- motion_bucket_id=motion_bucket_id,
504
- fps=7,
505
- num_inference_steps=30,
506
- # track
507
- sift_track_update=sift_track_update,
508
- anchor_points_flag=anchor_points_flag,
509
- ).frames[0]
510
-
511
- vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
512
- vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
513
- vis_images = [Image.fromarray(img) for img in vis_images]
514
-
515
- # video_frames = [img for sublist in video_frames for img in sublist]
516
- val_save_dir = os.path.join(args.output_dir, "vis_gif.gif")
517
- save_gifs_side_by_side(
518
- video_frames,
519
- vis_images[:self.model_length],
520
- val_save_dir,
521
- target_size=(self.width, self.height),
522
- duration=110,
523
- point_tracks=pred_tracks,
524
- )
525
-
526
- return val_save_dir
527
 
 
528
 
529
- def reset_states(first_frame_path, last_frame_path, tracking_points):
530
- first_frame_path = gr.State()
531
- last_frame_path = gr.State()
532
- tracking_points = gr.State([])
533
 
534
- return first_frame_path, last_frame_path, tracking_points
 
535
 
536
 
537
  def preprocess_image(image):
@@ -544,11 +360,11 @@ def preprocess_image(image):
544
  # image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
545
  image_pil = image_pil.resize((512, 320), Image.BILINEAR)
546
 
547
- first_frame_path = os.path.join(args.output_dir, f"first_frame_{str(uuid.uuid4())[:4]}.png")
548
-
549
  image_pil.save(first_frame_path)
550
 
551
- return first_frame_path, first_frame_path, gr.State([])
552
 
553
 
554
  def preprocess_image_end(image_end):
@@ -561,37 +377,52 @@ def preprocess_image_end(image_end):
561
  # image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
562
  image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
563
 
564
- last_frame_path = os.path.join(args.output_dir, f"last_frame_{str(uuid.uuid4())[:4]}.png")
565
 
566
  image_end_pil.save(last_frame_path)
567
 
568
- return last_frame_path, last_frame_path, gr.State([])
569
 
570
 
571
  def add_drag(tracking_points):
572
- tracking_points.constructor_args['value'].append([])
 
573
  return tracking_points
574
 
575
 
576
  def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
577
- tracking_points.constructor_args['value'].pop()
578
- transparent_background = Image.open(first_frame_path).convert('RGBA')
579
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
 
580
  w, h = transparent_background.size
581
  transparent_layer = np.zeros((h, w, 4))
582
 
583
- for track in tracking_points.constructor_args['value']:
584
  if len(track) > 1:
585
- for i in range(len(track)-1):
586
  start_point = track[i]
587
- end_point = track[i+1]
588
  vx = end_point[0] - start_point[0]
589
  vy = end_point[1] - start_point[1]
590
  arrow_length = np.sqrt(vx**2 + vy**2)
591
- if i == len(track)-2:
592
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
 
 
 
 
 
 
 
593
  else:
594
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
 
 
 
 
 
 
595
  else:
596
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
597
 
@@ -603,24 +434,40 @@ def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
603
 
604
 
605
  def delete_last_step(tracking_points, first_frame_path, last_frame_path):
606
- tracking_points.constructor_args['value'][-1].pop()
607
- transparent_background = Image.open(first_frame_path).convert('RGBA')
608
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
 
609
  w, h = transparent_background.size
610
  transparent_layer = np.zeros((h, w, 4))
611
 
612
- for track in tracking_points.constructor_args['value']:
 
 
613
  if len(track) > 1:
614
- for i in range(len(track)-1):
615
  start_point = track[i]
616
- end_point = track[i+1]
617
  vx = end_point[0] - start_point[0]
618
  vy = end_point[1] - start_point[1]
619
  arrow_length = np.sqrt(vx**2 + vy**2)
620
- if i == len(track)-2:
621
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
 
 
 
 
 
 
 
622
  else:
623
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
 
 
 
 
 
 
624
  else:
625
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
626
 
@@ -631,34 +478,51 @@ def delete_last_step(tracking_points, first_frame_path, last_frame_path):
631
  return tracking_points, trajectory_map, trajectory_map_end
632
 
633
 
634
- def add_tracking_points(tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData): # SelectData is a subclass of EventData
 
 
635
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
636
- tracking_points.constructor_args['value'][-1].append(evt.index)
 
 
637
 
638
- transparent_background = Image.open(first_frame_path).convert('RGBA')
639
- transparent_background_end = Image.open(last_frame_path).convert('RGBA')
640
 
641
  w, h = transparent_background.size
642
  transparent_layer = 0
643
- for idx, track in enumerate(tracking_points.constructor_args['value']):
644
  # mask = cv2.imread(
645
- # os.path.join(args.output_dir, f"mask_{idx+1}.jpg")
646
  # )
647
  mask = np.zeros((320, 512, 3))
648
- color = color_list[idx+1]
649
  transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer
650
 
651
  if len(track) > 1:
652
- for i in range(len(track)-1):
653
  start_point = track[i]
654
- end_point = track[i+1]
655
  vx = end_point[0] - start_point[0]
656
  vy = end_point[1] - start_point[1]
657
  arrow_length = np.sqrt(vx**2 + vy**2)
658
- if i == len(track)-2:
659
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
 
 
 
 
 
 
 
660
  else:
661
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
 
 
 
 
 
 
662
  else:
663
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
664
 
@@ -674,26 +538,162 @@ def add_tracking_points(tracking_points, first_frame_path, last_frame_path, evt:
674
  return tracking_points, trajectory_map, trajectory_map_end
675
 
676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
  if __name__ == "__main__":
678
 
679
- args = get_args()
680
- ensure_dirname(args.output_dir)
681
-
682
  color_list = []
683
  for i in range(20):
684
- color = np.concatenate([np.random.random(4)*255], axis=0)
685
  color_list.append(color)
686
 
687
  with gr.Blocks() as demo:
688
  gr.Markdown("""<h1 align="center">Framer: Interactive Frame Interpolation</h1><br>""")
689
-
690
- gr.Markdown("""Gradio Demo for <a href='https://arxiv.org/abs/2410.18978'><b>Framer: Interactive Frame Interpolation</b></a>.<br>
 
691
  Github Repo can be found at https://github.com/aim-uofa/Framer<br>
692
- The template is inspired by DragAnything.""")
693
-
 
694
  gr.Image(label="Framer: Interactive Frame Interpolation", value="assets/demos.gif", height=432, width=768)
695
-
696
- gr.Markdown("""## Usage: <br>
 
697
  1. Upload images<br>
698
  &ensp; 1.1 Upload the start image via the "Upload Start Image" button.<br>
699
  &ensp; 1.2. Upload the end image via the "Upload End Image" button.<br>
@@ -702,14 +702,13 @@ if __name__ == "__main__":
702
  &ensp; 2.2. You can click several points on either start or end image to forms a path.<br>
703
  &ensp; 2.3. Click "Delete last drag" to delete the whole lastest path.<br>
704
  &ensp; 2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
705
- 3. Interpolate the images (according the path) with a click on "Run" button. <br>""")
706
-
707
- # device, args, height, width, model_length
708
- Framer = Drag("cuda", args, 320, 512, 14)
709
  first_frame_path = gr.State()
710
  last_frame_path = gr.State()
711
  tracking_points = gr.State([])
712
-
713
  with gr.Row():
714
  with gr.Column(scale=1):
715
  image_upload_button = gr.UploadButton(label="Upload Start Image", file_types=["image"])
@@ -720,7 +719,7 @@ if __name__ == "__main__":
720
  run_button = gr.Button(value="Run")
721
  delete_last_drag_button = gr.Button(value="Delete last drag")
722
  delete_last_step_button = gr.Button(value="Delete last step")
723
-
724
  with gr.Column(scale=7):
725
  with gr.Row():
726
  with gr.Column(scale=6):
@@ -731,7 +730,7 @@ if __name__ == "__main__":
731
  width=512,
732
  sources=[],
733
  )
734
-
735
  with gr.Column(scale=6):
736
  input_image_end = gr.Image(
737
  label="end frame",
@@ -740,36 +739,36 @@ if __name__ == "__main__":
740
  width=512,
741
  sources=[],
742
  )
743
-
744
  with gr.Row():
745
  with gr.Column(scale=1):
746
-
747
  controlnet_cond_scale = gr.Slider(
748
- label='Control Scale',
749
- minimum=0.0,
750
- maximum=10,
751
- step=0.1,
752
  value=1.0,
753
  )
754
-
755
  motion_bucket_id = gr.Slider(
756
- label='Motion Bucket',
757
- minimum=1,
758
- maximum=180,
759
- step=1,
760
  value=100,
761
  )
762
-
763
  with gr.Column(scale=5):
764
  output_video = gr.Image(
765
  label="Output Video",
766
  height=320,
767
  width=1152,
768
  )
769
-
770
-
771
  with gr.Row():
772
- gr.Markdown("""
 
773
  ## Citation
774
  ```bibtex
775
  @article{wang2024framer,
@@ -779,24 +778,55 @@ if __name__ == "__main__":
779
  year={2024}
780
  }
781
  ```
782
- """)
783
-
784
- image_upload_button.upload(preprocess_image, image_upload_button, [input_image, first_frame_path, tracking_points])
785
-
786
- image_end_upload_button.upload(preprocess_image_end, image_end_upload_button, [input_image_end, last_frame_path, tracking_points])
787
-
788
- add_drag_button.click(add_drag, tracking_points, [tracking_points, ])
789
-
790
- delete_last_drag_button.click(delete_last_drag, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
791
-
792
- delete_last_step_button.click(delete_last_step, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
793
-
794
- reset_button.click(reset_states, [first_frame_path, last_frame_path, tracking_points], [first_frame_path, last_frame_path, tracking_points])
795
-
796
- input_image.select(add_tracking_points, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
797
-
798
- input_image_end.select(add_tracking_points, [tracking_points, first_frame_path, last_frame_path], [tracking_points, input_image, input_image_end])
799
-
800
- run_button.click(Framer.run, [first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id], output_video)
801
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
  demo.launch()
 
 
1
  import datetime
2
+ import os
3
+ import sys
4
  import uuid
5
+ import warnings
 
 
 
 
6
 
7
+ import cv2
8
+ import gradio as gr
9
+ import numpy as np
10
+ import spaces
11
  import torch
12
  import torchvision
13
+ from huggingface_hub import snapshot_download
14
+ from PIL import Image
15
+ from scipy.interpolate import PchipInterpolator
 
16
 
 
 
17
  sys.path.insert(0, os.getcwd())
18
+
19
+ from gradio_demo.utils_drag import *
20
  from models_diffusers.controlnet_svd import ControlNetSVDModel
21
  from models_diffusers.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
22
  from pipelines.pipeline_stable_video_diffusion_interp_control import StableVideoDiffusionInterpControlPipeline
 
23
 
 
24
  print("gr file", gr.__file__)
25
 
 
26
 
27
  os.makedirs("checkpoints", exist_ok=True)
28
 
29
  snapshot_download(
30
  "wwen1997/framer_512x320",
31
  local_dir="checkpoints/framer_512x320",
 
32
  )
33
 
34
  snapshot_download(
35
  "stabilityai/stable-video-diffusion-img2vid-xt",
36
  local_dir="checkpoints/stable-video-diffusion-img2vid-xt",
 
37
  )
38
 
39
 
40
+ model_id = "checkpoints/framer_512x320"
41
+ device = "cuda"
42
+ dtype = torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ OUTPUT_DIR = "gradio_demo/outputs"
45
+ HEIGHT = 320
46
+ WIDTH = 512
47
+ MODEL_LENGTH = 14
48
+ USE_SIFT = False
49
 
 
 
50
 
51
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
52
+ os.path.join(model_id, "unet"),
53
+ torch_dtype=torch.float16,
54
+ low_cpu_mem_usage=True,
55
+ custom_resume=True,
56
+ )
57
+ unet = unet.to(device, dtype)
 
 
 
 
 
 
58
 
59
+ controlnet = ControlNetSVDModel.from_pretrained(
60
+ os.path.join(model_id, "controlnet"),
61
+ )
62
+ controlnet = controlnet.to(device, dtype)
63
+
64
+ pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
65
+ "checkpoints/stable-video-diffusion-img2vid-xt",
66
+ unet=unet,
67
+ controlnet=controlnet,
68
+ low_cpu_mem_usage=False,
69
+ torch_dtype=torch.float16,
70
+ variant="fp16",
71
+ local_files_only=True,
72
+ )
73
+ pipe.to(device)
74
 
75
 
76
  def interpolate_trajectory(points, n_points):
 
95
 
96
  def gen_gaussian_heatmap(imgSize=200):
97
  circle_img = np.zeros((imgSize, imgSize), np.float32)
98
+ circle_mask = cv2.circle(circle_img, (imgSize // 2, imgSize // 2), imgSize // 2, 1, -1)
99
 
100
  isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
101
 
102
  for i in range(imgSize):
103
  for j in range(imgSize):
104
+ isotropicGrayscaleImage[i, j] = (
105
+ 1
106
+ / 2
107
+ / np.pi
108
+ / (40**2)
109
+ * np.exp(-1 / 2 * ((i - imgSize / 2) ** 2 / (40**2) + (j - imgSize / 2) ** 2 / (40**2)))
110
+ )
111
 
112
  isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
113
  isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
114
+ isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage) * 255).astype(np.uint8)
115
 
116
  return isotropicGrayscaleImage
117
 
118
 
119
  def get_vis_image(
120
+ target_size=(512, 512),
121
+ points=None,
122
+ side=20,
123
+ num_frames=14,
124
+ # original_size=(512 , 512), args="", first_frame=None, is_mask = False, model_id=None,
125
+ ):
126
 
127
  # images = []
128
  vis_images = []
 
130
 
131
  trajectory_list = []
132
  radius_list = []
133
+
134
  for index, point in enumerate(points):
135
  trajectories = [[int(i[0]), int(i[1])] for i in point]
136
  trajectory_list.append(trajectories)
137
 
138
  radius = 20
139
+ radius_list.append(radius)
140
 
141
  if len(trajectory_list) == 0:
142
  vis_images = [Image.fromarray(np.zeros(target_size, np.uint8)) for _ in range(num_frames)]
 
146
  new_img = np.zeros(target_size, np.uint8)
147
  vis_img = new_img.copy()
148
  # ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
149
+
150
+ if idxx >= num_frames:
151
  break
152
 
153
  # for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
154
  for cc, (trajectory, radius) in enumerate(zip(trajectory_list, radius_list)):
155
+
156
  center_coordinate = trajectory[idxx]
157
  trajectory_ = trajectory[:idxx]
158
  side = min(radius, 50)
159
+
160
+ y1 = max(center_coordinate[1] - side, 0)
161
  y2 = min(center_coordinate[1] + side, target_size[0] - 1)
162
  x1 = max(center_coordinate[0] - side, 0)
163
  x2 = min(center_coordinate[0] + side, target_size[1] - 1)
164
+
165
+ if x2 - x1 > 3 and y2 - y1 > 3:
166
+ need_map = cv2.resize(heatmap, (x2 - x1, y2 - y1))
167
  new_img[y1:y2, x1:x2] = need_map.copy()
168
+
169
  if cc >= 0:
170
+ vis_img[y1:y2, x1:x2] = need_map.copy()
171
  if len(trajectory_) == 1:
172
  vis_img[trajectory_[0][1], trajectory_[0][0]] = 255
173
  else:
174
+ for itt in range(len(trajectory_) - 1):
175
+ cv2.line(
176
+ vis_img,
177
+ (trajectory_[itt][0], trajectory_[itt][1]),
178
+ (trajectory_[itt + 1][0], trajectory_[itt + 1][1]),
179
+ (255, 255, 255),
180
+ 3,
181
+ )
182
 
183
  img = new_img
184
 
 
189
  elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format
190
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
191
  vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
192
+
193
  # Convert the numpy array to a PIL image
194
  # pil_img = Image.fromarray(img)
195
  # images.append(pil_img)
 
210
  video.append(frame)
211
 
212
  video = torch.stack(video)
213
+ video = rearrange(video, "T C H W -> T H W C")
214
  torchvision.io.write_video(output_video_path, video, fps=fps)
215
 
216
 
 
218
  batch_output,
219
  validation_control_images,
220
  output_folder,
221
+ target_size=(512, 512),
222
  duration=200,
223
  point_tracks=None,
224
  ):
225
  flattened_batch_output = batch_output
226
+
227
  def create_gif(image_list, gif_path, duration=100):
228
  pil_images = [validate_and_convert_image(img, target_size=target_size) for img in image_list]
229
  pil_images = [img for img in pil_images if img is not None]
 
239
  tmp_frame_path = os.path.join(tmp_folder, f"{idx}.png")
240
  pil_image.save(tmp_frame_path)
241
  tmp_frame_list.append(tmp_frame_path)
242
+
243
  # also save as mp4
244
  output_video_path = gif_path.replace(".gif", ".mp4")
245
  frames_to_video(tmp_folder, output_video_path, fps=7)
 
282
  if output_path.endswith(".mp4"):
283
  video = [torchvision.transforms.functional.pil_to_tensor(frame) for frame in frames]
284
  video = torch.stack(video)
285
+ video = rearrange(video, "T C H W -> T H W C")
286
  torchvision.io.write_video(output_path, video, fps=7)
287
  print(f"Saved video to {output_path}")
288
  else:
289
  frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration)
290
+
291
  # Helper function to concatenate images horizontally
292
  def get_concat_h(im1, im2, gap=10):
293
  # # img first, heatmap second
294
  # im1, im2 = im2, im1
295
 
296
+ dst = Image.new("RGB", (im1.width + im2.width + gap, max(im1.height, im2.height)), (255, 255, 255))
297
  dst.paste(im1, (0, 0))
298
  dst.paste(im2, (im1.width + gap, 0))
299
  return dst
300
 
301
  # Helper function to concatenate images vertically
302
  def get_concat_v(im1, im2):
303
+ dst = Image.new("RGB", (max(im1.width, im2.width), im1.height + im2.height))
304
  dst.paste(im1, (0, 0))
305
  dst.paste(im2, (0, im1.height))
306
  return dst
 
321
 
322
 
323
  # Define functions
324
+ def validate_and_convert_image(image, target_size=(512, 512)):
325
  if image is None:
326
  print("Encountered a None image")
327
  return None
 
342
  else:
343
  print("Image is not a PIL Image or a PyTorch tensor")
344
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
+ return image
347
 
 
 
 
 
348
 
349
+ def reset_states():
350
+ return None, None, None, None, None, []
351
 
352
 
353
  def preprocess_image(image):
 
360
  # image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
361
  image_pil = image_pil.resize((512, 320), Image.BILINEAR)
362
 
363
+ first_frame_path = os.path.join(OUTPUT_DIR, f"first_frame_{str(uuid.uuid4())[:4]}.png")
364
+
365
  image_pil.save(first_frame_path)
366
 
367
+ return first_frame_path, first_frame_path, []
368
 
369
 
370
  def preprocess_image_end(image_end):
 
377
  # image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
378
  image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
379
 
380
+ last_frame_path = os.path.join(OUTPUT_DIR, f"last_frame_{str(uuid.uuid4())[:4]}.png")
381
 
382
  image_end_pil.save(last_frame_path)
383
 
384
+ return last_frame_path, last_frame_path, []
385
 
386
 
387
  def add_drag(tracking_points):
388
+ if not tracking_points or tracking_points[-1]:
389
+ tracking_points.append([])
390
  return tracking_points
391
 
392
 
393
  def delete_last_drag(tracking_points, first_frame_path, last_frame_path):
394
+ if tracking_points:
395
+ tracking_points.pop()
396
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
397
+ transparent_background_end = Image.open(last_frame_path).convert("RGBA")
398
  w, h = transparent_background.size
399
  transparent_layer = np.zeros((h, w, 4))
400
 
401
+ for track in tracking_points:
402
  if len(track) > 1:
403
+ for i in range(len(track) - 1):
404
  start_point = track[i]
405
+ end_point = track[i + 1]
406
  vx = end_point[0] - start_point[0]
407
  vy = end_point[1] - start_point[1]
408
  arrow_length = np.sqrt(vx**2 + vy**2)
409
+ if i == len(track) - 2:
410
+ cv2.arrowedLine(
411
+ transparent_layer,
412
+ tuple(start_point),
413
+ tuple(end_point),
414
+ (255, 0, 0, 255),
415
+ 2,
416
+ tipLength=8 / arrow_length,
417
+ )
418
  else:
419
+ cv2.line(
420
+ transparent_layer,
421
+ tuple(start_point),
422
+ tuple(end_point),
423
+ (255, 0, 0, 255),
424
+ 2,
425
+ )
426
  else:
427
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
428
 
 
434
 
435
 
436
  def delete_last_step(tracking_points, first_frame_path, last_frame_path):
437
+ if tracking_points and tracking_points[-1]:
438
+ tracking_points[-1].pop()
439
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
440
+ transparent_background_end = Image.open(last_frame_path).convert("RGBA")
441
  w, h = transparent_background.size
442
  transparent_layer = np.zeros((h, w, 4))
443
 
444
+ for track in tracking_points:
445
+ if not track:
446
+ continue
447
  if len(track) > 1:
448
+ for i in range(len(track) - 1):
449
  start_point = track[i]
450
+ end_point = track[i + 1]
451
  vx = end_point[0] - start_point[0]
452
  vy = end_point[1] - start_point[1]
453
  arrow_length = np.sqrt(vx**2 + vy**2)
454
+ if i == len(track) - 2:
455
+ cv2.arrowedLine(
456
+ transparent_layer,
457
+ tuple(start_point),
458
+ tuple(end_point),
459
+ (255, 0, 0, 255),
460
+ 2,
461
+ tipLength=8 / arrow_length,
462
+ )
463
  else:
464
+ cv2.line(
465
+ transparent_layer,
466
+ tuple(start_point),
467
+ tuple(end_point),
468
+ (255, 0, 0, 255),
469
+ 2,
470
+ )
471
  else:
472
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
473
 
 
478
  return tracking_points, trajectory_map, trajectory_map_end
479
 
480
 
481
+ def add_tracking_points(
482
+ tracking_points, first_frame_path, last_frame_path, evt: gr.SelectData
483
+ ): # SelectData is a subclass of EventData
484
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
485
+ if not tracking_points:
486
+ tracking_points = [[]]
487
+ tracking_points[-1].append(evt.index)
488
 
489
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
490
+ transparent_background_end = Image.open(last_frame_path).convert("RGBA")
491
 
492
  w, h = transparent_background.size
493
  transparent_layer = 0
494
+ for idx, track in enumerate(tracking_points):
495
  # mask = cv2.imread(
496
+ # os.path.join(OUTPUT_DIR, f"mask_{idx+1}.jpg")
497
  # )
498
  mask = np.zeros((320, 512, 3))
499
+ color = color_list[idx + 1]
500
  transparent_layer = mask[:, :, 0].reshape(h, w, 1) * color.reshape(1, 1, -1) + transparent_layer
501
 
502
  if len(track) > 1:
503
+ for i in range(len(track) - 1):
504
  start_point = track[i]
505
+ end_point = track[i + 1]
506
  vx = end_point[0] - start_point[0]
507
  vy = end_point[1] - start_point[1]
508
  arrow_length = np.sqrt(vx**2 + vy**2)
509
+ if i == len(track) - 2:
510
+ cv2.arrowedLine(
511
+ transparent_layer,
512
+ tuple(start_point),
513
+ tuple(end_point),
514
+ (255, 0, 0, 255),
515
+ 2,
516
+ tipLength=8 / arrow_length,
517
+ )
518
  else:
519
+ cv2.line(
520
+ transparent_layer,
521
+ tuple(start_point),
522
+ tuple(end_point),
523
+ (255, 0, 0, 255),
524
+ 2,
525
+ )
526
  else:
527
  cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
528
 
 
538
  return tracking_points, trajectory_map, trajectory_map_end
539
 
540
 
541
+ @spaces.GPU
542
+ def run(
543
+ first_frame_path,
544
+ last_frame_path,
545
+ tracking_points,
546
+ controlnet_cond_scale,
547
+ motion_bucket_id,
548
+ progress=gr.Progress(track_tqdm=True),
549
+ ):
550
+ original_width, original_height = 512, 320 # TODO
551
+
552
+ # load_image
553
+ image = Image.open(first_frame_path).convert("RGB")
554
+ width, height = image.size
555
+ image = image.resize((WIDTH, HEIGHT))
556
+
557
+ image_end = Image.open(last_frame_path).convert("RGB")
558
+ image_end = image_end.resize((WIDTH, HEIGHT))
559
+
560
+ input_all_points = tracking_points
561
+
562
+ sift_track_update = False
563
+ anchor_points_flag = None
564
+
565
+ if (len(input_all_points) == 0) and USE_SIFT:
566
+ sift_track_update = True
567
+ controlnet_cond_scale = 0.5
568
+
569
+ from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
570
+ from models_diffusers.sift_match import sift_match
571
+
572
+ output_file_sift = os.path.join(OUTPUT_DIR, "sift.png")
573
+
574
+ # (f, topk, 2), f=2 (before interpolation)
575
+ pred_tracks = sift_match(
576
+ image,
577
+ image_end,
578
+ thr=0.5,
579
+ topk=5,
580
+ method="random",
581
+ output_path=output_file_sift,
582
+ )
583
+
584
+ if pred_tracks is not None:
585
+ # interpolate the tracks, following draganything gradio demo
586
+ pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=MODEL_LENGTH)
587
+
588
+ anchor_points_flag = torch.zeros((MODEL_LENGTH, pred_tracks.shape[1])).to(pred_tracks.device)
589
+ anchor_points_flag[0] = 1
590
+ anchor_points_flag[-1] = 1
591
+
592
+ pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
593
+
594
+ else:
595
+
596
+ resized_all_points = [
597
+ tuple([tuple([int(e1[0] * WIDTH / original_width), int(e1[1] * HEIGHT / original_height)]) for e1 in e])
598
+ for e in input_all_points
599
+ ]
600
+
601
+ # a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
602
+ # in image w & h scale
603
+
604
+ for idx, splited_track in enumerate(resized_all_points):
605
+ if len(splited_track) == 0:
606
+ warnings.warn("running without point trajectory control")
607
+ continue
608
+
609
+ if len(splited_track) == 1: # stationary point
610
+ displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
611
+ splited_track = tuple([splited_track[0], displacement_point])
612
+ # interpolate the track
613
+ splited_track = interpolate_trajectory(splited_track, MODEL_LENGTH)
614
+ splited_track = splited_track[:MODEL_LENGTH]
615
+ resized_all_points[idx] = splited_track
616
+
617
+ pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
618
+
619
+ vis_images = get_vis_image(
620
+ target_size=(HEIGHT, WIDTH),
621
+ points=pred_tracks,
622
+ num_frames=MODEL_LENGTH,
623
+ )
624
+
625
+ if len(pred_tracks.shape) != 3:
626
+ print("pred_tracks.shape", pred_tracks.shape)
627
+ with_control = False
628
+ controlnet_cond_scale = 0.0
629
+ else:
630
+ with_control = True
631
+ pred_tracks = pred_tracks.permute(1, 0, 2).to(device, dtype) # (num_frames, num_points, 2)
632
+
633
+ point_embedding = None
634
+ video_frames = pipe(
635
+ image,
636
+ image_end,
637
+ # trajectory control
638
+ with_control=with_control,
639
+ point_tracks=pred_tracks,
640
+ point_embedding=point_embedding,
641
+ with_id_feature=False,
642
+ controlnet_cond_scale=controlnet_cond_scale,
643
+ # others
644
+ num_frames=14,
645
+ width=width,
646
+ height=height,
647
+ # decode_chunk_size=8,
648
+ # generator=generator,
649
+ motion_bucket_id=motion_bucket_id,
650
+ fps=7,
651
+ num_inference_steps=30,
652
+ # track
653
+ sift_track_update=sift_track_update,
654
+ anchor_points_flag=anchor_points_flag,
655
+ ).frames[0]
656
+
657
+ vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
658
+ vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
659
+ vis_images = [Image.fromarray(img) for img in vis_images]
660
+
661
+ # video_frames = [img for sublist in video_frames for img in sublist]
662
+ val_save_dir = os.path.join(OUTPUT_DIR, "vis_gif.gif")
663
+ save_gifs_side_by_side(
664
+ video_frames,
665
+ vis_images[:MODEL_LENGTH],
666
+ val_save_dir,
667
+ target_size=(WIDTH, HEIGHT),
668
+ duration=110,
669
+ point_tracks=pred_tracks,
670
+ )
671
+
672
+ return val_save_dir
673
+
674
+
675
  if __name__ == "__main__":
676
 
677
+ ensure_dirname(OUTPUT_DIR)
678
+
 
679
  color_list = []
680
  for i in range(20):
681
+ color = np.concatenate([np.random.random(4) * 255], axis=0)
682
  color_list.append(color)
683
 
684
  with gr.Blocks() as demo:
685
  gr.Markdown("""<h1 align="center">Framer: Interactive Frame Interpolation</h1><br>""")
686
+
687
+ gr.Markdown(
688
+ """Gradio Demo for <a href='https://arxiv.org/abs/2410.18978'><b>Framer: Interactive Frame Interpolation</b></a>.<br>
689
  Github Repo can be found at https://github.com/aim-uofa/Framer<br>
690
+ The template is inspired by DragAnything."""
691
+ )
692
+
693
  gr.Image(label="Framer: Interactive Frame Interpolation", value="assets/demos.gif", height=432, width=768)
694
+
695
+ gr.Markdown(
696
+ """## Usage: <br>
697
  1. Upload images<br>
698
  &ensp; 1.1 Upload the start image via the "Upload Start Image" button.<br>
699
  &ensp; 1.2. Upload the end image via the "Upload End Image" button.<br>
 
702
  &ensp; 2.2. You can click several points on either start or end image to forms a path.<br>
703
  &ensp; 2.3. Click "Delete last drag" to delete the whole lastest path.<br>
704
  &ensp; 2.4. Click "Delete last step" to delete the lastest clicked control point.<br>
705
+ 3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
706
+ )
707
+
 
708
  first_frame_path = gr.State()
709
  last_frame_path = gr.State()
710
  tracking_points = gr.State([])
711
+
712
  with gr.Row():
713
  with gr.Column(scale=1):
714
  image_upload_button = gr.UploadButton(label="Upload Start Image", file_types=["image"])
 
719
  run_button = gr.Button(value="Run")
720
  delete_last_drag_button = gr.Button(value="Delete last drag")
721
  delete_last_step_button = gr.Button(value="Delete last step")
722
+
723
  with gr.Column(scale=7):
724
  with gr.Row():
725
  with gr.Column(scale=6):
 
730
  width=512,
731
  sources=[],
732
  )
733
+
734
  with gr.Column(scale=6):
735
  input_image_end = gr.Image(
736
  label="end frame",
 
739
  width=512,
740
  sources=[],
741
  )
742
+
743
  with gr.Row():
744
  with gr.Column(scale=1):
745
+
746
  controlnet_cond_scale = gr.Slider(
747
+ label="Control Scale",
748
+ minimum=0.0,
749
+ maximum=10,
750
+ step=0.1,
751
  value=1.0,
752
  )
753
+
754
  motion_bucket_id = gr.Slider(
755
+ label="Motion Bucket",
756
+ minimum=1,
757
+ maximum=180,
758
+ step=1,
759
  value=100,
760
  )
761
+
762
  with gr.Column(scale=5):
763
  output_video = gr.Image(
764
  label="Output Video",
765
  height=320,
766
  width=1152,
767
  )
768
+
 
769
  with gr.Row():
770
+ gr.Markdown(
771
+ """
772
  ## Citation
773
  ```bibtex
774
  @article{wang2024framer,
 
778
  year={2024}
779
  }
780
  ```
781
+ """
782
+ )
783
+
784
+ image_upload_button.upload(
785
+ fn=preprocess_image,
786
+ inputs=image_upload_button,
787
+ outputs=[input_image, first_frame_path, tracking_points],
788
+ )
789
+
790
+ image_end_upload_button.upload(
791
+ fn=preprocess_image_end,
792
+ inputs=image_end_upload_button,
793
+ outputs=[input_image_end, last_frame_path, tracking_points],
794
+ )
795
+
796
+ add_drag_button.click(
797
+ fn=add_drag,
798
+ inputs=tracking_points,
799
+ outputs=tracking_points,
800
+ )
801
+
802
+ delete_last_drag_button.click(
803
+ fn=delete_last_drag,
804
+ inputs=[tracking_points, first_frame_path, last_frame_path],
805
+ outputs=[tracking_points, input_image, input_image_end],
806
+ )
807
+
808
+ delete_last_step_button.click(
809
+ fn=delete_last_step,
810
+ inputs=[tracking_points, first_frame_path, last_frame_path],
811
+ outputs=[tracking_points, input_image, input_image_end],
812
+ )
813
+
814
+ reset_button.click(
815
+ fn=reset_states,
816
+ outputs=[input_image, input_image_end, first_frame_path, last_frame_path, output_video, tracking_points],
817
+ )
818
+
819
+ gr.on(
820
+ triggers=[input_image.select, input_image_end.select],
821
+ fn=add_tracking_points,
822
+ inputs=[tracking_points, first_frame_path, last_frame_path],
823
+ outputs=[tracking_points, input_image, input_image_end],
824
+ )
825
+
826
+ run_button.click(
827
+ fn=run,
828
+ inputs=[first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
829
+ outputs=output_video,
830
+ )
831
+
832
  demo.launch()
requirements.txt CHANGED
@@ -1,14 +1,279 @@
1
- torch==2.0.0
2
- torchvision
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  diffusers==0.24.0
4
- transformers==4.27.0
5
- xformers==0.0.18
6
- imageio==2.27.0
7
- decord==0.6.0
8
- einops
9
- opencv-python
10
- av
11
- accelerate==0.27.2
12
- scipy
13
- colorlog
14
- numpy==1.24.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.1.1
4
+ # via framer (pyproject.toml)
5
+ aiofiles==23.2.1
6
+ # via gradio
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.6.2.post1
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ av==13.1.0
15
+ # via framer (pyproject.toml)
16
+ certifi==2024.8.30
17
+ # via
18
+ # httpcore
19
+ # httpx
20
+ # requests
21
+ charset-normalizer==3.4.0
22
+ # via requests
23
+ click==8.1.7
24
+ # via
25
+ # typer
26
+ # uvicorn
27
+ colorlog==6.9.0
28
+ # via framer (pyproject.toml)
29
  diffusers==0.24.0
30
+ # via framer (pyproject.toml)
31
+ einops==0.8.0
32
+ # via framer (pyproject.toml)
33
+ exceptiongroup==1.2.2
34
+ # via anyio
35
+ fastapi==0.115.4
36
+ # via gradio
37
+ ffmpy==0.4.0
38
+ # via gradio
39
+ filelock==3.16.1
40
+ # via
41
+ # diffusers
42
+ # huggingface-hub
43
+ # torch
44
+ # transformers
45
+ # triton
46
+ fsspec==2024.10.0
47
+ # via
48
+ # gradio-client
49
+ # huggingface-hub
50
+ # torch
51
+ gradio==5.5.0
52
+ # via
53
+ # framer (pyproject.toml)
54
+ # spaces
55
+ gradio-client==1.4.2
56
+ # via gradio
57
+ h11==0.14.0
58
+ # via
59
+ # httpcore
60
+ # uvicorn
61
+ hf-transfer==0.1.8
62
+ # via framer (pyproject.toml)
63
+ httpcore==1.0.6
64
+ # via httpx
65
+ httpx==0.27.2
66
+ # via
67
+ # gradio
68
+ # gradio-client
69
+ # safehttpx
70
+ # spaces
71
+ huggingface-hub==0.25.2
72
+ # via
73
+ # framer (pyproject.toml)
74
+ # accelerate
75
+ # diffusers
76
+ # gradio
77
+ # gradio-client
78
+ # tokenizers
79
+ # transformers
80
+ idna==3.10
81
+ # via
82
+ # anyio
83
+ # httpx
84
+ # requests
85
+ imageio==2.36.0
86
+ # via framer (pyproject.toml)
87
+ importlib-metadata==8.5.0
88
+ # via diffusers
89
+ jinja2==3.1.4
90
+ # via
91
+ # gradio
92
+ # torch
93
+ markdown-it-py==3.0.0
94
+ # via rich
95
+ markupsafe==2.1.5
96
+ # via
97
+ # gradio
98
+ # jinja2
99
+ mdurl==0.1.2
100
+ # via markdown-it-py
101
+ mpmath==1.3.0
102
+ # via sympy
103
+ networkx==3.4.2
104
+ # via torch
105
+ numpy==1.24.3
106
+ # via
107
+ # accelerate
108
+ # diffusers
109
+ # gradio
110
+ # imageio
111
+ # opencv-python
112
+ # pandas
113
+ # scipy
114
+ # torchvision
115
+ # transformers
116
+ nvidia-cublas-cu12==12.1.3.1
117
+ # via
118
+ # nvidia-cudnn-cu12
119
+ # nvidia-cusolver-cu12
120
+ # torch
121
+ nvidia-cuda-cupti-cu12==12.1.105
122
+ # via torch
123
+ nvidia-cuda-nvrtc-cu12==12.1.105
124
+ # via torch
125
+ nvidia-cuda-runtime-cu12==12.1.105
126
+ # via torch
127
+ nvidia-cudnn-cu12==9.1.0.70
128
+ # via torch
129
+ nvidia-cufft-cu12==11.0.2.54
130
+ # via torch
131
+ nvidia-curand-cu12==10.3.2.106
132
+ # via torch
133
+ nvidia-cusolver-cu12==11.4.5.107
134
+ # via torch
135
+ nvidia-cusparse-cu12==12.1.0.106
136
+ # via
137
+ # nvidia-cusolver-cu12
138
+ # torch
139
+ nvidia-nccl-cu12==2.20.5
140
+ # via torch
141
+ nvidia-nvjitlink-cu12==12.6.77
142
+ # via
143
+ # nvidia-cusolver-cu12
144
+ # nvidia-cusparse-cu12
145
+ nvidia-nvtx-cu12==12.1.105
146
+ # via torch
147
+ opencv-python==4.10.0.84
148
+ # via framer (pyproject.toml)
149
+ orjson==3.10.11
150
+ # via gradio
151
+ packaging==24.2
152
+ # via
153
+ # accelerate
154
+ # gradio
155
+ # gradio-client
156
+ # huggingface-hub
157
+ # spaces
158
+ # transformers
159
+ pandas==2.2.3
160
+ # via gradio
161
+ pillow==11.0.0
162
+ # via
163
+ # diffusers
164
+ # gradio
165
+ # imageio
166
+ # torchvision
167
+ psutil==5.9.8
168
+ # via
169
+ # accelerate
170
+ # spaces
171
+ pydantic==2.9.2
172
+ # via
173
+ # fastapi
174
+ # gradio
175
+ # spaces
176
+ pydantic-core==2.23.4
177
+ # via pydantic
178
+ pydub==0.25.1
179
+ # via gradio
180
+ pygments==2.18.0
181
+ # via rich
182
+ python-dateutil==2.9.0.post0
183
+ # via pandas
184
+ python-multipart==0.0.12
185
+ # via gradio
186
+ pytz==2024.2
187
+ # via pandas
188
+ pyyaml==6.0.2
189
+ # via
190
+ # accelerate
191
+ # gradio
192
+ # huggingface-hub
193
+ # transformers
194
+ regex==2024.11.6
195
+ # via
196
+ # diffusers
197
+ # transformers
198
+ requests==2.32.3
199
+ # via
200
+ # diffusers
201
+ # huggingface-hub
202
+ # spaces
203
+ # transformers
204
+ rich==13.9.4
205
+ # via typer
206
+ ruff==0.7.3
207
+ # via gradio
208
+ safehttpx==0.1.1
209
+ # via gradio
210
+ safetensors==0.4.5
211
+ # via
212
+ # accelerate
213
+ # diffusers
214
+ # transformers
215
+ scipy==1.14.1
216
+ # via framer (pyproject.toml)
217
+ semantic-version==2.10.0
218
+ # via gradio
219
+ shellingham==1.5.4
220
+ # via typer
221
+ six==1.16.0
222
+ # via python-dateutil
223
+ sniffio==1.3.1
224
+ # via
225
+ # anyio
226
+ # httpx
227
+ spaces==0.30.4
228
+ # via framer (pyproject.toml)
229
+ starlette==0.41.2
230
+ # via
231
+ # fastapi
232
+ # gradio
233
+ sympy==1.13.3
234
+ # via torch
235
+ tokenizers==0.20.3
236
+ # via transformers
237
+ tomlkit==0.12.0
238
+ # via gradio
239
+ torch==2.4.0
240
+ # via
241
+ # framer (pyproject.toml)
242
+ # accelerate
243
+ # torchvision
244
+ torchvision==0.19.0
245
+ # via framer (pyproject.toml)
246
+ tqdm==4.67.0
247
+ # via
248
+ # huggingface-hub
249
+ # transformers
250
+ transformers==4.46.2
251
+ # via framer (pyproject.toml)
252
+ triton==3.0.0
253
+ # via torch
254
+ typer==0.13.0
255
+ # via gradio
256
+ typing-extensions==4.12.2
257
+ # via
258
+ # anyio
259
+ # fastapi
260
+ # gradio
261
+ # gradio-client
262
+ # huggingface-hub
263
+ # pydantic
264
+ # pydantic-core
265
+ # rich
266
+ # spaces
267
+ # torch
268
+ # typer
269
+ # uvicorn
270
+ tzdata==2024.2
271
+ # via pandas
272
+ urllib3==2.2.3
273
+ # via requests
274
+ uvicorn==0.32.0
275
+ # via gradio
276
+ websockets==12.0
277
+ # via gradio-client
278
+ zipp==3.21.0
279
+ # via importlib-metadata