reedmayhew commited on
Commit
7d5d19b
·
verified ·
1 Parent(s): 7b309dd

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -702
app.py DELETED
@@ -1,702 +0,0 @@
1
- #!/usr/bin/env python
2
- """
3
- This is the full application script for VideoPainter.
4
- It first checks for and (if necessary) installs missing dependencies.
5
- When installing the custom packages (diffusers and app),
6
- it uses the flag --no-build-isolation so that the installed torch is seen.
7
- If the custom diffusers package fails to provide the expected submodules,
8
- the script will force-install the official diffusers package.
9
- """
10
-
11
- import os
12
- import sys
13
- import subprocess
14
- import warnings
15
- import time
16
- import json
17
- from collections import OrderedDict
18
-
19
- warnings.filterwarnings("ignore")
20
-
21
- ###############################
22
- # Set up temporary directories
23
- ###############################
24
- GRADIO_TEMP_DIR = "./tmp_gradio"
25
- os.makedirs(GRADIO_TEMP_DIR, exist_ok=True)
26
- os.makedirs(os.path.join(GRADIO_TEMP_DIR, "track"), exist_ok=True)
27
- os.makedirs(os.path.join(GRADIO_TEMP_DIR, "inpaint"), exist_ok=True)
28
- os.environ["GRADIO_TEMP_DIR"] = GRADIO_TEMP_DIR
29
-
30
- ###############################
31
- # Helper: Install package via pip
32
- ###############################
33
- def install_package(package_spec):
34
- print(f"Installing {package_spec} ...")
35
- try:
36
- subprocess.check_call([sys.executable, "-m", "pip", "install", package_spec])
37
- print(f"Successfully installed {package_spec}")
38
- return True
39
- except Exception as e:
40
- print(f"Failed to install {package_spec}: {e}")
41
- return False
42
-
43
- ###############################
44
- # Ensure PyTorch is present
45
- ###############################
46
- print("Checking for PyTorch ...")
47
- try:
48
- import torch
49
- print("PyTorch is already installed.")
50
- except ImportError:
51
- print("PyTorch not found, installing...")
52
- if not install_package("torch>=2.0.0 torchvision>=0.15.0"):
53
- print("Failed to install PyTorch, which is required.")
54
- sys.exit(1)
55
-
56
- ###############################
57
- # Check/install critical dependencies
58
- ###############################
59
- critical_dependencies = [
60
- ("hydra", "hydra-core>=1.3.2"),
61
- ("omegaconf", "omegaconf>=2.3.0"),
62
- ("decord", "decord>=0.6.0"),
63
- ("diffusers", "diffusers>=0.24.0"), # This one is later replaced by our custom version.
64
- ("transformers", "transformers>=4.35.0"),
65
- ("gradio", "gradio>=4.0.0"),
66
- ("numpy", "numpy>=1.24.0"),
67
- ("cv2", "opencv-python>=4.8.0"),
68
- ("PIL", "Pillow>=10.0.0"),
69
- ("scipy", "scipy>=1.11.0"),
70
- ("einops", "einops>=0.7.0"),
71
- ("onnxruntime", "onnxruntime>=1.16.0"),
72
- ("timm", "timm>=0.9.0"),
73
- ("safetensors", "safetensors>=0.4.0"),
74
- ("moviepy", "moviepy>=1.0.3"),
75
- ("imageio", "imageio>=2.30.0"),
76
- ("tqdm", "tqdm>=4.64.0"),
77
- ("openai", "openai>=1.5.0"),
78
- ("psutil", "psutil>=5.9.0")
79
- ]
80
-
81
- for mod_name, pkg_spec in critical_dependencies:
82
- try:
83
- if mod_name == "PIL":
84
- from PIL import Image
85
- elif mod_name == "cv2":
86
- import cv2
87
- else:
88
- __import__(mod_name)
89
- print(f"{mod_name} is already installed.")
90
- except ImportError:
91
- print(f"{mod_name} not found, installing {pkg_spec} ...")
92
- install_package(pkg_spec)
93
-
94
- ###############################
95
- # Environment setup: Clone repository, install custom packages
96
- ###############################
97
- print("Setting up environment...")
98
-
99
- # Clone the VideoPainter repository if not present
100
- if not os.path.exists("VideoPainter"):
101
- print("Cloning VideoPainter repository...")
102
- os.system("git clone https://github.com/TencentARC/VideoPainter.git")
103
-
104
- # Append repository folders to sys.path (if not already)
105
- sys.path.append(os.path.join(os.getcwd(), "VideoPainter"))
106
- sys.path.append(os.path.join(os.getcwd(), "VideoPainter/app"))
107
- sys.path.append(os.path.join(os.getcwd(), "app"))
108
- sys.path.append(".")
109
-
110
- # Install the custom diffusers package from VideoPainter/diffusers.
111
- if os.path.exists("VideoPainter/diffusers"):
112
- print("Installing custom diffusers (editable, no-build-isolation)...")
113
- os.system("pip install --no-build-isolation -e VideoPainter/diffusers")
114
-
115
- # Copy VideoPainter/app to local 'app' directory if needed.
116
- if not os.path.exists("app"):
117
- os.makedirs("app", exist_ok=True)
118
- print("Copying VideoPainter/app to local app directory...")
119
- os.system("cp -r VideoPainter/app/* app/")
120
-
121
- # Install the app package in editable mode.
122
- if os.path.exists("app"):
123
- curr_dir = os.getcwd()
124
- os.chdir("app")
125
- print("Installing app package (editable, no-build-isolation)...")
126
- ret = os.system("pip install --no-build-isolation -e .")
127
- if ret != 0:
128
- print("Warning: Installing the app package failed; continuing by adding 'app' to sys.path.")
129
- os.chdir(curr_dir)
130
-
131
- ###############################
132
- # Import modules – if any critical module is missing, exit.
133
- ###############################
134
- try:
135
- print("Importing modules...")
136
- import gradio as gr
137
- import cv2
138
- import numpy as np
139
- import scipy
140
- import torchvision
141
- from PIL import Image
142
- from huggingface_hub import snapshot_download
143
- from decord import VideoReader
144
- from sam2.build_sam import build_sam2_video_predictor
145
- from utils import load_model, generate_frames
146
- print("Standard and specialized modules imported successfully!")
147
- except ImportError as e:
148
- print(f"Error importing modules: {e}")
149
- sys.exit(1)
150
-
151
- ###############################
152
- # Validate diffusers installation.
153
- ###############################
154
- try:
155
- from diffusers import pipelines # Expect this to work.
156
- print("Custom diffusers installation appears complete.")
157
- except Exception as e:
158
- print("Custom diffusers installation appears broken:")
159
- print(e)
160
- print("Installing official diffusers package from PyPI (>=0.24.0)...")
161
- if install_package("diffusers>=0.24.0 --force-reinstall"):
162
- try:
163
- from diffusers import pipelines
164
- print("Official diffusers package installed successfully.")
165
- except Exception as e2:
166
- print("Failed to import diffusers even after installing official version.")
167
- sys.exit(1)
168
- else:
169
- sys.exit(1)
170
-
171
- ###############################
172
- # Begin Application Code (VideoPainter demo)
173
- ###############################
174
-
175
- def download_models():
176
- print("Downloading models from Hugging Face Hub...")
177
- models = {
178
- "CogVideoX-5b-I2V": "THUDM/CogVideoX-5b-I2V",
179
- "VideoPainter": "TencentARC/VideoPainter"
180
- }
181
- model_paths = {}
182
- os.makedirs("ckpt", exist_ok=True)
183
- for name, repo_id in models.items():
184
- print(f"Downloading {name} from {repo_id}...")
185
- path = snapshot_download(repo_id=repo_id)
186
- model_paths[name] = path
187
- print(f"Downloaded {name} to {path}")
188
- try:
189
- flux_path = snapshot_download(repo_id="black-forest-labs/FLUX.1-Fill-dev")
190
- model_paths["FLUX"] = flux_path
191
- except Exception as e:
192
- print(f"Failed to download FLUX model: {e}")
193
- model_paths["FLUX"] = None
194
- os.makedirs("ckpt/Grounded-SAM-2", exist_ok=True)
195
- sam2_path = "ckpt/Grounded-SAM-2/sam2_hiera_large.pt"
196
- if not os.path.exists(sam2_path):
197
- print(f"Downloading SAM2 to {sam2_path}...")
198
- os.system(f"wget -O {sam2_path} https://huggingface.co/spaces/sam2/sam2/resolve/main/sam2_hiera_large.pt")
199
- model_paths["SAM2"] = sam2_path
200
- return model_paths
201
-
202
- print("Initializing application environment...")
203
- if not os.path.exists("app"):
204
- print("Setting up app folder from VideoPainter repository ...")
205
- os.system("git clone https://github.com/TencentARC/VideoPainter.git")
206
- os.makedirs("app", exist_ok=True)
207
- os.system("cp -r VideoPainter/app/* app/")
208
- os.system("pip install --no-build-isolation -e VideoPainter/diffusers")
209
- os.chdir("app")
210
- os.system("pip install --no-build-isolation -e .")
211
- os.chdir("..")
212
-
213
- sys.path.append("app")
214
- sys.path.append(".")
215
-
216
- # Import project modules (again, to be safe)
217
- try:
218
- from decord import VideoReader
219
- from sam2.build_sam import build_sam2_video_predictor
220
- from utils import load_model, generate_frames
221
- except ImportError as e:
222
- print(f"Failed to import specialized modules: {e}")
223
- sys.exit(1)
224
-
225
- # Set up OpenRouter / OpenAI (for caption generation)
226
- try:
227
- from openai import OpenAI
228
- vlm_model = OpenAI(
229
- api_key=os.getenv("OPENROUTER_API_KEY", ""),
230
- base_url="https://openrouter.ai/api/v1"
231
- )
232
- print("OpenRouter client initialized successfully")
233
- except Exception as e:
234
- print(f"OpenRouter API not available: {e}")
235
- class DummyModel:
236
- def __getattr__(self, name):
237
- return self
238
- def __call__(self, *args, **kwargs):
239
- return self
240
- def create(self, *args, **kwargs):
241
- class DummyResponse:
242
- choices = [type('obj', (object,), {'message': type('obj', (object,), {'content': "OpenRouter API not available. Using default prompt."})})]
243
- return DummyResponse()
244
- vlm_model = DummyModel()
245
-
246
- ###############################
247
- # Download models and initialize predictors
248
- ###############################
249
- model_paths = download_models()
250
- base_model_path = model_paths["CogVideoX-5b-I2V"]
251
- videopainter_path = model_paths["VideoPainter"]
252
- inpainting_branch = os.path.join(videopainter_path, "checkpoints/branch")
253
- id_adapter = os.path.join(videopainter_path, "VideoPainterID/checkpoints")
254
- img_inpainting_model = model_paths.get("FLUX")
255
- sam2_checkpoint = "ckpt/Grounded-SAM-2/sam2_hiera_large.pt"
256
- model_cfg = "sam2_hiera_l.yaml"
257
-
258
- try:
259
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
260
- print("Build SAM2 predictor done!")
261
- validation_pipeline, validation_pipeline_img = load_model(
262
- model_path=base_model_path,
263
- inpainting_branch=inpainting_branch,
264
- id_adapter=id_adapter,
265
- img_inpainting_model=img_inpainting_model
266
- )
267
- print("Load model done!")
268
- except Exception as e:
269
- print(f"Error initializing models: {e}")
270
- sys.exit(1)
271
-
272
- ###############################
273
- # Helper functions & state definitions
274
- ###############################
275
- EXAMPLES = [
276
- [
277
- "https://huggingface.co/spaces/TencentARC/VideoPainter/resolve/main/examples/ferry.mp4",
278
- "A white ferry with red and blue accents, named 'COLONIA', cruises on a calm river...",
279
- "White and red passenger ferry boat labeled 'COLONIA 6' with multiple windows, life buoys, and upper deck seating.",
280
- "Positive",
281
- "Inpaint",
282
- "",
283
- 42,
284
- 6.0,
285
- 16,
286
- [[[320, 240]], [1]],
287
- ],
288
- [
289
- "https://huggingface.co/spaces/TencentARC/VideoPainter/resolve/main/examples/street.mp4",
290
- "A bustling city street at night illuminated by festive lights, a red double-decker bus...",
291
- "The rear of a black car with illuminated red tail lights and a visible license plate.",
292
- "Positive",
293
- "Inpaint",
294
- "",
295
- 42,
296
- 6.0,
297
- 16,
298
- [[[200, 400]], [1]],
299
- ],
300
- ]
301
-
302
- class StatusMessage:
303
- INFO = "Info"
304
- WARNING = "Warning"
305
- ERROR = "Error"
306
- SUCCESS = "Success"
307
-
308
- def create_status(message, status_type=StatusMessage.INFO):
309
- timestamp = time.strftime("%H:%M:%S")
310
- return [("", ""), (f"[{timestamp}]: {message}\n", status_type)]
311
-
312
- def update_status(previous_status, new_message, status_type=StatusMessage.INFO):
313
- timestamp = time.strftime("%H:%M:%S")
314
- history = previous_status[-3:]
315
- history.append((f"[{timestamp}]: {new_message}\n", status_type))
316
- return [("", "")] + history
317
-
318
- def init_state(offload_video_to_cpu=False, offload_state_to_cpu=False):
319
- inference_state = {}
320
- inference_state["images"] = torch.zeros([1, 3, 100, 100])
321
- inference_state["num_frames"] = 1
322
- inference_state["offload_video_to_cpu"] = offload_video_to_cpu
323
- inference_state["offload_state_to_cpu"] = offload_state_to_cpu
324
- inference_state["video_height"] = 100
325
- inference_state["video_width"] = 100
326
- inference_state["device"] = torch.device("cuda")
327
- inference_state["storage_device"] = torch.device("cpu") if offload_state_to_cpu else torch.device("cuda")
328
- inference_state["point_inputs_per_obj"] = {}
329
- inference_state["mask_inputs_per_obj"] = {}
330
- inference_state["cached_features"] = {}
331
- inference_state["constants"] = {}
332
- inference_state["obj_id_to_idx"] = OrderedDict()
333
- inference_state["obj_idx_to_id"] = OrderedDict()
334
- inference_state["obj_ids"] = []
335
- inference_state["output_dict"] = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
336
- inference_state["output_dict_per_obj"] = {}
337
- inference_state["temp_output_dict_per_obj"] = {}
338
- inference_state["consolidated_frame_inds"] = {"cond_frame_outputs": set(), "non_cond_frame_outputs": set()}
339
- inference_state["tracking_has_started"] = False
340
- inference_state["frames_already_tracked"] = {}
341
- inference_state = gr.State(inference_state)
342
- return inference_state
343
-
344
- # (All additional helper functions such as get_frames_from_video, sam_refine, vos_tracking_video,
345
- # inpaint_video, generate_video_from_frames, process_example, reset_all, etc. are defined below.)
346
- # For brevity, they are included here in full as in your original code.
347
-
348
- def get_frames_from_video(video_input, video_state):
349
- video_path = video_input
350
- frames = []
351
- user_name = time.time()
352
- vr = VideoReader(video_path)
353
- original_fps = vr.get_avg_fps()
354
- if original_fps > 8:
355
- total_frames = len(vr)
356
- sample_interval = max(1, int(original_fps / 8))
357
- frame_indices = list(range(0, total_frames, sample_interval))
358
- frames = vr.get_batch(frame_indices).asnumpy()
359
- else:
360
- frames = vr.get_batch(list(range(len(vr)))).asnumpy()
361
- frames = frames[:49]
362
- resized_frames = [cv2.resize(frame, (720, 480)) for frame in frames]
363
- frames = np.array(resized_frames)
364
- init_start = time.time()
365
- inference_state = predictor.init_state(images=frames, offload_video_to_cpu=True, async_loading_frames=True)
366
- init_time = time.time() - init_start
367
- print(f"Inference state initialization took {init_time:.2f}s")
368
- fps = 8
369
- image_size = (frames[0].shape[0], frames[0].shape[1])
370
- video_state = {
371
- "user_name": user_name,
372
- "video_name": os.path.split(video_path)[-1],
373
- "origin_images": frames,
374
- "painted_images": frames.copy(),
375
- "masks": [np.zeros((frames[0].shape[0], frames[0].shape[1]), np.uint8)] * len(frames),
376
- "logits": [None] * len(frames),
377
- "select_frame_number": 0,
378
- "fps": fps,
379
- "ann_obj_id": 0
380
- }
381
- video_info = f"Video Name: {video_state['video_name']}, FPS: {video_state['fps']}, Total Frames: {len(frames)}, Image Size: {image_size}"
382
- video_input_path = generate_video_from_frames(frames, output_path=f"{GRADIO_TEMP_DIR}/inpaint/original_{video_state['video_name']}", fps=fps)
383
- return (gr.update(visible=True), gr.update(visible=True), inference_state, video_state, video_info,
384
- video_state["origin_images"][0], gr.update(visible=False, maximum=len(frames), value=1, interactive=True),
385
- gr.update(visible=False, maximum=len(frames), value=len(frames), interactive=True), gr.update(visible=True, interactive=True),
386
- gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=True),
387
- gr.update(visible=True, interactive=False), create_status("Upload video complete. Ready to select targets.", StatusMessage.SUCCESS), video_input_path)
388
-
389
- def select_template(image_selection_slider, video_state, interactive_state, previous_status):
390
- image_selection_slider -= 1
391
- video_state["select_frame_number"] = image_selection_slider
392
- return video_state["painted_images"][image_selection_slider], video_state, interactive_state, update_status(previous_status, f"Set tracking start at frame {image_selection_slider}.", StatusMessage.INFO)
393
-
394
- def get_end_number(track_pause_number_slider, video_state, interactive_state, previous_status):
395
- interactive_state["track_end_number"] = track_pause_number_slider
396
- return video_state["painted_images"][track_pause_number_slider], interactive_state, update_status(previous_status, f"Set tracking finish at frame {track_pause_number_slider}.", StatusMessage.INFO)
397
-
398
- def sam_refine(inference_state, video_state, point_prompt, click_state, interactive_state, evt, previous_status):
399
- ann_obj_id = 0
400
- ann_frame_idx = video_state["select_frame_number"]
401
- if point_prompt == "Positive":
402
- coordinate = f"[[{evt.index[0]},{evt.index[1]},1]]"
403
- interactive_state["positive_click_times"] += 1
404
- else:
405
- coordinate = f"[[{evt.index[0]},{evt.index[1]},0]]"
406
- interactive_state["negative_click_times"] += 1
407
- print(f"sam_refine, point_prompt: {point_prompt}, click_state: {click_state}")
408
- prompt = {"prompt_type":["click"], "input_point": click_state[0], "input_label": click_state[1], "multimask_output": "True"}
409
- points = np.array(prompt["input_point"])
410
- labels = np.array(prompt["input_label"])
411
- height, width = video_state["origin_images"][0].shape[0:2]
412
- for i in range(len(points)):
413
- points[i, 0] = int(points[i, 0])
414
- points[i, 1] = int(points[i, 1])
415
- print(f"sam_refine points: {points}, labels: {labels}")
416
- frame_idx, obj_ids, mask = predictor.add_new_points(inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels)
417
- mask_ = mask.cpu().squeeze().detach().numpy()
418
- mask_[mask_ <= 0] = 0
419
- mask_[mask_ > 0] = 1
420
- org_image = video_state["origin_images"][video_state["select_frame_number"]]
421
- mask_ = cv2.resize(mask_, (width, height))
422
- mask_ = mask_[:, :, None]
423
- mask_[mask_ > 0.5] = 1
424
- mask_[mask_ <= 0.5] = 0
425
- color = 63 * np.ones((height, width, 3)) * np.array([[[np.random.randint(5), np.random.randint(5), np.random.randint(5)]]])
426
- painted_image = np.uint8((1 - 0.5 * mask_) * org_image + 0.5 * mask_ * color)
427
- video_state["masks"][video_state["select_frame_number"]] = mask_
428
- video_state["painted_images"][video_state["select_frame_number"]] = painted_image
429
- return painted_image, video_state, interactive_state, update_status(previous_status, "Segmentation updated. Add more points or continue tracking.", StatusMessage.SUCCESS)
430
-
431
- def clear_click(inference_state, video_state, click_state, previous_status):
432
- predictor.reset_state(inference_state)
433
- click_state = [[], []]
434
- template_frame = video_state["origin_images"][video_state["select_frame_number"]]
435
- return inference_state, template_frame, click_state, update_status(previous_status, "Click history cleared.", StatusMessage.INFO)
436
-
437
- def vos_tracking_video(inference_state, video_state, interactive_state, previous_status):
438
- height, width = video_state["origin_images"][0].shape[0:2]
439
- masks = []
440
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
441
- mask = np.zeros([480, 720, 1])
442
- for i in range(len(out_mask_logits)):
443
- out_mask = out_mask_logits[i].cpu().squeeze().detach().numpy()
444
- out_mask[out_mask > 0] = 1
445
- out_mask[out_mask <= 0] = 0
446
- out_mask = out_mask[:, :, None]
447
- mask += out_mask
448
- mask = cv2.resize(mask, (width, height))
449
- mask = mask[:, :, None]
450
- mask[mask > 0.5] = 1
451
- mask[mask < 1] = 0
452
- mask = scipy.ndimage.binary_dilation(mask, iterations=6)
453
- masks.append(mask)
454
- masks = np.array(masks)
455
- if interactive_state.get("track_end_number") is not None:
456
- video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
457
- org_images = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
458
- color = 255 * np.ones((1, org_images.shape[-3], org_images.shape[-2], 3)) * np.array([[[[0, 1, 1]]]])
459
- painted_images = np.uint8((1 - 0.5 * masks) * org_images + 0.5 * masks * color)
460
- video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
461
- else:
462
- video_state["masks"] = masks
463
- org_images = video_state["origin_images"]
464
- color = 255 * np.ones((1, org_images.shape[-3], org_images.shape[-2], 3)) * np.array([[[[0, 1, 1]]]])
465
- painted_images = np.uint8((1 - 0.5 * masks) * org_images + 0.5 * masks * color)
466
- video_state["painted_images"] = painted_images
467
- video_output = generate_video_from_frames(video_state["painted_images"], output_path=f"{GRADIO_TEMP_DIR}/track/{video_state['video_name']}", fps=video_state["fps"])
468
- interactive_state["inference_times"] += 1
469
- print(f"vos_tracking_video output: {video_output}")
470
- return inference_state, video_output, video_state, interactive_state, update_status(previous_status, "Tracking complete.", StatusMessage.SUCCESS), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True)
471
-
472
- def inpaint_video(video_state, video_caption, target_region_frame1_caption, interactive_state, previous_status, seed_param, cfg_scale, dilate_size):
473
- seed = int(seed_param) if int(seed_param) >= 0 else np.random.randint(0, 2**32 - 1)
474
- validation_images = video_state["origin_images"]
475
- validation_masks = video_state["masks"]
476
- validation_masks = [np.squeeze(mask) for mask in validation_masks]
477
- validation_masks = [(mask > 0).astype(np.uint8) * 255 for mask in validation_masks]
478
- validation_masks = [np.stack([m, m, m], axis=-1) for m in validation_masks]
479
- validation_images = [Image.fromarray(np.uint8(img)).convert('RGB') for img in validation_images]
480
- validation_masks = [Image.fromarray(np.uint8(mask)).convert('RGB') for mask in validation_masks]
481
- validation_images = [img.resize((720, 480)) for img in validation_images]
482
- validation_masks = [mask.resize((720, 480)) for mask in validation_masks]
483
- print("Inpainting: video_caption=", video_caption)
484
- images = generate_frames(
485
- images=validation_images,
486
- masks=validation_masks,
487
- pipe=validation_pipeline,
488
- pipe_img_inpainting=validation_pipeline_img,
489
- prompt=str(video_caption),
490
- image_inpainting_prompt=str(target_region_frame1_caption),
491
- seed=seed,
492
- cfg_scale=float(cfg_scale),
493
- dilate_size=int(dilate_size)
494
- )
495
- images = (images * 255).astype(np.uint8)
496
- video_output = generate_video_from_frames(images, output_path=f"{GRADIO_TEMP_DIR}/inpaint/{video_state['video_name']}", fps=8)
497
- print(f"Inpaint_video output: {video_output}")
498
- return video_output, update_status(previous_status, "Inpainting complete.", StatusMessage.SUCCESS)
499
-
500
- def generate_video_from_frames(frames, output_path, fps=8):
501
- frames_tensor = torch.from_numpy(np.asarray(frames)).to(torch.uint8)
502
- if not os.path.exists(os.path.dirname(output_path)):
503
- os.makedirs(os.path.dirname(output_path))
504
- torchvision.io.write_video(output_path, frames_tensor, fps=fps, video_codec="libx264")
505
- return output_path
506
-
507
- def process_example(video_input, video_caption, target_region_frame1_caption, prompt, click_state):
508
- if video_input is None or video_input == "":
509
- return (gr.update(value=""), gr.update(value=""), init_state(),
510
- {"user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 8, "ann_obj_id": 0},
511
- "", None,
512
- gr.update(value=1, visible=False, interactive=False),
513
- gr.update(value=1, visible=False, interactive=False),
514
- gr.update(value="Positive", interactive=False),
515
- gr.update(visible=True, interactive=False),
516
- gr.update(visible=True, interactive=False),
517
- gr.update(value=None),
518
- gr.update(visible=True, interactive=False),
519
- create_status("Reset complete. Ready for new input.", StatusMessage.INFO),
520
- gr.update(value=None))
521
- video_state = gr.State({
522
- "user_name": "",
523
- "video_name": "",
524
- "origin_images": None,
525
- "painted_images": None,
526
- "masks": None,
527
- "inpaint_masks": None,
528
- "logits": None,
529
- "select_frame_number": 0,
530
- "fps": 8,
531
- "ann_obj_id": 0
532
- })
533
- results = get_frames_from_video(video_input, video_state)
534
- if click_state[0] and click_state[1]:
535
- print("Example detected, executing sam_refine")
536
- (video_caption, target_region_frame1_caption, inference_state, video_state, video_info, template_frame, image_selection_slider, track_pause_number_slider, point_prompt, clear_button, tracking_button, video_output, inpaint_button, run_status, video_input) = results
537
- class MockEvent:
538
- def __init__(self, points, point_idx=0):
539
- self.index = points[point_idx]
540
- for i_click in range(len(click_state[0])):
541
- evt = MockEvent(click_state[0], i_click)
542
- prompt_type = "Positive" if click_state[1][i_click] == 1 else "Negative"
543
- template_frame, video_state, interactive_state, run_status = sam_refine(inference_state, video_state, prompt_type, click_state, {"inference_times": 0, "negative_click_times": 0, "positive_click_times": 0, "mask_save": False, "multi_mask": {"mask_names": [], "masks": []}, "track_end_number": None}, evt, run_status)
544
- return (video_caption, target_region_frame1_caption, inference_state, video_state, video_info, template_frame, image_selection_slider, track_pause_number_slider, point_prompt, clear_button, tracking_button, video_output, inpaint_button, run_status, video_input)
545
- return results
546
-
547
- def reset_all():
548
- return (gr.update(value=None), gr.update(value=""), gr.update(value=""), init_state(),
549
- {"user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 8, "ann_obj_id": 0},
550
- {"inference_times": 0, "negative_click_times": 0, "positive_click_times": 0, "mask_save": False, "multi_mask": {"mask_names": [], "masks": []}, "track_end_number": None},
551
- [[], []], None, gr.update(visible=True, interactive=True), "",
552
- gr.update(value=1, visible=False, interactive=False), gr.update(value=1, visible=False, interactive=False),
553
- gr.update(value="Positive", interactive=False), gr.Button.update(interactive=False),
554
- gr.Button.update(interactive=False), gr.Button.update(interactive=False),
555
- gr.Button.update(interactive=False), gr.Button.update(interactive=False),
556
- gr.Button.update(interactive=False), gr.Number.update(value=42),
557
- gr.Slider.update(value=6.0), gr.Slider.update(value=16),
558
- create_status("Reset complete. Ready for new input.", StatusMessage.INFO))
559
-
560
- ###############################
561
- # Build Gradio Interface
562
- ###############################
563
- title = """<p><h1 align="center">VideoPainter</h1></p>"""
564
- with gr.Blocks() as iface:
565
- gr.HTML("""
566
- <div style="text-align: center;">
567
- <h1 style="color: #333;">🖌️ VideoPainter</h1>
568
- <h3 style="color: #333;">Any-length Video Inpainting and Editing with Plug-and-Play Context Control</h3>
569
- <p style="font-weight: bold;">
570
- <a href="https://yxbian23.github.io/project/video-painter/">🌍 Project Page</a> |
571
- <a href="https://arxiv.org/abs/2503.05639">📃 ArXiv Preprint</a> |
572
- <a href="https://github.com/TencentARC/VideoPainter">🧑‍💻 Github Repository</a>
573
- </p>
574
- </div>
575
- """)
576
- click_state = gr.State([[], []])
577
- interactive_state = gr.State({
578
- "inference_times": 0,
579
- "negative_click_times": 0,
580
- "positive_click_times": 0,
581
- "mask_save": False,
582
- "multi_mask": {"mask_names": [], "masks": []},
583
- "track_end_number": None,
584
- })
585
- video_state = gr.State({
586
- "user_name": "",
587
- "video_name": "",
588
- "origin_images": None,
589
- "painted_images": None,
590
- "masks": None,
591
- "inpaint_masks": None,
592
- "logits": None,
593
- "select_frame_number": 0,
594
- "fps": 8,
595
- "ann_obj_id": 0
596
- })
597
- inference_state = init_state()
598
-
599
- with gr.Row():
600
- with gr.Column():
601
- with gr.Row():
602
- video_input = gr.Video(label="Original Video", visible=True)
603
- with gr.Row():
604
- with gr.Column(scale=3):
605
- template_frame = gr.Image(type="pil", interactive=True, elem_id="template_frame", visible=True)
606
- with gr.Column(scale=1):
607
- with gr.Accordion("Segmentation Point Prompt", open=True):
608
- point_prompt = gr.Radio(choices=["Positive", "Negative"], value="Positive", label="Point Type", interactive=False, visible=True)
609
- clear_button_click = gr.Button(value="Clear clicks", interactive=False, visible=True)
610
- gr.Markdown("✨ Positive: Include target region. <br> ✨ Negative: Exclude target region.")
611
- image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False, interactive=False)
612
- track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False, interactive=False)
613
- video_output = gr.Video(label="Generated Video", visible=True)
614
- with gr.Row():
615
- tracking_video_predict_button = gr.Button(value="Tracking", interactive=False, visible=True)
616
- inpaint_video_predict_button = gr.Button(value="Inpainting", interactive=False, visible=True)
617
- reset_button = gr.Button(value="Reset All", interactive=True, visible=True)
618
-
619
- with gr.Column():
620
- with gr.Accordion("Global Video Caption", open=True):
621
- video_caption = gr.Textbox(label="Global Video Caption", placeholder="Input global video caption...", interactive=True, visible=True, max_lines=5, show_copy_button=True)
622
- with gr.Row():
623
- gr.Markdown("✨ Enhance prompt using GPT-4o (optional).")
624
- enhance_button = gr.Button("✨ Enhance Prompt(Optional)", interactive=False)
625
- with gr.Accordion("Target Object Caption", open=True):
626
- target_region_frame1_caption = gr.Textbox(label="Target Object Caption", placeholder="Input target object caption...", interactive=True, visible=True, max_lines=5, show_copy_button=True)
627
- with gr.Row():
628
- gr.Markdown("✨ Generate target caption (optional).")
629
- enhance_target_region_frame1_button = gr.Button("✨ Target Prompt Generation (Optional)", interactive=False)
630
- with gr.Accordion("Editing Instruction", open=False):
631
- gr.Markdown("✨ Modify captions based on your instruction using GPT-4o.")
632
- with gr.Row():
633
- editing_instruction = gr.Textbox(label="Editing Instruction", placeholder="Input editing instruction...", interactive=True, visible=True, max_lines=5, show_copy_button=True)
634
- enhance_editing_instruction_button = gr.Button("✨ Modify Caption(For Editing)", interactive=False)
635
- with gr.Accordion("Advanced Sampling Settings", open=False):
636
- cfg_scale = gr.Slider(value=6.0, label="Classifier-Free Guidance Scale", minimum=1, maximum=10, step=0.1, interactive=True)
637
- seed_param = gr.Number(label="Inference Seed (>=0)", interactive=True, value=42)
638
- dilate_size = gr.Slider(value=16, label="Mask Dilate Size", minimum=0, maximum=32, step=1, interactive=True)
639
- video_info = gr.Textbox(label="Video Info", visible=True, interactive=False)
640
- model_type = gr.Textbox(label="Type", placeholder="Model type...", interactive=True, visible=False)
641
- notes_accordion = gr.Accordion("Notes", open=False)
642
- with notes_accordion:
643
- gr.HTML("<p style='font-size: 1.1em;'>🧐 Reminder: VideoPainter may produce unexpected outputs. Adjust settings if needed.</p>")
644
- run_status = gr.HighlightedText(value=[("", "")], visible=True, label="Operation Status", show_label=True,
645
- color_map={"Success": "green", "Error": "red", "Warning": "orange", "Info": "blue"})
646
-
647
- with gr.Row():
648
- examples = gr.Examples(label="Quick Examples", examples=EXAMPLES,
649
- inputs=[video_input, video_caption, target_region_frame1_caption, point_prompt, model_type, editing_instruction, seed_param, cfg_scale, dilate_size, click_state],
650
- examples_per_page=20, cache_examples=False)
651
-
652
- video_input.change(fn=process_example, inputs=[video_input, video_caption, target_region_frame1_caption, point_prompt, click_state],
653
- outputs=[video_caption, target_region_frame1_caption, inference_state, video_state, video_info,
654
- template_frame, image_selection_slider, track_pause_number_slider, point_prompt, clear_button_click,
655
- tracking_video_predict_button, video_output, inpaint_video_predict_button, run_status, video_input])
656
-
657
- image_selection_slider.release(fn=select_template, inputs=[image_selection_slider, video_state, interactive_state, run_status],
658
- outputs=[template_frame, video_state, interactive_state, run_status])
659
-
660
- track_pause_number_slider.release(fn=get_end_number, inputs=[track_pause_number_slider, video_state, interactive_state, run_status],
661
- outputs=[template_frame, interactive_state, run_status])
662
-
663
- template_frame.select(fn=sam_refine, inputs=[inference_state, video_state, point_prompt, click_state, interactive_state, run_status],
664
- outputs=[template_frame, video_state, interactive_state, run_status])
665
-
666
- tracking_video_predict_button.click(fn=vos_tracking_video, inputs=[inference_state, video_state, interactive_state, run_status],
667
- outputs=[inference_state, video_output, video_state, interactive_state, run_status,
668
- inpaint_video_predict_button, enhance_button, enhance_target_region_frame1_button, enhance_editing_instruction_button, notes_accordion])
669
-
670
- inpaint_video_predict_button.click(fn=inpaint_video, inputs=[video_state, video_caption, target_region_frame1_caption, interactive_state, run_status, seed_param, cfg_scale, dilate_size],
671
- outputs=[video_output, run_status], api_name=False, show_progress="full")
672
-
673
- def enhance_prompt_func(video_caption):
674
- return video_caption # Replace with your convert_prompt() if available
675
-
676
- def enhance_target_region_frame1_prompt_func(target_region_frame1_caption, video_state):
677
- return target_region_frame1_caption # Replace with your convert_prompt_target_region_frame1() if available
678
-
679
- def enhance_editing_instruction_prompt_func(editing_instruction, video_caption, target_region_frame1_caption, video_state):
680
- return video_caption, target_region_frame1_caption # Replace with your convert_prompt_editing_instruction() if available
681
-
682
- enhance_button.click(enhance_prompt_func, inputs=[video_caption], outputs=[video_caption])
683
- enhance_target_region_frame1_button.click(enhance_target_region_frame1_prompt_func, inputs=[target_region_frame1_caption, video_state], outputs=[target_region_frame1_caption])
684
- enhance_editing_instruction_button.click(enhance_editing_instruction_prompt_func, inputs=[editing_instruction, video_caption, target_region_frame1_caption, video_state],
685
- outputs=[video_caption, target_region_frame1_caption])
686
-
687
- video_input.clear(fn=lambda: (gr.update(visible=True), gr.update(visible=True), init_state(),
688
- {"user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 8, "ann_obj_id": 0},
689
- {"inference_times": 0, "negative_click_times": 0, "positive_click_times": 0, "mask_save": False, "multi_mask": {"mask_names": [], "masks": []}, "track_end_number": 0},
690
- [[], []], None, None,
691
- gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True),
692
- gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True, value=[]),
693
- gr.update(visible=True), gr.update(visible=True), gr.update(visible=True),
694
- gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False)),
695
- outputs=[video_caption, target_region_frame1_caption, inference_state, video_state, interactive_state, click_state, video_output, template_frame, tracking_video_predict_button, image_selection_slider, track_pause_number_slider, point_prompt, clear_button_click, template_frame, tracking_video_predict_button, video_output, inpaint_video_predict_button, run_status], queue=False, show_progress=False)
696
-
697
- clear_button_click.click(fn=clear_click, inputs=[inference_state, video_state, click_state, run_status],
698
- outputs=[inference_state, template_frame, click_state, run_status])
699
-
700
- reset_button.click(fn=reset_all, inputs=[], outputs=[video_input, video_caption, target_region_frame1_caption, inference_state, video_state, interactive_state, click_state, video_output, template_frame, video_info, image_selection_slider, track_pause_number_slider, point_prompt, clear_button_click, tracking_video_predict_button, inpaint_video_predict_button, enhance_button, enhance_target_region_frame1_button, enhance_editing_instruction_button, seed_param, cfg_scale, dilate_size, run_status])
701
-
702
- iface.queue().launch(share=False)