Yw22 commited on
Commit
230c8a7
1 Parent(s): 759f1a7
Files changed (1) hide show
  1. app-1.py +678 -0
app-1.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+
5
+ print("Installing correct gradio version...")
6
+ os.system("pip uninstall -y gradio")
7
+ os.system("pip install gradio==4.7.0")
8
+ print("Installing Finished!")
9
+
10
+
11
+ import gradio as gr
12
+ import numpy as np
13
+ import cv2
14
+ import uuid
15
+ import torch
16
+ import torchvision
17
+ import json
18
+ import spaces
19
+
20
+ from PIL import Image
21
+ from omegaconf import OmegaConf
22
+ from einops import rearrange, repeat
23
+ from torchvision import transforms,utils
24
+ from transformers import CLIPTextModel, CLIPTokenizer
25
+ from diffusers import AutoencoderKL, DDIMScheduler
26
+
27
+ from pipelines.pipeline_imagecoductor import ImageConductorPipeline
28
+ from modules.unet import UNet3DConditionFlowModel
29
+ from utils.gradio_utils import ensure_dirname, split_filename, visualize_drag, image2pil
30
+ from utils.utils import create_image_controlnet, create_flow_controlnet, interpolate_trajectory, load_weights, load_model, bivariate_Gaussian, save_videos_grid
31
+ from utils.lora_utils import add_LoRA_to_controlnet
32
+ from utils.visualizer import Visualizer, vis_flow_to_video
33
+ #### Description ####
34
+ title = r"""<h1 align="center">CustomNet: Object Customization with Variable-Viewpoints in Text-to-Image Diffusion Models</h1>"""
35
+
36
+ head = r"""
37
+ <div style="text-align: center;">
38
+ <h1>Image Conductor: Precision Control for Interactive Video Synthesis</h1>
39
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
40
+ <a href=""></a>
41
+ <a href='https://liyaowei-stu.github.io/project/ImageConductor/'><img src='https://img.shields.io/badge/Project_Page-ImgaeConductor-green' alt='Project Page'></a>
42
+ <a href='https://arxiv.org/pdf/2406.15339'><img src='https://img.shields.io/badge/Paper-Arxiv-blue'></a>
43
+ <a href='https://github.com/liyaowei-stu/ImageConductor'><img src='https://img.shields.io/badge/Code-Github-orange'></a>
44
+
45
+
46
+ </div>
47
+ </br>
48
+ </div>
49
+ """
50
+
51
+
52
+
53
+ descriptions = r"""
54
+ Official Gradio Demo for <a href='https://github.com/liyaowei-stu/ImageConductor'><b>Image Conductor: Precision Control for Interactive Video Synthesis</b></a>.<br>
55
+ 🧙Image Conductor enables precise, fine-grained control for generating motion-controllable videos from images, advancing the practical application of interactive video synthesis.<br>
56
+ """
57
+
58
+
59
+ instructions = r"""
60
+ - ⭐️ <b>step1: </b>Upload or select one image from Example.
61
+ - ⭐️ <b>step2: </b>Click 'Add Drag' to draw some drags.
62
+ - ⭐️ <b>step3: </b>Input text prompt that complements the image (Necessary).
63
+ - ⭐️ <b>step4: </b>Select 'Drag Mode' to specify the control of camera transition or object movement.
64
+ - ⭐️ <b>step5: </b>Click 'Run' button to generate video assets.
65
+ - ⭐️ <b>others: </b>Click 'Delete last drag' to delete the whole lastest path. Click 'Delete last step' to delete the lastest clicked control point.
66
+ """
67
+
68
+ citation = r"""
69
+ If Image Conductor is helpful, please help to ⭐ the <a href='https://github.com/liyaowei-stu/ImageConductor' target='_blank'>Github Repo</a>. Thanks!
70
+ [![GitHub Stars](https://img.shields.io/github/stars/liyaowei-stu%2FImageConductor)](https://github.com/liyaowei-stu/ImageConductor)
71
+ ---
72
+
73
+ 📝 **Citation**
74
+ <br>
75
+ If our work is useful for your research, please consider citing:
76
+ ```bibtex
77
+ @misc{li2024imageconductor,
78
+ title={Image Conductor: Precision Control for Interactive Video Synthesis},
79
+ author={Li, Yaowei and Wang, Xintao and Zhang, Zhaoyang and Wang, Zhouxia and Yuan, Ziyang and Xie, Liangbin and Zou, Yuexian and Shan, Ying},
80
+ year={2024},
81
+ eprint={2406.15339},
82
+ archivePrefix={arXiv},
83
+ primaryClass={cs.CV}
84
+ }
85
+ ```
86
+
87
+ 📧 **Contact**
88
+ <br>
89
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
90
+
91
+ # """
92
+
93
+ os.makedirs("models/personalized")
94
+ os.makedirs("models/sd1-5")
95
+
96
+ if not os.path.exists("models/flow_controlnet.ckpt"):
97
+ os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/flow_controlnet.ckpt?download=true -P models/')
98
+ os.system(f'mv models/flow_controlnet.ckpt?download=true models/flow_controlnet.ckpt')
99
+ print("flow_controlnet Download!", )
100
+
101
+ if not os.path.exists("models/image_controlnet.ckpt"):
102
+ os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/image_controlnet.ckpt?download=true -P models/')
103
+ os.system(f'mv models/image_controlnet.ckpt?download=true models/image_controlnet.ckpt')
104
+ print("image_controlnet Download!", )
105
+
106
+ if not os.path.exists("models/unet.ckpt"):
107
+ os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/unet.ckpt?download=true -P models/')
108
+ os.system(f'mv models/unet.ckpt?download=true models/unet.ckpt')
109
+ print("unet Download!", )
110
+
111
+
112
+ if not os.path.exists("models/sd1-5/config.json"):
113
+ os.system(f'wget -q https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/config.json?download=true -P models/sd1-5/')
114
+ os.system(f'mv models/sd1-5/config.json?download=true models/sd1-5/config.json')
115
+ print("config Download!", )
116
+
117
+
118
+ if not os.path.exists("models/sd1-5/unet.ckpt"):
119
+ os.system(f'cp -r models/unet.ckpt models/sd1-5/unet.ckpt')
120
+
121
+ # os.system(f'wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin?download=true -P models/sd1-5/')
122
+
123
+ if not os.path.exists("models/personalized/helloobjects_V12c.safetensors"):
124
+ os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/helloobjects_V12c.safetensors?download=true -P models/personalized')
125
+ os.system(f'mv models/personalized/helloobjects_V12c.safetensors?download=true models/personalized/helloobjects_V12c.safetensors')
126
+ print("helloobjects_V12c Download!", )
127
+
128
+
129
+ if not os.path.exists("models/personalized/TUSUN.safetensors"):
130
+ os.system(f'wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/TUSUN.safetensors?download=true -P models/personalized')
131
+ os.system(f'mv models/personalized/TUSUN.safetensors?download=true models/personalized/TUSUN.safetensors')
132
+ print("TUSUN Download!", )
133
+
134
+ # mv1 = os.system(f'mv /usr/local/lib/python3.10/site-packages/gradio/helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers_bkp.py')
135
+ # mv2 = os.system(f'mv helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers.py')
136
+
137
+
138
+ # # 检查命令是否成功
139
+ # if mv1 == 0 and mv2 == 0:
140
+ # print("file move success!")
141
+ # else:
142
+ # print("file move failed!")
143
+
144
+
145
+ # - - - - - examples - - - - - #
146
+
147
+ image_examples = [
148
+ ["__asset__/images/object/turtle-1.jpg",
149
+ "a sea turtle gracefully swimming over a coral reef in the clear blue ocean.",
150
+ "object",
151
+ 11318446767408804497,
152
+ "",
153
+ "__asset__/images/object/turtle-1.jpg",
154
+ json.load(open("__asset__/trajs/object/turtle-1.json"))
155
+ ],
156
+
157
+ # ["__asset__/images/object/rose-1.jpg",
158
+ # "a red rose engulfed in flames.",
159
+ # "object",
160
+ # 6854275249656120509,
161
+ # "",
162
+ # "rose",
163
+ # ],
164
+
165
+ # ["__asset__/images/object/jellyfish-1.jpg",
166
+ # "intricate detailing,photorealism,hyperrealistic, glowing jellyfish mushroom, flying, starry sky, bokeh, golden ratio composition.",
167
+ # "object",
168
+ # 17966188172968903484,
169
+ # "HelloObject",
170
+ # "jellyfish"
171
+ # ],
172
+
173
+
174
+ # ["__asset__/images/camera/lush-1.jpg",
175
+ # "detailed craftsmanship, photorealism, hyperrealistic, roaring waterfall, misty spray, lush greenery, vibrant rainbow, golden ratio composition.",
176
+ # "camera",
177
+ # 7970487946960948963,
178
+ # "HelloObject",
179
+ # "lush",
180
+ # ],
181
+
182
+ # ["__asset__/images/camera/tusun-1.jpg",
183
+ # "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing.",
184
+ # "camera",
185
+ # 996953226890228361,
186
+ # "TUSUN",
187
+ # "tusun",
188
+ # ],
189
+
190
+ # ["__asset__/images/camera/painting-1.jpg",
191
+ # "A oil painting.",
192
+ # "camera",
193
+ # 16867854766769816385,
194
+ # "",
195
+ # "painting"
196
+ # ],
197
+ ]
198
+
199
+
200
+ # POINTS = {
201
+ # 'turtle': "__asset__/trajs/object/turtle-1.json",
202
+ # 'rose': "__asset__/trajs/object/rose-1.json",
203
+ # 'jellyfish': "__asset__/trajs/object/jellyfish-1.json",
204
+ # 'lush': "__asset__/trajs/camera/lush-1.json",
205
+ # 'tusun': "__asset__/trajs/camera/tusun-1.json",
206
+ # 'painting': "__asset__/trajs/camera/painting-1.json",
207
+ # }
208
+
209
+ # IMAGE_PATH = {
210
+ # 'turtle': "__asset__/images/object/turtle-1.jpg",
211
+ # 'rose': "__asset__/images/object/rose-1.jpg",
212
+ # 'jellyfish': "__asset__/images/object/jellyfish-1.jpg",
213
+ # 'lush': "__asset__/images/camera/lush-1.jpg",
214
+ # 'tusun': "__asset__/images/camera/tusun-1.jpg",
215
+ # 'painting': "__asset__/images/camera/painting-1.jpg",
216
+ # }
217
+
218
+
219
+
220
+ DREAM_BOOTH = {
221
+ 'HelloObject': 'models/personalized/helloobjects_V12c.safetensors',
222
+ }
223
+
224
+ LORA = {
225
+ 'TUSUN': 'models/personalized/TUSUN.safetensors',
226
+ }
227
+
228
+ LORA_ALPHA = {
229
+ 'TUSUN': 0.6,
230
+ }
231
+
232
+ NPROMPT = {
233
+ "HelloObject": 'FastNegativeV2,(bad-artist:1),(worst quality, low quality:1.4),(bad_prompt_version2:0.8),bad-hands-5,lowres,bad anatomy,bad hands,((text)),(watermark),error,missing fingers,extra digit,fewer digits,cropped,worst quality,low quality,normal quality,((username)),blurry,(extra limbs),bad-artist-anime,badhandv4,EasyNegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3,BadDream,(three hands:1.6),(three legs:1.2),(more than two hands:1.4),(more than two legs,:1.2)'
234
+ }
235
+
236
+ output_dir = "outputs"
237
+ ensure_dirname(output_dir)
238
+
239
+ def points_to_flows(track_points, model_length, height, width):
240
+ input_drag = np.zeros((model_length - 1, height, width, 2))
241
+ for splited_track in track_points:
242
+ if len(splited_track) == 1: # stationary point
243
+ displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
244
+ splited_track = tuple([splited_track[0], displacement_point])
245
+ # interpolate the track
246
+ splited_track = interpolate_trajectory(splited_track, model_length)
247
+ splited_track = splited_track[:model_length]
248
+ if len(splited_track) < model_length:
249
+ splited_track = splited_track + [splited_track[-1]] * (model_length -len(splited_track))
250
+ for i in range(model_length - 1):
251
+ start_point = splited_track[i]
252
+ end_point = splited_track[i+1]
253
+ input_drag[i][int(start_point[1])][int(start_point[0])][0] = end_point[0] - start_point[0]
254
+ input_drag[i][int(start_point[1])][int(start_point[0])][1] = end_point[1] - start_point[1]
255
+ return input_drag
256
+
257
+ class ImageConductor:
258
+ def __init__(self, device, unet_path, image_controlnet_path, flow_controlnet_path, height, width, model_length, lora_rank=64):
259
+ self.device = device
260
+ tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
261
+ text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(device)
262
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(device)
263
+ inference_config = OmegaConf.load("configs/inference/inference.yaml")
264
+ unet = UNet3DConditionFlowModel.from_pretrained_2d("models/sd1-5/", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
265
+
266
+ self.vae = vae
267
+
268
+ ### >>> Initialize UNet module >>> ###
269
+ load_model(unet, unet_path)
270
+
271
+ ### >>> Initialize image controlnet module >>> ###
272
+ image_controlnet = create_image_controlnet("configs/inference/image_condition.yaml", unet)
273
+ load_model(image_controlnet, image_controlnet_path)
274
+ ### >>> Initialize flow controlnet module >>> ###
275
+ flow_controlnet = create_flow_controlnet("configs/inference/flow_condition.yaml", unet)
276
+ add_LoRA_to_controlnet(lora_rank, flow_controlnet)
277
+ load_model(flow_controlnet, flow_controlnet_path)
278
+
279
+ unet.eval().to(device)
280
+ image_controlnet.eval().to(device)
281
+ flow_controlnet.eval().to(device)
282
+
283
+ self.pipeline = ImageConductorPipeline(
284
+ unet=unet,
285
+ vae=vae,
286
+ tokenizer=tokenizer,
287
+ text_encoder=text_encoder,
288
+ scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
289
+ image_controlnet=image_controlnet,
290
+ flow_controlnet=flow_controlnet,
291
+ ).to(device)
292
+
293
+
294
+ self.height = height
295
+ self.width = width
296
+ # _, model_step, _ = split_filename(model_path)
297
+ # self.ouput_prefix = f'{model_step}_{width}X{height}'
298
+ self.model_length = model_length
299
+
300
+ blur_kernel = bivariate_Gaussian(kernel_size=99, sig_x=10, sig_y=10, theta=0, grid=None, isotropic=True)
301
+
302
+ self.blur_kernel = blur_kernel
303
+
304
+ @spaces.GPU(duration=120)
305
+ def run(self, first_frame_path, tracking_points, prompt, drag_mode, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type):
306
+ print("Run!")
307
+ if examples_type != "":
308
+ ### for adapting high version gradio
309
+ tracking_points = gr.State([])
310
+ first_frame_path = IMAGE_PATH[examples_type]
311
+ points = json.load(open(POINTS[examples_type]))
312
+ tracking_points.value.extend(points)
313
+ print("example first_frame_path", first_frame_path)
314
+ print("example tracking_points", tracking_points.value)
315
+
316
+ original_width, original_height=384, 256
317
+ if isinstance(tracking_points, list):
318
+ input_all_points = tracking_points
319
+ else:
320
+ input_all_points = tracking_points.value
321
+
322
+ print("input_all_points", input_all_points)
323
+ resized_all_points = [tuple([tuple([float(e1[0]*self.width/original_width), float(e1[1]*self.height/original_height)]) for e1 in e]) for e in input_all_points]
324
+
325
+ dir, base, ext = split_filename(first_frame_path)
326
+ id = base.split('_')[-1]
327
+
328
+
329
+ visualized_drag, _ = visualize_drag(first_frame_path, resized_all_points, drag_mode, self.width, self.height, self.model_length)
330
+
331
+ ## image condition
332
+ image_transforms = transforms.Compose([
333
+ transforms.RandomResizedCrop(
334
+ (self.height, self.width), (1.0, 1.0),
335
+ ratio=(self.width/self.height, self.width/self.height)
336
+ ),
337
+ transforms.ToTensor(),
338
+ ])
339
+
340
+ image_paths = [first_frame_path]
341
+ controlnet_images = [(image_transforms(Image.open(path).convert("RGB"))) for path in image_paths]
342
+ controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device)
343
+ controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")
344
+ num_controlnet_images = controlnet_images.shape[2]
345
+ controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w")
346
+ self.vae.to(device)
347
+ controlnet_images = self.vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215
348
+ controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images)
349
+
350
+ # flow condition
351
+ controlnet_flows = points_to_flows(resized_all_points, self.model_length, self.height, self.width)
352
+ for i in range(0, self.model_length-1):
353
+ controlnet_flows[i] = cv2.filter2D(controlnet_flows[i], -1, self.blur_kernel)
354
+ controlnet_flows = np.concatenate([np.zeros_like(controlnet_flows[0])[np.newaxis, ...], controlnet_flows], axis=0) # pad the first frame with zero flow
355
+ os.makedirs(os.path.join(output_dir, "control_flows"), exist_ok=True)
356
+ trajs_video = vis_flow_to_video(controlnet_flows, num_frames=self.model_length) # T-1 x H x W x 3
357
+ torchvision.io.write_video(f'{output_dir}/control_flows/sample-{id}-train_flow.mp4', trajs_video, fps=8, video_codec='h264', options={'crf': '10'})
358
+ controlnet_flows = torch.from_numpy(controlnet_flows)[None][:, :self.model_length, ...]
359
+ controlnet_flows = rearrange(controlnet_flows, "b f h w c-> b c f h w").float().to(device)
360
+
361
+ dreambooth_model_path = DREAM_BOOTH.get(personalized, '')
362
+ lora_model_path = LORA.get(personalized, '')
363
+ lora_alpha = LORA_ALPHA.get(personalized, 0.6)
364
+ self.pipeline = load_weights(
365
+ self.pipeline,
366
+ dreambooth_model_path = dreambooth_model_path,
367
+ lora_model_path = lora_model_path,
368
+ lora_alpha = lora_alpha,
369
+ ).to(device)
370
+
371
+ if NPROMPT.get(personalized, '') != '':
372
+ negative_prompt = NPROMPT.get(personalized)
373
+
374
+ if randomize_seed:
375
+ random_seed = torch.seed()
376
+ else:
377
+ seed = int(seed)
378
+ random_seed = seed
379
+ torch.manual_seed(random_seed)
380
+ torch.cuda.manual_seed_all(random_seed)
381
+ print(f"current seed: {torch.initial_seed()}")
382
+ sample = self.pipeline(
383
+ prompt,
384
+ negative_prompt = negative_prompt,
385
+ num_inference_steps = num_inference_steps,
386
+ guidance_scale = guidance_scale,
387
+ width = self.width,
388
+ height = self.height,
389
+ video_length = self.model_length,
390
+ controlnet_images = controlnet_images, # 1 4 1 32 48
391
+ controlnet_image_index = [0],
392
+ controlnet_flows = controlnet_flows,# [1, 2, 16, 256, 384]
393
+ control_mode = drag_mode,
394
+ eval_mode = True,
395
+ ).videos
396
+
397
+ # outputs_path = os.path.join(output_dir, f'output_{i}_{id}.mp4')
398
+ # vis_video = (rearrange(sample[0], 'c t h w -> t h w c') * 255.).clip(0, 255)
399
+ # torchvision.io.write_video(outputs_path, vis_video, fps=8, video_codec='h264', options={'crf': '10'})
400
+
401
+ outputs_path = os.path.join(output_dir, f'output_{i}_{id}.gif')
402
+ save_videos_grid(sample[0][None], outputs_path)
403
+ print("Done!")
404
+ return {output_image: visualized_drag, output_video: outputs_path}
405
+
406
+
407
+ def reset_states(first_frame_path, tracking_points):
408
+ first_frame_path = gr.State()
409
+ tracking_points = gr.State([])
410
+ return {input_image:None, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points}
411
+
412
+
413
+ def preprocess_image(image, tracking_points):
414
+ image_pil = image2pil(image.name)
415
+ raw_w, raw_h = image_pil.size
416
+ resize_ratio = max(384/raw_w, 256/raw_h)
417
+ image_pil = image_pil.resize((int(raw_w * resize_ratio), int(raw_h * resize_ratio)), Image.BILINEAR)
418
+ image_pil = transforms.CenterCrop((256, 384))(image_pil.convert('RGB'))
419
+ id = str(uuid.uuid4())[:4]
420
+ first_frame_path = os.path.join(output_dir, f"first_frame_{id}.jpg")
421
+ image_pil.save(first_frame_path, quality=95)
422
+ tracking_points = gr.State([])
423
+ return {input_image: first_frame_path, first_frame_path_var: first_frame_path, tracking_points_var: tracking_points, personalized:""}
424
+
425
+
426
+ def add_tracking_points(tracking_points, first_frame_path, drag_mode, evt: gr.SelectData): # SelectData is a subclass of EventData
427
+ if drag_mode=='object':
428
+ color = (255, 0, 0, 255)
429
+ elif drag_mode=='camera':
430
+ color = (0, 0, 255, 255)
431
+
432
+ if not isinstance(tracking_points ,list):
433
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
434
+ tracking_points.value[-1].append(evt.index)
435
+ print(tracking_points.value)
436
+ tracking_points_values = tracking_points.value
437
+ else:
438
+ try:
439
+ tracking_points[-1].append(evt.index)
440
+ except Exception as e:
441
+ tracking_points.append([])
442
+ tracking_points[-1].append(evt.index)
443
+ print(f"Solved Error: {e}")
444
+
445
+ tracking_points_values = tracking_points
446
+
447
+
448
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
449
+ w, h = transparent_background.size
450
+ transparent_layer = np.zeros((h, w, 4))
451
+
452
+ for track in tracking_points_values:
453
+ if len(track) > 1:
454
+ for i in range(len(track)-1):
455
+ start_point = track[i]
456
+ end_point = track[i+1]
457
+ vx = end_point[0] - start_point[0]
458
+ vy = end_point[1] - start_point[1]
459
+ arrow_length = np.sqrt(vx**2 + vy**2)
460
+ if i == len(track)-2:
461
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
462
+ else:
463
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
464
+ else:
465
+ cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
466
+
467
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
468
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
469
+ return {tracking_points_var: tracking_points, input_image: trajectory_map}
470
+
471
+
472
+ def add_drag(tracking_points):
473
+ if not isinstance(tracking_points ,list):
474
+ # print("before", tracking_points.value)
475
+ tracking_points.value.append([])
476
+ # print(tracking_points.value)
477
+ else:
478
+ tracking_points.append([])
479
+ return {tracking_points_var: tracking_points}
480
+
481
+
482
+ def delete_last_drag(tracking_points, first_frame_path, drag_mode):
483
+ if drag_mode=='object':
484
+ color = (255, 0, 0, 255)
485
+ elif drag_mode=='camera':
486
+ color = (0, 0, 255, 255)
487
+ tracking_points.value.pop()
488
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
489
+ w, h = transparent_background.size
490
+ transparent_layer = np.zeros((h, w, 4))
491
+ for track in tracking_points.value:
492
+ if len(track) > 1:
493
+ for i in range(len(track)-1):
494
+ start_point = track[i]
495
+ end_point = track[i+1]
496
+ vx = end_point[0] - start_point[0]
497
+ vy = end_point[1] - start_point[1]
498
+ arrow_length = np.sqrt(vx**2 + vy**2)
499
+ if i == len(track)-2:
500
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
501
+ else:
502
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
503
+ else:
504
+ cv2.circle(transparent_layer, tuple(track[0]), 5, color, -1)
505
+
506
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
507
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
508
+ return {tracking_points_var: tracking_points, input_image: trajectory_map}
509
+
510
+
511
+ def delete_last_step(tracking_points, first_frame_path, drag_mode):
512
+ if drag_mode=='object':
513
+ color = (255, 0, 0, 255)
514
+ elif drag_mode=='camera':
515
+ color = (0, 0, 255, 255)
516
+ tracking_points.value[-1].pop()
517
+ transparent_background = Image.open(first_frame_path).convert('RGBA')
518
+ w, h = transparent_background.size
519
+ transparent_layer = np.zeros((h, w, 4))
520
+ for track in tracking_points.value:
521
+ if len(track) > 1:
522
+ for i in range(len(track)-1):
523
+ start_point = track[i]
524
+ end_point = track[i+1]
525
+ vx = end_point[0] - start_point[0]
526
+ vy = end_point[1] - start_point[1]
527
+ arrow_length = np.sqrt(vx**2 + vy**2)
528
+ if i == len(track)-2:
529
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), color, 2, tipLength=8 / arrow_length)
530
+ else:
531
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), color, 2,)
532
+ else:
533
+ cv2.circle(transparent_layer, tuple(track[0]), 5,color, -1)
534
+
535
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
536
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
537
+ return {tracking_points_var: tracking_points, input_image: trajectory_map}
538
+
539
+
540
+ block = gr.Blocks(
541
+ theme=gr.themes.Soft(
542
+ radius_size=gr.themes.sizes.radius_none,
543
+ text_size=gr.themes.sizes.text_md
544
+ )
545
+ )
546
+ with block:
547
+ with gr.Row():
548
+ with gr.Column():
549
+ gr.HTML(head)
550
+
551
+ gr.Markdown(descriptions)
552
+
553
+ with gr.Accordion(label="🛠️ Instructions:", open=True, elem_id="accordion"):
554
+ with gr.Row(equal_height=True):
555
+ gr.Markdown(instructions)
556
+
557
+
558
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
559
+ device = torch.device("cuda")
560
+ unet_path = 'models/unet.ckpt'
561
+ image_controlnet_path = 'models/image_controlnet.ckpt'
562
+ flow_controlnet_path = 'models/flow_controlnet.ckpt'
563
+ ImageConductor_net = ImageConductor(device=device,
564
+ unet_path=unet_path,
565
+ image_controlnet_path=image_controlnet_path,
566
+ flow_controlnet_path=flow_controlnet_path,
567
+ height=256,
568
+ width=384,
569
+ model_length=16
570
+ )
571
+ first_frame_path_var = gr.State(value=None)
572
+ tracking_points_var = gr.State([])
573
+
574
+ with gr.Row():
575
+ with gr.Column(scale=1):
576
+ image_upload_button = gr.UploadButton(label="Upload Image",file_types=["image"])
577
+ add_drag_button = gr.Button(value="Add Drag")
578
+ reset_button = gr.Button(value="Reset")
579
+ delete_last_drag_button = gr.Button(value="Delete last drag")
580
+ delete_last_step_button = gr.Button(value="Delete last step")
581
+
582
+
583
+
584
+ with gr.Column(scale=7):
585
+ with gr.Row():
586
+ with gr.Column(scale=6):
587
+ input_image = gr.Image(label="Input Image",
588
+ interactive=True,
589
+ height=300,
590
+ width=384,)
591
+ with gr.Column(scale=6):
592
+ output_image = gr.Image(label="Motion Path",
593
+ interactive=False,
594
+ height=256,
595
+ width=384,)
596
+ with gr.Row():
597
+ with gr.Column(scale=1):
598
+ prompt = gr.Textbox(value="a wonderful elf.", label="Prompt (highly-recommended)", interactive=True, visible=True)
599
+ negative_prompt = gr.Text(
600
+ label="Negative Prompt",
601
+ max_lines=5,
602
+ placeholder="Please input your negative prompt",
603
+ value='worst quality, low quality, letterboxed',lines=1
604
+ )
605
+ drag_mode = gr.Radio(['camera', 'object'], label='Drag mode: ', value='object', scale=2)
606
+ run_button = gr.Button(value="Run")
607
+
608
+ with gr.Accordion("More input params", open=False, elem_id="accordion1"):
609
+ with gr.Group():
610
+ seed = gr.Textbox(
611
+ label="Seed: ", value=561793204,
612
+ )
613
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
614
+
615
+ with gr.Group():
616
+ with gr.Row():
617
+ guidance_scale = gr.Slider(
618
+ label="Guidance scale",
619
+ minimum=1,
620
+ maximum=12,
621
+ step=0.1,
622
+ value=8.5,
623
+ )
624
+ num_inference_steps = gr.Slider(
625
+ label="Number of inference steps",
626
+ minimum=1,
627
+ maximum=50,
628
+ step=1,
629
+ value=25,
630
+ )
631
+
632
+ with gr.Group():
633
+ personalized = gr.Dropdown(label="Personalized", choices=['HelloObject', 'TUSUN', ""], value="")
634
+ examples_type = gr.Textbox(label="Examples Type (Ignore) ", value="", visible=False)
635
+
636
+ with gr.Column(scale=7):
637
+ # output_video = gr.Video(
638
+ # label="Output Video",
639
+ # width=384,
640
+ # height=256)
641
+ output_video = gr.Image(label="Output Video",
642
+ height=256,
643
+ width=384,)
644
+
645
+
646
+ with gr.Row():
647
+
648
+
649
+ example = gr.Examples(
650
+ label="Input Example",
651
+ examples=image_examples,
652
+ inputs=[input_image, prompt, drag_mode, seed, personalized, first_frame_path_var, tracking_points_var],
653
+ examples_per_page=10,
654
+ cache_examples=False,
655
+ )
656
+
657
+
658
+ with gr.Row():
659
+ gr.Markdown(citation)
660
+
661
+
662
+ image_upload_button.upload(preprocess_image, [image_upload_button, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var, personalized])
663
+
664
+ add_drag_button.click(add_drag, tracking_points_var, tracking_points_var)
665
+
666
+ delete_last_drag_button.click(delete_last_drag, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
667
+
668
+ delete_last_step_button.click(delete_last_step, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
669
+
670
+ reset_button.click(reset_states, [first_frame_path_var, tracking_points_var], [input_image, first_frame_path_var, tracking_points_var])
671
+
672
+ input_image.select(add_tracking_points, [tracking_points_var, first_frame_path_var, drag_mode], [tracking_points_var, input_image])
673
+
674
+ run_button.click(ImageConductor_net.run, [first_frame_path_var, tracking_points_var, prompt, drag_mode,
675
+ negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, personalized, examples_type],
676
+ [output_image, output_video])
677
+
678
+ block.queue().launch()