hysts HF staff commited on
Commit
33cb6e3
1 Parent(s): ecda160

Apply formatter to app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +397 -275
  2. requirements.txt +22 -23
app.py CHANGED
@@ -1,35 +1,43 @@
1
  import os
2
  import sys
3
 
4
-
5
  print("Installing correct gradio version...")
6
  os.system("pip uninstall -y gradio")
7
  os.system("pip install gradio==4.38.1")
8
  print("Installing Finished!")
9
 
10
 
 
 
 
 
11
  import gradio as gr
12
  import numpy as np
13
- import cv2
14
- import uuid
15
  import torch
16
  import torchvision
17
- import json
18
- import spaces
19
-
20
- from PIL import Image
21
- from omegaconf import OmegaConf
22
  from einops import rearrange, repeat
23
- from torchvision import transforms,utils
 
 
24
  from transformers import CLIPTextModel, CLIPTokenizer
25
- from diffusers import AutoencoderKL, DDIMScheduler
26
 
27
- from pipelines.pipeline_imagecoductor import ImageConductorPipeline
28
  from modules.unet import UNet3DConditionFlowModel
29
- from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil
30
- from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian, save_videos_grid
31
  from utils.lora_utils import add_LoRA_to_controlnet
 
 
 
 
 
 
 
 
 
32
  from utils.visualizer import Visualizer, vis_flow_to_video
 
33
  #### Description ####
34
  title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
35
 
@@ -41,7 +49,7 @@ head = r"""
41
  <a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
42
  <a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
43
  <a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
44
-
45
 
46
  </div>
47
  </br>
@@ -49,7 +57,6 @@ head = r"""
49
  """
50
 
51
 
52
-
53
  descriptions = r"""
54
  Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
55
  🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
@@ -66,7 +73,7 @@ instructions = r"""
66
  """
67
 
68
  citation = r"""
69
- If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
70
  [![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor)
71
  ---
72
 
@@ -75,7 +82,7 @@ If Image Conductor is helpful, please help to ⭐ the <a href='https://github.co
75
  If our work is useful for your research, please consider citing:
76
  ```bibtex
77
  @misc{li2024imageconductor,
78
- title={Image Conductor: Precision Control for Interactive Video Synthesis},
79
  author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
80
  year={2024},
81
  eprint={2406.15339},
@@ -94,42 +101,68 @@ os.makedirs("models/personalized")
94
  os.makedirs("models/sd1-5")
95
 
96
  if not os.path.exists("models/flow_controlnet.ckpt"):
97
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/flow_controlnet.ckpt?download=true -P models/')
98
- os.system(f'mv models/flow_controlnet.ckpt?download=true models/flow_controlnet.ckpt')
99
- print("flow_controlnet Download!", )
 
 
 
 
100
 
101
  if not os.path.exists("models/image_controlnet.ckpt"):
102
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/image_controlnet.ckpt?download=true -P models/')
103
- os.system(f'mv models/image_controlnet.ckpt?download=true models/image_controlnet.ckpt')
104
- print("image_controlnet Download!", )
 
 
 
 
105
 
106
  if not os.path.exists("models/unet.ckpt"):
107
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/unet.ckpt?download=true -P models/')
108
- os.system(f'mv models/unet.ckpt?download=true models/unet.ckpt')
109
- print("unet Download!", )
 
 
 
 
 
110
 
111
-
112
  if not os.path.exists("models/sd1-5/config.json"):
113
- os.system(f'wget -q https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/config.json?download=true -P models/sd1-5/')
114
- os.system(f'mv models/sd1-5/config.json?download=true models/sd1-5/config.json')
115
- print("config Download!", )
 
 
 
 
116
 
117
 
118
  if not os.path.exists("models/sd1-5/unet.ckpt"):
119
- os.system(f'cp -r models/unet.ckpt models/sd1-5/unet.ckpt')
120
 
121
  # os.system(f'wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin?download=true -P models/sd1-5/')
122
 
123
  if not os.path.exists("models/personalized/helloobjects_V12c.safetensors"):
124
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/helloobjects_V12c.safetensors?download=true -P models/personalized')
125
- os.system(f'mv models/personalized/helloobjects_V12c.safetensors?download=true models/personalized/helloobjects_V12c.safetensors')
126
- print("helloobjects_V12c Download!", )
 
 
 
 
 
 
127
 
128
 
129
  if not os.path.exists("models/personalized/TUSUN.safetensors"):
130
- os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/TUSUN.safetensors?download=true -P models/personalized')
131
- os.system(f'mv models/personalized/TUSUN.safetensors?download=true models/personalized/TUSUN.safetensors')
132
- print("TUSUN Download!", )
 
 
 
 
133
 
134
  # mv1 = os.system(f'mv /usr/local/lib/python3.10/site-packages/gradio/helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers_bkp.py')
135
  # mv2 = os.system(f'mv helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers.py')
@@ -145,128 +178,135 @@ if not os.path.exists("models/personalized/TUSUN.safetensors"):
145
  # - - - - - examples - - - - - #
146
 
147
  image_examples = [
148
- ["__asset__/images/object/turtle-1.jpg",
149
- "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
150
- "object",
151
- 11318446767408804497,
152
- "",
153
- "turtle",
154
- "__asset__/turtle.mp4"
155
- ],
156
-
157
- ["__asset__/images/object/rose-1.jpg",
158
- "a red rose engulfed in flames.",
159
- "object",
160
- 6854275249656120509,
161
- "",
162
- "rose",
163
- "__asset__/rose.mp4"
164
- ],
165
-
166
- ["__asset__/images/object/jellyfish-1.jpg",
167
- "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
168
- "object",
169
- 17966188172968903484,
170
- "HelloObject",
171
- "jellyfish",
172
- "__asset__/jellyfish.mp4"
173
- ],
174
-
175
-
176
- ["__asset__/images/camera/lush-1.jpg",
177
- "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
178
- "camera",
179
- 7970487946960948963,
180
- "HelloObject",
181
- "lush",
182
- "__asset__/lush.mp4",
183
- ],
184
-
185
- ["__asset__/images/camera/tusun-1.jpg",
186
- "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
187
- "camera",
188
- 996953226890228361,
189
- "TUSUN",
190
- "tusun",
191
- "__asset__/tusun.mp4"
192
- ],
193
-
194
- ["__asset__/images/camera/painting-1.jpg",
195
- "A oil painting.",
196
- "camera",
197
- 16867854766769816385,
198
- "",
199
- "painting",
200
- "__asset__/painting.mp4"
201
- ],
202
  ]
203
 
204
 
205
  POINTS = {
206
- 'turtle': "__asset__/trajs/object/turtle-1.json",
207
- 'rose': "__asset__/trajs/object/rose-1.json",
208
- 'jellyfish': "__asset__/trajs/object/jellyfish-1.json",
209
- 'lush': "__asset__/trajs/camera/lush-1.json",
210
- 'tusun': "__asset__/trajs/camera/tusun-1.json",
211
- 'painting': "__asset__/trajs/camera/painting-1.json",
212
  }
213
 
214
  IMAGE_PATH = {
215
- 'turtle': "__asset__/images/object/turtle-1.jpg",
216
- 'rose': "__asset__/images/object/rose-1.jpg",
217
- 'jellyfish': "__asset__/images/object/jellyfish-1.jpg",
218
- 'lush': "__asset__/images/camera/lush-1.jpg",
219
- 'tusun': "__asset__/images/camera/tusun-1.jpg",
220
- 'painting': "__asset__/images/camera/painting-1.jpg",
221
  }
222
 
223
 
224
-
225
  DREAM_BOOTH = {
226
- 'HelloObject': 'models/personalized/helloobjects_V12c.safetensors',
227
  }
228
 
229
  LORA = {
230
- 'TUSUN': 'models/personalized/TUSUN.safetensors',
231
  }
232
 
233
  LORA_ALPHA = {
234
- 'TUSUN': 0.6,
235
  }
236
 
237
  NPROMPT = {
238
- "HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)'
239
  }
240
 
241
  output_dir = "outputs"
242
  ensure_dirname(output_dir)
243
 
 
244
  def points_to_flows(track_points, model_length, height, width):
245
  input_drag = np.zeros((model_length - 1, height, width, 2))
246
  for splited_track in track_points:
247
- if len(splited_track) == 1: # stationary point
248
  displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
249
  splited_track = tuple([splited_track[0], displacement_point])
250
  # interpolate the track
251
  splited_track = interpolate_trajectory(splited_track, model_length)
252
  splited_track = splited_track[:model_length]
253
  if len(splited_track) < model_length:
254
- splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
255
  for i in range(model_length - 1):
256
  start_point = splited_track[i]
257
- end_point = splited_track[i+1]
258
  input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
259
  input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
260
  return input_drag
261
 
 
262
  class ImageConductor:
263
- def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
 
 
264
  self.device = device
265
- tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
266
- text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(device)
267
- vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
 
 
268
  inference_config = OmegaConf.load("configs/inference/inference.yaml")
269
- unet = UNet3DConditionFlowModel.from_pretrained_2d("models/sd1-5/", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
 
 
270
 
271
  self.vae = vae
272
 
@@ -287,15 +327,14 @@ class ImageConductor:
287
 
288
  self.pipeline = ImageConductorPipeline(
289
  unet=unet,
290
- vae=vae,
291
- tokenizer=tokenizer,
292
- text_encoder=text_encoder,
293
  scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
294
  image_controlnet=image_controlnet,
295
  flow_controlnet=flow_controlnet,
296
  ).to(device)
297
 
298
-
299
  self.height = height
300
  self.width = width
301
  # _, model_step, _ = split_filename(model_path)
@@ -307,7 +346,20 @@ class ImageConductor:
307
  self.blur_kernel = blur_kernel
308
 
309
  @spaces.GPU(duration=180)
310
- def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type):
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  print("Run!")
312
  if examples_type != "":
313
  ### for adapting high version gradio
@@ -317,30 +369,40 @@ class ImageConductor:
317
  tracking_points.value.extend(points)
318
  print("example first_frame_path", first_frame_path)
319
  print("example tracking_points", tracking_points.value)
320
-
321
- original_width, original_height=384, 256
322
  if isinstance(tracking_points, list):
323
  input_all_points = tracking_points
324
  else:
325
  input_all_points = tracking_points.value
326
-
327
  print("input_all_points", input_all_points)
328
- resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
 
 
 
 
 
 
 
 
329
 
330
  dir, base, ext = split_filename(first_frame_path)
331
- id = base.split('_')[-1]
332
-
333
-
334
- visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
 
335
 
336
- ## image condition
337
- image_transforms = transforms.Compose([
 
338
  transforms.RandomResizedCrop(
339
- (self.height, self.width), (1.0, 1.0),
340
- ratio=(self.width/self.height, self.width/self.height)
341
  ),
342
  transforms.ToTensor(),
343
- ])
 
344
 
345
  image_paths = [first_frame_path]
346
  controlnet_images = [(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
@@ -349,60 +411,68 @@ class ImageConductor:
349
  num_controlnet_images = controlnet_images.shape[2]
350
  controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
351
  self.vae.to(device)
352
- controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
353
  controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
354
 
355
  # flow condition
356
  controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
357
- for i in range(0, self.model_length-1):
358
  controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
359
- controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow
 
 
360
  os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
361
- trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
362
- torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
363
- controlnet_flows = torch.from_numpy(controlnet_flows)[None][:, :self.model_length, ...]
364
- controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w").float().to(device)
 
 
 
 
 
 
365
 
366
- dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
367
- lora_model_path = LORA.get(personalized, '')
368
  lora_alpha = LORA_ALPHA.get(personalized, 0.6)
369
  self.pipeline = load_weights(
370
  self.pipeline,
371
- dreambooth_model_path = dreambooth_model_path,
372
- lora_model_path = lora_model_path,
373
- lora_alpha = lora_alpha,
374
  ).to(device)
375
-
376
- if NPROMPT.get(personalized, '') != '':
377
- negative_prompt = NPROMPT.get(personalized)
378
-
379
  if randomize_seed:
380
  random_seed = torch.seed()
381
  else:
382
  seed = int(seed)
383
  random_seed = seed
384
  torch.manual_seed(random_seed)
385
- torch.cuda.manual_seed_all(random_seed)
386
  print(f"current seed: {torch.initial_seed()}")
387
  sample = self.pipeline(
388
- prompt,
389
- negative_prompt = negative_prompt,
390
- num_inference_steps = num_inference_steps,
391
- guidance_scale = guidance_scale,
392
- width = self.width,
393
- height = self.height,
394
- video_length = self.model_length,
395
- controlnet_images = controlnet_images, # 1 4 1 32 48
396
- controlnet_image_index = [0],
397
- controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384]
398
- control_mode = drag_mode,
399
- eval_mode = True,
400
- ).videos
401
-
402
- outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
403
- vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
404
- torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
405
-
406
  # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
407
  # save_videos_grid(sample[0][None], outputs_path)
408
  print("Done!")
@@ -412,33 +482,40 @@ class ImageConductor:
412
  def reset_states(first_frame_path, tracking_points):
413
  first_frame_path = gr.State()
414
  tracking_points = gr.State([])
415
- return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
416
 
417
 
418
  def preprocess_image(image, tracking_points):
419
  image_pil = image2pil(image.name)
420
  raw_w, raw_h = image_pil.size
421
- resize_ratio = max(384/raw_w, 256/raw_h)
422
  image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
423
- image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB'))
424
  id = str(uuid.uuid4())[:4]
425
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
426
  image_pil.save(first_frame_path, quality=95)
427
- tracking_points = gr.State([])
428
- return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points, personalized:""}
429
-
430
-
431
- def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
432
- if drag_mode=='object':
 
 
 
 
 
 
 
433
  color = (255, 0, 0, 255)
434
- elif drag_mode=='camera':
435
  color = (0, 0, 255, 255)
436
 
437
- if not isinstance(tracking_points ,list):
438
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
439
  tracking_points.value[-1].append(evt.index)
440
  print(tracking_points.value)
441
- tracking_points_values = tracking_points.value
442
  else:
443
  try:
444
  tracking_points[-1].append(evt.index)
@@ -446,26 +523,33 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
446
  tracking_points.append([])
447
  tracking_points[-1].append(evt.index)
448
  print(f"Solved Error: {e}")
449
-
450
  tracking_points_values = tracking_points
451
-
452
-
453
- transparent_background = Image.open(first_frame_path).convert('RGBA')
454
  w, h = transparent_background.size
455
  transparent_layer = np.zeros((h, w, 4))
456
-
457
  for track in tracking_points_values:
458
  if len(track) > 1:
459
- for i in range(len(track)-1):
460
  start_point = track[i]
461
- end_point = track[i+1]
462
  vx = end_point[0] - start_point[0]
463
  vy = end_point[1] - start_point[1]
464
  arrow_length = np.sqrt(vx**2 + vy**2)
465
- if i == len(track)-2:
466
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
 
 
467
  else:
468
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
 
 
 
 
 
 
469
  else:
470
  cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
471
 
@@ -475,79 +559,90 @@ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.Se
475
 
476
 
477
  def add_drag(tracking_points):
478
- if not isinstance(tracking_points ,list):
479
  # print("before", tracking_points.value)
480
  tracking_points.value.append([])
481
  # print(tracking_points.value)
482
  else:
483
  tracking_points.append([])
484
  return {tracking_points_var: tracking_points}
485
-
486
 
487
  def delete_last_drag(tracking_points, first_frame_path, drag_mode):
488
- if drag_mode=='object':
489
  color = (255, 0, 0, 255)
490
- elif drag_mode=='camera':
491
  color = (0, 0, 255, 255)
492
  tracking_points.value.pop()
493
- transparent_background = Image.open(first_frame_path).convert('RGBA')
494
  w, h = transparent_background.size
495
  transparent_layer = np.zeros((h, w, 4))
496
  for track in tracking_points.value:
497
  if len(track) > 1:
498
- for i in range(len(track)-1):
499
  start_point = track[i]
500
- end_point = track[i+1]
501
  vx = end_point[0] - start_point[0]
502
  vy = end_point[1] - start_point[1]
503
  arrow_length = np.sqrt(vx**2 + vy**2)
504
- if i == len(track)-2:
505
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
 
 
506
  else:
507
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
 
 
 
 
 
 
508
  else:
509
  cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
510
 
511
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
512
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
513
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
514
-
515
 
516
  def delete_last_step(tracking_points, first_frame_path, drag_mode):
517
- if drag_mode=='object':
518
  color = (255, 0, 0, 255)
519
- elif drag_mode=='camera':
520
  color = (0, 0, 255, 255)
521
  tracking_points.value[-1].pop()
522
- transparent_background = Image.open(first_frame_path).convert('RGBA')
523
  w, h = transparent_background.size
524
  transparent_layer = np.zeros((h, w, 4))
525
  for track in tracking_points.value:
526
  if len(track) > 1:
527
- for i in range(len(track)-1):
528
  start_point = track[i]
529
- end_point = track[i+1]
530
  vx = end_point[0] - start_point[0]
531
  vy = end_point[1] - start_point[1]
532
  arrow_length = np.sqrt(vx**2 + vy**2)
533
- if i == len(track)-2:
534
- cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
 
 
535
  else:
536
- cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
 
 
 
 
 
 
537
  else:
538
- cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1)
539
 
540
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
541
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
542
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
543
 
544
 
545
- block = gr.Blocks(
546
- theme=gr.themes.Soft(
547
- radius_size=gr.themes.sizes.radius_none,
548
- text_size=gr.themes.sizes.text_md
549
- )
550
- )
551
  with block:
552
  with gr.Row():
553
  with gr.Column():
@@ -557,66 +652,72 @@ with block:
557
 
558
  with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
559
  with gr.Row(equal_height=True):
560
- gr.Markdown(instructions)
561
-
562
 
563
  # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
564
  device = torch.device("cuda")
565
- unet_path = 'models/unet.ckpt'
566
- image_controlnet_path = 'models/image_controlnet.ckpt'
567
- flow_controlnet_path = 'models/flow_controlnet.ckpt'
568
- ImageConductor_net = ImageConductor(device=device,
569
- unet_path=unet_path,
570
- image_controlnet_path=image_controlnet_path,
571
- flow_controlnet_path=flow_controlnet_path,
572
- height=256,
573
- width=384,
574
- model_length=16
575
- )
 
576
  first_frame_path_var = gr.State(value=None)
577
  tracking_points_var = gr.State([])
578
 
579
  with gr.Row():
580
  with gr.Column(scale=1):
581
- image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
582
  add_drag_button = gr.Button(value="Add Drag")
583
  reset_button = gr.Button(value="Reset")
584
  delete_last_drag_button = gr.Button(value="Delete last drag")
585
  delete_last_step_button = gr.Button(value="Delete last step")
586
-
587
-
588
 
589
  with gr.Column(scale=7):
590
  with gr.Row():
591
  with gr.Column(scale=6):
592
- input_image = gr.Image(label="Input Image",
593
- interactive=True,
594
- height=300,
595
- width=384,)
 
 
596
  with gr.Column(scale=6):
597
- output_image = gr.Image(label="Motion Path",
598
- interactive=False,
599
- height=256,
600
- width=384,)
 
 
601
  with gr.Row():
602
  with gr.Column(scale=1):
603
- prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
 
 
604
  negative_prompt = gr.Text(
605
- label="Negative Prompt",
606
- max_lines=5,
607
- placeholder="Please input your negative prompt",
608
- value='worst quality, low quality, letterboxed',lines=1
609
- )
610
- drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
 
611
  run_button = gr.Button(value="Run")
612
 
613
  with gr.Accordion("More input params", open=False, elem_id="accordion1"):
614
  with gr.Group():
615
  seed = gr.Textbox(
616
- label="Seed: ", value=561793204,
 
617
  )
618
  randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
619
-
620
  with gr.Group():
621
  with gr.Row():
622
  guidance_scale = gr.Slider(
@@ -633,23 +734,18 @@ with block:
633
  step=1,
634
  value=25,
635
  )
636
-
637
  with gr.Group():
638
- personalized = gr.Dropdown(label="Personalized", choices=["", 'HelloObject', 'TUSUN'], value="")
639
- examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
640
 
641
  with gr.Column(scale=7):
642
- output_video = gr.Video(
643
- label="Output Video",
644
- width=384,
645
- height=256)
646
  # output_video = gr.Image(label="Output Video",
647
  # height=256,
648
  # width=384,)
649
-
650
-
651
  with gr.Row():
652
-
653
 
654
  example = gr.Examples(
655
  label="Input Example",
@@ -658,26 +754,52 @@ with block:
658
  examples_per_page=10,
659
  cache_examples=False,
660
  )
661
-
662
-
663
  with gr.Row():
664
  gr.Markdown(citation)
665
 
666
-
667
- image_upload_button.upload(preprocess_image, [image_upload_button, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var, personalized])
 
 
 
668
 
669
  add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
670
 
671
- delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
672
-
673
- delete_last_step_button.click(delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
674
-
675
- reset_button.click(reset_states, [first_frame_path_var, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
676
-
677
- input_image.select(add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
678
-
679
- run_button.click(ImageConductor_net.run, [first_frame_path_var, tracking_points_var, prompt, drag_mode,
680
- negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type],
681
- [output_image, output_video])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
 
683
  block.queue().launch()
 
1
  import os
2
  import sys
3
 
 
4
  print("Installing correct gradio version...")
5
  os.system("pip uninstall -y gradio")
6
  os.system("pip install gradio==4.38.1")
7
  print("Installing Finished!")
8
 
9
 
10
+ import json
11
+ import uuid
12
+
13
+ import cv2
14
  import gradio as gr
15
  import numpy as np
16
+ import spaces
 
17
  import torch
18
  import torchvision
19
+ from diffusers import AutoencoderKL, DDIMScheduler
 
 
 
 
20
  from einops import rearrange, repeat
21
+ from omegaconf import OmegaConf
22
+ from PIL import Image
23
+ from torchvision import transforms, utils
24
  from transformers import CLIPTextModel, CLIPTokenizer
 
25
 
 
26
  from modules.unet import UNet3DConditionFlowModel
27
+ from pipelines.pipeline_imagecoductor import ImageConductorPipeline
28
+ from utils.gradio_utils import ensure_dirname, image2pil, split_filename, visualize_drag
29
  from utils.lora_utils import add_LoRA_to_controlnet
30
+ from utils.utils import (
31
+ bivariate_Gaussian,
32
+ create_flow_controlnet,
33
+ create_image_controlnet,
34
+ interpolate_trajectory,
35
+ load_model,
36
+ load_weights,
37
+ save_videos_grid,
38
+ )
39
  from utils.visualizer import Visualizer, vis_flow_to_video
40
+
41
  #### Description ####
42
  title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
43
 
 
49
  <a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
50
  <a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
51
  <a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
52
+
53
 
54
  </div>
55
  </br>
 
57
  """
58
 
59
 
 
60
  descriptions = r"""
61
  Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
62
  🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
 
73
  """
74
 
75
  citation = r"""
76
+ If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
77
  [![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor)
78
  ---
79
 
 
82
  If our work is useful for your research, please consider citing:
83
  ```bibtex
84
  @misc{li2024imageconductor,
85
+ title={Image Conductor: Precision Control for Interactive Video Synthesis},
86
  author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
87
  year={2024},
88
  eprint={2406.15339},
 
101
  os.makedirs("models/sd1-5")
102
 
103
  if not os.path.exists("models/flow_controlnet.ckpt"):
104
+ os.system(
105
+ f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/flow_controlnet.ckpt?download=true -P models/"
106
+ )
107
+ os.system(f"mv models/flow_controlnet.ckpt?download=true models/flow_controlnet.ckpt")
108
+ print(
109
+ "flow_controlnet Download!",
110
+ )
111
 
112
  if not os.path.exists("models/image_controlnet.ckpt"):
113
+ os.system(
114
+ f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/image_controlnet.ckpt?download=true -P models/"
115
+ )
116
+ os.system(f"mv models/image_controlnet.ckpt?download=true models/image_controlnet.ckpt")
117
+ print(
118
+ "image_controlnet Download!",
119
+ )
120
 
121
  if not os.path.exists("models/unet.ckpt"):
122
+ os.system(
123
+ f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/unet.ckpt?download=true -P models/"
124
+ )
125
+ os.system(f"mv models/unet.ckpt?download=true models/unet.ckpt")
126
+ print(
127
+ "unet Download!",
128
+ )
129
+
130
 
 
131
  if not os.path.exists("models/sd1-5/config.json"):
132
+ os.system(
133
+ f"wget -q https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/config.json?download=true -P models/sd1-5/"
134
+ )
135
+ os.system(f"mv models/sd1-5/config.json?download=true models/sd1-5/config.json")
136
+ print(
137
+ "config Download!",
138
+ )
139
 
140
 
141
  if not os.path.exists("models/sd1-5/unet.ckpt"):
142
+ os.system(f"cp -r models/unet.ckpt models/sd1-5/unet.ckpt")
143
 
144
  # os.system(f'wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin?download=true -P models/sd1-5/')
145
 
146
  if not os.path.exists("models/personalized/helloobjects_V12c.safetensors"):
147
+ os.system(
148
+ f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/helloobjects_V12c.safetensors?download=true -P models/personalized"
149
+ )
150
+ os.system(
151
+ f"mv models/personalized/helloobjects_V12c.safetensors?download=true models/personalized/helloobjects_V12c.safetensors"
152
+ )
153
+ print(
154
+ "helloobjects_V12c Download!",
155
+ )
156
 
157
 
158
  if not os.path.exists("models/personalized/TUSUN.safetensors"):
159
+ os.system(
160
+ f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/TUSUN.safetensors?download=true -P models/personalized"
161
+ )
162
+ os.system(f"mv models/personalized/TUSUN.safetensors?download=true models/personalized/TUSUN.safetensors")
163
+ print(
164
+ "TUSUN Download!",
165
+ )
166
 
167
  # mv1 = os.system(f'mv /usr/local/lib/python3.10/site-packages/gradio/helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers_bkp.py')
168
  # mv2 = os.system(f'mv helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers.py')
 
178
  # - - - - - examples - - - - - #
179
 
180
  image_examples = [
181
+ [
182
+ "__asset__/images/object/turtle-1.jpg",
183
+ "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
184
+ "object",
185
+ 11318446767408804497,
186
+ "",
187
+ "turtle",
188
+ "__asset__/turtle.mp4",
189
+ ],
190
+ [
191
+ "__asset__/images/object/rose-1.jpg",
192
+ "a red rose engulfed in flames.",
193
+ "object",
194
+ 6854275249656120509,
195
+ "",
196
+ "rose",
197
+ "__asset__/rose.mp4",
198
+ ],
199
+ [
200
+ "__asset__/images/object/jellyfish-1.jpg",
201
+ "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
202
+ "object",
203
+ 17966188172968903484,
204
+ "HelloObject",
205
+ "jellyfish",
206
+ "__asset__/jellyfish.mp4",
207
+ ],
208
+ [
209
+ "__asset__/images/camera/lush-1.jpg",
210
+ "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
211
+ "camera",
212
+ 7970487946960948963,
213
+ "HelloObject",
214
+ "lush",
215
+ "__asset__/lush.mp4",
216
+ ],
217
+ [
218
+ "__asset__/images/camera/tusun-1.jpg",
219
+ "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
220
+ "camera",
221
+ 996953226890228361,
222
+ "TUSUN",
223
+ "tusun",
224
+ "__asset__/tusun.mp4",
225
+ ],
226
+ [
227
+ "__asset__/images/camera/painting-1.jpg",
228
+ "A oil painting.",
229
+ "camera",
230
+ 16867854766769816385,
231
+ "",
232
+ "painting",
233
+ "__asset__/painting.mp4",
234
+ ],
235
  ]
236
 
237
 
238
  POINTS = {
239
+ "turtle": "__asset__/trajs/object/turtle-1.json",
240
+ "rose": "__asset__/trajs/object/rose-1.json",
241
+ "jellyfish": "__asset__/trajs/object/jellyfish-1.json",
242
+ "lush": "__asset__/trajs/camera/lush-1.json",
243
+ "tusun": "__asset__/trajs/camera/tusun-1.json",
244
+ "painting": "__asset__/trajs/camera/painting-1.json",
245
  }
246
 
247
  IMAGE_PATH = {
248
+ "turtle": "__asset__/images/object/turtle-1.jpg",
249
+ "rose": "__asset__/images/object/rose-1.jpg",
250
+ "jellyfish": "__asset__/images/object/jellyfish-1.jpg",
251
+ "lush": "__asset__/images/camera/lush-1.jpg",
252
+ "tusun": "__asset__/images/camera/tusun-1.jpg",
253
+ "painting": "__asset__/images/camera/painting-1.jpg",
254
  }
255
 
256
 
 
257
  DREAM_BOOTH = {
258
+ "HelloObject": "models/personalized/helloobjects_V12c.safetensors",
259
  }
260
 
261
  LORA = {
262
+ "TUSUN": "models/personalized/TUSUN.safetensors",
263
  }
264
 
265
  LORA_ALPHA = {
266
+ "TUSUN": 0.6,
267
  }
268
 
269
  NPROMPT = {
270
+ "HelloObject": "FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)"
271
  }
272
 
273
  output_dir = "outputs"
274
  ensure_dirname(output_dir)
275
 
276
+
277
  def points_to_flows(track_points, model_length, height, width):
278
  input_drag = np.zeros((model_length - 1, height, width, 2))
279
  for splited_track in track_points:
280
+ if len(splited_track) == 1: # stationary point
281
  displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
282
  splited_track = tuple([splited_track[0], displacement_point])
283
  # interpolate the track
284
  splited_track = interpolate_trajectory(splited_track, model_length)
285
  splited_track = splited_track[:model_length]
286
  if len(splited_track) < model_length:
287
+ splited_track = splited_track + [splited_track[-1]] * (model_length - len(splited_track))
288
  for i in range(model_length - 1):
289
  start_point = splited_track[i]
290
+ end_point = splited_track[i + 1]
291
  input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
292
  input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
293
  return input_drag
294
 
295
+
296
  class ImageConductor:
297
+ def __init__(
298
+ self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64
299
+ ):
300
  self.device = device
301
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
302
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(
303
+ device
304
+ )
305
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
306
  inference_config = OmegaConf.load("configs/inference/inference.yaml")
307
+ unet = UNet3DConditionFlowModel.from_pretrained_2d(
308
+ "models/sd1-5/", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)
309
+ )
310
 
311
  self.vae = vae
312
 
 
327
 
328
  self.pipeline = ImageConductorPipeline(
329
  unet=unet,
330
+ vae=vae,
331
+ tokenizer=tokenizer,
332
+ text_encoder=text_encoder,
333
  scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
334
  image_controlnet=image_controlnet,
335
  flow_controlnet=flow_controlnet,
336
  ).to(device)
337
 
 
338
  self.height = height
339
  self.width = width
340
  # _, model_step, _ = split_filename(model_path)
 
346
  self.blur_kernel = blur_kernel
347
 
348
  @spaces.GPU(duration=180)
349
+ def run(
350
+ self,
351
+ first_frame_path,
352
+ tracking_points,
353
+ prompt,
354
+ drag_mode,
355
+ negative_prompt,
356
+ seed,
357
+ randomize_seed,
358
+ guidance_scale,
359
+ num_inference_steps,
360
+ personalized,
361
+ examples_type,
362
+ ):
363
  print("Run!")
364
  if examples_type != "":
365
  ### for adapting high version gradio
 
369
  tracking_points.value.extend(points)
370
  print("example first_frame_path", first_frame_path)
371
  print("example tracking_points", tracking_points.value)
372
+
373
+ original_width, original_height = 384, 256
374
  if isinstance(tracking_points, list):
375
  input_all_points = tracking_points
376
  else:
377
  input_all_points = tracking_points.value
378
+
379
  print("input_all_points", input_all_points)
380
+ resized_all_points = [
381
+ tuple(
382
+ [
383
+ tuple([float(e1[0] * self.width / original_width), float(e1[1] * self.height / original_height)])
384
+ for e1 in e
385
+ ]
386
+ )
387
+ for e in input_all_points
388
+ ]
389
 
390
  dir, base, ext = split_filename(first_frame_path)
391
+ id = base.split("_")[-1]
392
+
393
+ visualized_drag, _ = visualize_drag(
394
+ first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length
395
+ )
396
 
397
+ ## image condition
398
+ image_transforms = transforms.Compose(
399
+ [
400
  transforms.RandomResizedCrop(
401
+ (self.height, self.width), (1.0, 1.0), ratio=(self.width / self.height, self.width / self.height)
 
402
  ),
403
  transforms.ToTensor(),
404
+ ]
405
+ )
406
 
407
  image_paths = [first_frame_path]
408
  controlnet_images = [(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
 
411
  num_controlnet_images = controlnet_images.shape[2]
412
  controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
413
  self.vae.to(device)
414
+ controlnet_images = self.vae.encode(controlnet_images * 2.0 - 1.0).latent_dist.sample() * 0.18215
415
  controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
416
 
417
  # flow condition
418
  controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
419
+ for i in range(0, self.model_length - 1):
420
  controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
421
+ controlnet_flows = np.concatenate(
422
+ [np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0
423
+ ) # pad the first frame with zero flow
424
  os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
425
+ trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
426
+ torchvision.io.write_video(
427
+ f"{output_dir}/control_flows/sample-{id}-train_flow.mp4",
428
+ trajs_video,
429
+ fps=8,
430
+ video_codec="h264",
431
+ options={"crf": "10"},
432
+ )
433
+ controlnet_flows = torch.from_numpy(controlnet_flows)[None][:, : self.model_length, ...]
434
+ controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w").float().to(device)
435
 
436
+ dreambooth_model_path = DREAM_BOOTH.get(personalized, "")
437
+ lora_model_path = LORA.get(personalized, "")
438
  lora_alpha = LORA_ALPHA.get(personalized, 0.6)
439
  self.pipeline = load_weights(
440
  self.pipeline,
441
+ dreambooth_model_path=dreambooth_model_path,
442
+ lora_model_path=lora_model_path,
443
+ lora_alpha=lora_alpha,
444
  ).to(device)
445
+
446
+ if NPROMPT.get(personalized, "") != "":
447
+ negative_prompt = NPROMPT.get(personalized)
448
+
449
  if randomize_seed:
450
  random_seed = torch.seed()
451
  else:
452
  seed = int(seed)
453
  random_seed = seed
454
  torch.manual_seed(random_seed)
455
+ torch.cuda.manual_seed_all(random_seed)
456
  print(f"current seed: {torch.initial_seed()}")
457
  sample = self.pipeline(
458
+ prompt,
459
+ negative_prompt=negative_prompt,
460
+ num_inference_steps=num_inference_steps,
461
+ guidance_scale=guidance_scale,
462
+ width=self.width,
463
+ height=self.height,
464
+ video_length=self.model_length,
465
+ controlnet_images=controlnet_images, # 1 4 1 32 48
466
+ controlnet_image_index=[0],
467
+ controlnet_flows=controlnet_flows, # [1, 2, 16, 256, 384]
468
+ control_mode=drag_mode,
469
+ eval_mode=True,
470
+ ).videos
471
+
472
+ outputs_path = os.path.join(output_dir, f"output_{i}_{id}.mp4")
473
+ vis_video = (rearrange(sample[0], "c t h w -> t h w c") * 255.0).clip(0, 255)
474
+ torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec="h264", options={"crf": "10"})
475
+
476
  # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
477
  # save_videos_grid(sample[0][None], outputs_path)
478
  print("Done!")
 
482
  def reset_states(first_frame_path, tracking_points):
483
  first_frame_path = gr.State()
484
  tracking_points = gr.State([])
485
+ return {input_image: None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
486
 
487
 
488
  def preprocess_image(image, tracking_points):
489
  image_pil = image2pil(image.name)
490
  raw_w, raw_h = image_pil.size
491
+ resize_ratio = max(384 / raw_w, 256 / raw_h)
492
  image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
493
+ image_pil = transforms.CenterCrop((256, 384))(image_pil.convert("RGB"))
494
  id = str(uuid.uuid4())[:4]
495
  first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
496
  image_pil.save(first_frame_path, quality=95)
497
+ tracking_points = gr.State([])
498
+ return {
499
+ input_image: first_frame_path,
500
+ first_frame_path_var: first_frame_path,
501
+ tracking_points_var: tracking_points,
502
+ personalized: "",
503
+ }
504
+
505
+
506
+ def add_tracking_points(
507
+ tracking_points, first_frame_path, drag_mode, evt: gr.SelectData
508
+ ): # SelectData is a subclass of EventData
509
+ if drag_mode == "object":
510
  color = (255, 0, 0, 255)
511
+ elif drag_mode == "camera":
512
  color = (0, 0, 255, 255)
513
 
514
+ if not isinstance(tracking_points, list):
515
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
516
  tracking_points.value[-1].append(evt.index)
517
  print(tracking_points.value)
518
+ tracking_points_values = tracking_points.value
519
  else:
520
  try:
521
  tracking_points[-1].append(evt.index)
 
523
  tracking_points.append([])
524
  tracking_points[-1].append(evt.index)
525
  print(f"Solved Error: {e}")
526
+
527
  tracking_points_values = tracking_points
528
+
529
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
 
530
  w, h = transparent_background.size
531
  transparent_layer = np.zeros((h, w, 4))
532
+
533
  for track in tracking_points_values:
534
  if len(track) > 1:
535
+ for i in range(len(track) - 1):
536
  start_point = track[i]
537
+ end_point = track[i + 1]
538
  vx = end_point[0] - start_point[0]
539
  vy = end_point[1] - start_point[1]
540
  arrow_length = np.sqrt(vx**2 + vy**2)
541
+ if i == len(track) - 2:
542
+ cv2.arrowedLine(
543
+ transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
544
+ )
545
  else:
546
+ cv2.line(
547
+ transparent_layer,
548
+ tuple(start_point),
549
+ tuple(end_point),
550
+ color,
551
+ 2,
552
+ )
553
  else:
554
  cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
555
 
 
559
 
560
 
561
  def add_drag(tracking_points):
562
+ if not isinstance(tracking_points, list):
563
  # print("before", tracking_points.value)
564
  tracking_points.value.append([])
565
  # print(tracking_points.value)
566
  else:
567
  tracking_points.append([])
568
  return {tracking_points_var: tracking_points}
569
+
570
 
571
  def delete_last_drag(tracking_points, first_frame_path, drag_mode):
572
+ if drag_mode == "object":
573
  color = (255, 0, 0, 255)
574
+ elif drag_mode == "camera":
575
  color = (0, 0, 255, 255)
576
  tracking_points.value.pop()
577
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
578
  w, h = transparent_background.size
579
  transparent_layer = np.zeros((h, w, 4))
580
  for track in tracking_points.value:
581
  if len(track) > 1:
582
+ for i in range(len(track) - 1):
583
  start_point = track[i]
584
+ end_point = track[i + 1]
585
  vx = end_point[0] - start_point[0]
586
  vy = end_point[1] - start_point[1]
587
  arrow_length = np.sqrt(vx**2 + vy**2)
588
+ if i == len(track) - 2:
589
+ cv2.arrowedLine(
590
+ transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
591
+ )
592
  else:
593
+ cv2.line(
594
+ transparent_layer,
595
+ tuple(start_point),
596
+ tuple(end_point),
597
+ color,
598
+ 2,
599
+ )
600
  else:
601
  cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
602
 
603
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
604
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
605
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
606
+
607
 
608
  def delete_last_step(tracking_points, first_frame_path, drag_mode):
609
+ if drag_mode == "object":
610
  color = (255, 0, 0, 255)
611
+ elif drag_mode == "camera":
612
  color = (0, 0, 255, 255)
613
  tracking_points.value[-1].pop()
614
+ transparent_background = Image.open(first_frame_path).convert("RGBA")
615
  w, h = transparent_background.size
616
  transparent_layer = np.zeros((h, w, 4))
617
  for track in tracking_points.value:
618
  if len(track) > 1:
619
+ for i in range(len(track) - 1):
620
  start_point = track[i]
621
+ end_point = track[i + 1]
622
  vx = end_point[0] - start_point[0]
623
  vy = end_point[1] - start_point[1]
624
  arrow_length = np.sqrt(vx**2 + vy**2)
625
+ if i == len(track) - 2:
626
+ cv2.arrowedLine(
627
+ transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length
628
+ )
629
  else:
630
+ cv2.line(
631
+ transparent_layer,
632
+ tuple(start_point),
633
+ tuple(end_point),
634
+ color,
635
+ 2,
636
+ )
637
  else:
638
+ cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
639
 
640
  transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
641
  trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
642
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
643
 
644
 
645
+ block = gr.Blocks(theme=gr.themes.Soft(radius_size=gr.themes.sizes.radius_none, text_size=gr.themes.sizes.text_md))
 
 
 
 
 
646
  with block:
647
  with gr.Row():
648
  with gr.Column():
 
652
 
653
  with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
654
  with gr.Row(equal_height=True):
655
+ gr.Markdown(instructions)
 
656
 
657
  # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
658
  device = torch.device("cuda")
659
+ unet_path = "models/unet.ckpt"
660
+ image_controlnet_path = "models/image_controlnet.ckpt"
661
+ flow_controlnet_path = "models/flow_controlnet.ckpt"
662
+ ImageConductor_net = ImageConductor(
663
+ device=device,
664
+ unet_path=unet_path,
665
+ image_controlnet_path=image_controlnet_path,
666
+ flow_controlnet_path=flow_controlnet_path,
667
+ height=256,
668
+ width=384,
669
+ model_length=16,
670
+ )
671
  first_frame_path_var = gr.State(value=None)
672
  tracking_points_var = gr.State([])
673
 
674
  with gr.Row():
675
  with gr.Column(scale=1):
676
+ image_upload_button = gr.UploadButton(label="Upload Image", file_types=["image"])
677
  add_drag_button = gr.Button(value="Add Drag")
678
  reset_button = gr.Button(value="Reset")
679
  delete_last_drag_button = gr.Button(value="Delete last drag")
680
  delete_last_step_button = gr.Button(value="Delete last step")
 
 
681
 
682
  with gr.Column(scale=7):
683
  with gr.Row():
684
  with gr.Column(scale=6):
685
+ input_image = gr.Image(
686
+ label="Input Image",
687
+ interactive=True,
688
+ height=300,
689
+ width=384,
690
+ )
691
  with gr.Column(scale=6):
692
+ output_image = gr.Image(
693
+ label="Motion Path",
694
+ interactive=False,
695
+ height=256,
696
+ width=384,
697
+ )
698
  with gr.Row():
699
  with gr.Column(scale=1):
700
+ prompt = gr.Textbox(
701
+ value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True
702
+ )
703
  negative_prompt = gr.Text(
704
+ label="Negative Prompt",
705
+ max_lines=5,
706
+ placeholder="Please input your negative prompt",
707
+ value="worst quality, low quality, letterboxed",
708
+ lines=1,
709
+ )
710
+ drag_mode = gr.Radio(["camera", "object"], label="Drag mode: ", value="object", scale=2)
711
  run_button = gr.Button(value="Run")
712
 
713
  with gr.Accordion("More input params", open=False, elem_id="accordion1"):
714
  with gr.Group():
715
  seed = gr.Textbox(
716
+ label="Seed: ",
717
+ value=561793204,
718
  )
719
  randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
720
+
721
  with gr.Group():
722
  with gr.Row():
723
  guidance_scale = gr.Slider(
 
734
  step=1,
735
  value=25,
736
  )
737
+
738
  with gr.Group():
739
+ personalized = gr.Dropdown(label="Personalized", choices=["", "HelloObject", "TUSUN"], value="")
740
+ examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
741
 
742
  with gr.Column(scale=7):
743
+ output_video = gr.Video(label="Output Video", width=384, height=256)
 
 
 
744
  # output_video = gr.Image(label="Output Video",
745
  # height=256,
746
  # width=384,)
747
+
 
748
  with gr.Row():
 
749
 
750
  example = gr.Examples(
751
  label="Input Example",
 
754
  examples_per_page=10,
755
  cache_examples=False,
756
  )
757
+
 
758
  with gr.Row():
759
  gr.Markdown(citation)
760
 
761
+ image_upload_button.upload(
762
+ preprocess_image,
763
+ [image_upload_button, tracking_points_var],
764
+ [input_image, first_frame_path_var, tracking_points_var, personalized],
765
+ )
766
 
767
  add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
768
 
769
+ delete_last_drag_button.click(
770
+ delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image]
771
+ )
772
+
773
+ delete_last_step_button.click(
774
+ delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image]
775
+ )
776
+
777
+ reset_button.click(
778
+ reset_states,
779
+ [first_frame_path_var, tracking_points_var],
780
+ [input_image, first_frame_path_var, tracking_points_var],
781
+ )
782
+
783
+ input_image.select(
784
+ add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image]
785
+ )
786
+
787
+ run_button.click(
788
+ ImageConductor_net.run,
789
+ [
790
+ first_frame_path_var,
791
+ tracking_points_var,
792
+ prompt,
793
+ drag_mode,
794
+ negative_prompt,
795
+ seed,
796
+ randomize_seed,
797
+ guidance_scale,
798
+ num_inference_steps,
799
+ personalized,
800
+ examples_type,
801
+ ],
802
+ [output_image, output_video],
803
+ )
804
 
805
  block.queue().launch()
requirements.txt CHANGED
@@ -1,29 +1,28 @@
1
- torch
2
- torchvision
3
- torchaudio
4
- transformers==4.32.1
5
- gradio==4.38.1
6
- ftfy
7
- tensorboard
8
- datasets
9
- Pillow==9.5.0
10
- opencv-python
11
- imgaug
12
  accelerate==0.23.0
13
- image-reward
14
- hpsv2
15
- torchmetrics
16
- open-clip-torch
17
- clip
18
  av2
19
- peft
20
- imageio-ffmpeg
21
- scipy
22
- tqdm
23
- einops
24
  diffusers==0.28.0
 
 
 
 
 
 
 
 
25
  omegaconf
 
 
 
 
26
  scikit-image
27
  scikit-learn
28
- numpy==1.26.2
29
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  accelerate==0.23.0
 
 
 
 
 
2
  av2
3
+ clip
4
+ datasets
 
 
 
5
  diffusers==0.28.0
6
+ einops
7
+ ftfy
8
+ gradio==4.38.1
9
+ hpsv2
10
+ image-reward
11
+ imageio-ffmpeg
12
+ imgaug
13
+ numpy==1.26.2
14
  omegaconf
15
+ open-clip-torch
16
+ opencv-python
17
+ peft
18
+ Pillow==9.5.0
19
  scikit-image
20
  scikit-learn
21
+ scipy
22
+ tensorboard
23
+ torch
24
+ torchaudio
25
+ torchmetrics
26
+ torchvision
27
+ tqdm
28
+ transformers==4.32.1