reedmayhew commited on
Commit
0b3b09b
Β·
verified Β·
1 Parent(s): 7d5d19b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +754 -0
app.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ This is the full application script for VideoPainter.
4
+ It first checks for and (if necessary) installs missing dependencies.
5
+ When installing the custom packages (diffusers and app),
6
+ it uses the flag --no-build-isolation so that the installed torch is seen.
7
+ If the custom diffusers package fails to provide the expected submodules,
8
+ the script will force-install the official diffusers package.
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import subprocess
14
+ import warnings
15
+ import time
16
+ import importlib
17
+
18
+ warnings.filterwarnings("ignore")
19
+
20
+ # Set Gradio temp directory via environment variable
21
+ GRADIO_TEMP_DIR = "./tmp_gradio"
22
+ os.makedirs(GRADIO_TEMP_DIR, exist_ok=True)
23
+ os.makedirs(f"{GRADIO_TEMP_DIR}/track", exist_ok=True)
24
+ os.makedirs(f"{GRADIO_TEMP_DIR}/inpaint", exist_ok=True)
25
+ os.environ["GRADIO_TEMP_DIR"] = GRADIO_TEMP_DIR
26
+
27
+ def install_package(package_spec):
28
+ print(f"Installing {package_spec} ...")
29
+ try:
30
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package_spec])
31
+ print(f"Successfully installed {package_spec}")
32
+ return True
33
+ except Exception as e:
34
+ print(f"Failed to install {package_spec}: {e}")
35
+ return False
36
+
37
+ print("Checking for PyTorch ...")
38
+ try:
39
+ import torch
40
+ print("PyTorch is already installed.")
41
+ except ImportError:
42
+ print("PyTorch not found, installing...")
43
+ if not install_package("torch>=2.0.0 torchvision>=0.15.0"):
44
+ print("Failed to install PyTorch, which is required.")
45
+ sys.exit(1)
46
+
47
+ # First, install wheel package which is needed for bdist_wheel command
48
+ install_package("wheel")
49
+
50
+ # Install ninja for faster builds
51
+ install_package("ninja")
52
+
53
+ # Check and install other critical dependencies
54
+ critical_dependencies = [
55
+ ("hydra", "hydra-core>=1.3.2"),
56
+ ("omegaconf", "omegaconf>=2.3.0"),
57
+ ("decord", "decord>=0.6.0"),
58
+ ("diffusers", "diffusers>=0.24.0"), # Will be replaced with custom one
59
+ ("transformers", "transformers>=4.35.0"),
60
+ ("gradio", "gradio>=4.0.0"),
61
+ ("numpy", "numpy>=1.24.0"),
62
+ ("cv2", "opencv-python>=4.8.0"),
63
+ ("PIL", "Pillow>=10.0.0"),
64
+ ("scipy", "scipy>=1.11.0"),
65
+ ("einops", "einops>=0.7.0"),
66
+ ("onnxruntime", "onnxruntime>=1.16.0"),
67
+ ("timm", "timm>=0.9.0"),
68
+ ("safetensors", "safetensors>=0.4.0"),
69
+ ("moviepy", "moviepy>=1.0.3"),
70
+ ("imageio", "imageio>=2.30.0"),
71
+ ("tqdm", "tqdm>=4.64.0"),
72
+ ("openai", "openai>=1.5.0"),
73
+ ("psutil", "psutil>=5.9.0")
74
+ ]
75
+
76
+ for mod_name, pkg_spec in critical_dependencies:
77
+ try:
78
+ if mod_name == "PIL":
79
+ from PIL import Image
80
+ elif mod_name == "cv2":
81
+ import cv2
82
+ else:
83
+ __import__(mod_name)
84
+ print(f"{mod_name} is already installed.")
85
+ except ImportError:
86
+ print(f"{mod_name} not found, installing {pkg_spec} ...")
87
+ install_package(pkg_spec)
88
+
89
+ print("Setting up environment...")
90
+ # Clone the VideoPainter repository if not present
91
+ if not os.path.exists("VideoPainter"):
92
+ print("Cloning VideoPainter repository...")
93
+ os.system("git clone https://github.com/TencentARC/VideoPainter.git")
94
+
95
+ # Add necessary paths to sys.path
96
+ sys.path.append(os.path.join(os.getcwd(), "VideoPainter"))
97
+ sys.path.append(os.path.join(os.getcwd(), "VideoPainter/app"))
98
+ sys.path.append(os.path.join(os.getcwd(), "app"))
99
+ sys.path.append(".")
100
+
101
+ # Ensure custom diffusers is importable
102
+ if os.path.exists("VideoPainter/diffusers"):
103
+ print("Installing custom diffusers...")
104
+ # First, remove any existing diffusers installation
105
+ subprocess.call([sys.executable, "-m", "pip", "uninstall", "-y", "diffusers"])
106
+
107
+ # Copy the files directly into the site-packages directory instead of using pip install -e
108
+ import site
109
+ site_packages = site.getsitepackages()[0]
110
+ diffusers_src = os.path.join(os.getcwd(), "VideoPainter/diffusers/src/diffusers")
111
+ diffusers_dst = os.path.join(site_packages, "diffusers")
112
+
113
+ print(f"Copying diffusers from {diffusers_src} to {diffusers_dst}")
114
+ if not os.path.exists(diffusers_dst):
115
+ os.makedirs(diffusers_dst, exist_ok=True)
116
+
117
+ # Copy diffusers files directly
118
+ os.system(f"cp -r {diffusers_src}/* {diffusers_dst}/")
119
+
120
+ # Also add VideoPainter/diffusers/src to sys.path
121
+ sys.path.append(os.path.join(os.getcwd(), "VideoPainter/diffusers/src"))
122
+
123
+ # Verify the custom model is available
124
+ try:
125
+ # Force reload diffusers to pick up the new files
126
+ if "diffusers" in sys.modules:
127
+ del sys.modules["diffusers"]
128
+ import diffusers
129
+ print(f"Diffusers version: {diffusers.__version__}")
130
+ print(f"Available modules in diffusers: {dir(diffusers)}")
131
+
132
+ # Check if models directory exists in custom diffusers
133
+ models_dir = os.path.join(diffusers_dst, "models")
134
+ if os.path.exists(models_dir):
135
+ print(f"Models in diffusers: {os.listdir(models_dir)}")
136
+ except Exception as e:
137
+ print(f"Error verifying diffusers installation: {e}")
138
+
139
+ # Copy the app directory if needed
140
+ if not os.path.exists("app"):
141
+ os.makedirs("app", exist_ok=True)
142
+ print("Copying VideoPainter/app to local app directory...")
143
+ os.system("cp -r VideoPainter/app/* app/")
144
+
145
+ # Don't try to install app package, just add to path
146
+ print("Adding app directory to Python path...")
147
+ app_path = os.path.join(os.getcwd(), "app")
148
+ sys.path.insert(0, app_path)
149
+
150
+ # Insert the VideoPainter path at the beginning of sys.path to ensure it takes precedence
151
+ sys.path.insert(0, os.path.join(os.getcwd(), "VideoPainter"))
152
+
153
+ print("Importing standard modules and dependencies ...")
154
+ try:
155
+ import gradio as gr
156
+ import cv2
157
+ import numpy as np
158
+ import scipy
159
+ import torchvision
160
+ from PIL import Image
161
+ from huggingface_hub import snapshot_download
162
+ from decord import VideoReader
163
+ except ImportError as e:
164
+ print(f"Error importing basic modules: {e}")
165
+ sys.exit(1)
166
+
167
+ # Import specialized modules with better error handling
168
+ try:
169
+ # Import our custom modules
170
+ from sam2.build_sam import build_sam2_video_predictor
171
+
172
+ # Force reload of diffusers after direct copy
173
+ if "diffusers" in sys.modules:
174
+ del sys.modules["diffusers"]
175
+
176
+ # Now import diffusers with explicit path to the files we need
177
+ sys.path.insert(0, os.path.join(os.getcwd(), "VideoPainter/app"))
178
+
179
+ # Import utils after setting up correct paths
180
+ from utils import load_model, generate_frames
181
+ print("All modules imported successfully!")
182
+ except ImportError as e:
183
+ print(f"Error importing specialized modules: {e}")
184
+ print("Paths:", sys.path)
185
+
186
+ # Try to diagnose and fix the specific issue
187
+ if "CogvideoXBranchModel" in str(e):
188
+ print("Trying to fix missing CogvideoXBranchModel...")
189
+
190
+ # Check if the model file exists in the repository
191
+ branch_model_file = "VideoPainter/diffusers/src/diffusers/models/cogvideox_branch.py"
192
+ if os.path.exists(branch_model_file):
193
+ print(f"Found branch model file at {branch_model_file}")
194
+
195
+ # Manually import the module
196
+ import sys
197
+ sys.path.insert(0, os.path.join(os.getcwd(), "VideoPainter/diffusers/src"))
198
+
199
+ # Add the import to __init__.py if not already there
200
+ init_file = os.path.join(site_packages, "diffusers/__init__.py")
201
+ with open(init_file, 'r') as f:
202
+ init_content = f.read()
203
+
204
+ if "CogvideoXBranchModel" not in init_content:
205
+ print("Adding CogvideoXBranchModel to diffusers/__init__.py")
206
+ with open(init_file, 'a') as f:
207
+ f.write("\nfrom .models.cogvideox_branch import CogvideoXBranchModel\n")
208
+
209
+ # Force reload diffusers
210
+ if "diffusers" in sys.modules:
211
+ del sys.modules["diffusers"]
212
+
213
+ # Try importing again
214
+ from utils import load_model, generate_frames
215
+ print("Fixed CogvideoXBranchModel import issue!")
216
+ else:
217
+ print(f"Could not find {branch_model_file}")
218
+ sys.exit(1)
219
+ else:
220
+ sys.exit(1)
221
+
222
+
223
+ ###############################
224
+ # Begin Application Code (VideoPainter demo)
225
+ ###############################
226
+
227
+ def download_models():
228
+ print("Downloading models from Hugging Face Hub...")
229
+ models = {
230
+ "CogVideoX-5b-I2V": "THUDM/CogVideoX-5b-I2V",
231
+ "VideoPainter": "TencentARC/VideoPainter"
232
+ }
233
+ model_paths = {}
234
+ os.makedirs("ckpt", exist_ok=True)
235
+ for name, repo_id in models.items():
236
+ print(f"Downloading {name} from {repo_id}...")
237
+ path = snapshot_download(repo_id=repo_id)
238
+ model_paths[name] = path
239
+ print(f"Downloaded {name} to {path}")
240
+ try:
241
+ flux_path = snapshot_download(repo_id="black-forest-labs/FLUX.1-Fill-dev")
242
+ model_paths["FLUX"] = flux_path
243
+ except Exception as e:
244
+ print(f"Failed to download FLUX model: {e}")
245
+ model_paths["FLUX"] = None
246
+ os.makedirs("ckpt/Grounded-SAM-2", exist_ok=True)
247
+ sam2_path = "ckpt/Grounded-SAM-2/sam2_hiera_large.pt"
248
+ if not os.path.exists(sam2_path):
249
+ print(f"Downloading SAM2 to {sam2_path}...")
250
+ os.system(f"wget -O {sam2_path} https://huggingface.co/spaces/sam2/sam2/resolve/main/sam2_hiera_large.pt")
251
+ model_paths["SAM2"] = sam2_path
252
+ return model_paths
253
+
254
+ print("Initializing application environment...")
255
+ if not os.path.exists("app"):
256
+ print("Setting up app folder from VideoPainter repository ...")
257
+ os.system("git clone https://github.com/TencentARC/VideoPainter.git")
258
+ os.makedirs("app", exist_ok=True)
259
+ os.system("cp -r VideoPainter/app/* app/")
260
+ os.system("pip install --no-build-isolation -e VideoPainter/diffusers")
261
+ os.chdir("app")
262
+ os.system("pip install --no-build-isolation -e .")
263
+ os.chdir("..")
264
+
265
+ sys.path.append("app")
266
+ sys.path.append(".")
267
+
268
+ # Import project modules (again, to be safe)
269
+ try:
270
+ from decord import VideoReader
271
+ from sam2.build_sam import build_sam2_video_predictor
272
+ from utils import load_model, generate_frames
273
+ except ImportError as e:
274
+ print(f"Failed to import specialized modules: {e}")
275
+ sys.exit(1)
276
+
277
+ # Set up OpenRouter / OpenAI (for caption generation)
278
+ try:
279
+ from openai import OpenAI
280
+ vlm_model = OpenAI(
281
+ api_key=os.getenv("OPENROUTER_API_KEY", ""),
282
+ base_url="https://openrouter.ai/api/v1"
283
+ )
284
+ print("OpenRouter client initialized successfully")
285
+ except Exception as e:
286
+ print(f"OpenRouter API not available: {e}")
287
+ class DummyModel:
288
+ def __getattr__(self, name):
289
+ return self
290
+ def __call__(self, *args, **kwargs):
291
+ return self
292
+ def create(self, *args, **kwargs):
293
+ class DummyResponse:
294
+ choices = [type('obj', (object,), {'message': type('obj', (object,), {'content': "OpenRouter API not available. Using default prompt."})})]
295
+ return DummyResponse()
296
+ vlm_model = DummyModel()
297
+
298
+ ###############################
299
+ # Download models and initialize predictors
300
+ ###############################
301
+ model_paths = download_models()
302
+ base_model_path = model_paths["CogVideoX-5b-I2V"]
303
+ videopainter_path = model_paths["VideoPainter"]
304
+ inpainting_branch = os.path.join(videopainter_path, "checkpoints/branch")
305
+ id_adapter = os.path.join(videopainter_path, "VideoPainterID/checkpoints")
306
+ img_inpainting_model = model_paths.get("FLUX")
307
+ sam2_checkpoint = "ckpt/Grounded-SAM-2/sam2_hiera_large.pt"
308
+ model_cfg = "sam2_hiera_l.yaml"
309
+
310
+ try:
311
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
312
+ print("Build SAM2 predictor done!")
313
+ validation_pipeline, validation_pipeline_img = load_model(
314
+ model_path=base_model_path,
315
+ inpainting_branch=inpainting_branch,
316
+ id_adapter=id_adapter,
317
+ img_inpainting_model=img_inpainting_model
318
+ )
319
+ print("Load model done!")
320
+ except Exception as e:
321
+ print(f"Error initializing models: {e}")
322
+ sys.exit(1)
323
+
324
+ ###############################
325
+ # Helper functions & state definitions
326
+ ###############################
327
+ EXAMPLES = [
328
+ [
329
+ "https://huggingface.co/spaces/TencentARC/VideoPainter/resolve/main/examples/ferry.mp4",
330
+ "A white ferry with red and blue accents, named 'COLONIA', cruises on a calm river...",
331
+ "White and red passenger ferry boat labeled 'COLONIA 6' with multiple windows, life buoys, and upper deck seating.",
332
+ "Positive",
333
+ "Inpaint",
334
+ "",
335
+ 42,
336
+ 6.0,
337
+ 16,
338
+ [[[320, 240]], [1]],
339
+ ],
340
+ [
341
+ "https://huggingface.co/spaces/TencentARC/VideoPainter/resolve/main/examples/street.mp4",
342
+ "A bustling city street at night illuminated by festive lights, a red double-decker bus...",
343
+ "The rear of a black car with illuminated red tail lights and a visible license plate.",
344
+ "Positive",
345
+ "Inpaint",
346
+ "",
347
+ 42,
348
+ 6.0,
349
+ 16,
350
+ [[[200, 400]], [1]],
351
+ ],
352
+ ]
353
+
354
+ class StatusMessage:
355
+ INFO = "Info"
356
+ WARNING = "Warning"
357
+ ERROR = "Error"
358
+ SUCCESS = "Success"
359
+
360
+ def create_status(message, status_type=StatusMessage.INFO):
361
+ timestamp = time.strftime("%H:%M:%S")
362
+ return [("", ""), (f"[{timestamp}]: {message}\n", status_type)]
363
+
364
+ def update_status(previous_status, new_message, status_type=StatusMessage.INFO):
365
+ timestamp = time.strftime("%H:%M:%S")
366
+ history = previous_status[-3:]
367
+ history.append((f"[{timestamp}]: {new_message}\n", status_type))
368
+ return [("", "")] + history
369
+
370
+ def init_state(offload_video_to_cpu=False, offload_state_to_cpu=False):
371
+ inference_state = {}
372
+ inference_state["images"] = torch.zeros([1, 3, 100, 100])
373
+ inference_state["num_frames"] = 1
374
+ inference_state["offload_video_to_cpu"] = offload_video_to_cpu
375
+ inference_state["offload_state_to_cpu"] = offload_state_to_cpu
376
+ inference_state["video_height"] = 100
377
+ inference_state["video_width"] = 100
378
+ inference_state["device"] = torch.device("cuda")
379
+ inference_state["storage_device"] = torch.device("cpu") if offload_state_to_cpu else torch.device("cuda")
380
+ inference_state["point_inputs_per_obj"] = {}
381
+ inference_state["mask_inputs_per_obj"] = {}
382
+ inference_state["cached_features"] = {}
383
+ inference_state["constants"] = {}
384
+ inference_state["obj_id_to_idx"] = OrderedDict()
385
+ inference_state["obj_idx_to_id"] = OrderedDict()
386
+ inference_state["obj_ids"] = []
387
+ inference_state["output_dict"] = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
388
+ inference_state["output_dict_per_obj"] = {}
389
+ inference_state["temp_output_dict_per_obj"] = {}
390
+ inference_state["consolidated_frame_inds"] = {"cond_frame_outputs": set(), "non_cond_frame_outputs": set()}
391
+ inference_state["tracking_has_started"] = False
392
+ inference_state["frames_already_tracked"] = {}
393
+ inference_state = gr.State(inference_state)
394
+ return inference_state
395
+
396
+ # (All additional helper functions such as get_frames_from_video, sam_refine, vos_tracking_video,
397
+ # inpaint_video, generate_video_from_frames, process_example, reset_all, etc. are defined below.)
398
+ # For brevity, they are included here in full as in your original code.
399
+
400
+ def get_frames_from_video(video_input, video_state):
401
+ video_path = video_input
402
+ frames = []
403
+ user_name = time.time()
404
+ vr = VideoReader(video_path)
405
+ original_fps = vr.get_avg_fps()
406
+ if original_fps > 8:
407
+ total_frames = len(vr)
408
+ sample_interval = max(1, int(original_fps / 8))
409
+ frame_indices = list(range(0, total_frames, sample_interval))
410
+ frames = vr.get_batch(frame_indices).asnumpy()
411
+ else:
412
+ frames = vr.get_batch(list(range(len(vr)))).asnumpy()
413
+ frames = frames[:49]
414
+ resized_frames = [cv2.resize(frame, (720, 480)) for frame in frames]
415
+ frames = np.array(resized_frames)
416
+ init_start = time.time()
417
+ inference_state = predictor.init_state(images=frames, offload_video_to_cpu=True, async_loading_frames=True)
418
+ init_time = time.time() - init_start
419
+ print(f"Inference state initialization took {init_time:.2f}s")
420
+ fps = 8
421
+ image_size = (frames[0].shape[0], frames[0].shape[1])
422
+ video_state = {
423
+ "user_name": user_name,
424
+ "video_name": os.path.split(video_path)[-1],
425
+ "origin_images": frames,
426
+ "painted_images": frames.copy(),
427
+ "masks": [np.zeros((frames[0].shape[0], frames[0].shape[1]), np.uint8)] * len(frames),
428
+ "logits": [None] * len(frames),
429
+ "select_frame_number": 0,
430
+ "fps": fps,
431
+ "ann_obj_id": 0
432
+ }
433
+ video_info = f"Video Name: {video_state['video_name']}, FPS: {video_state['fps']}, Total Frames: {len(frames)}, Image Size: {image_size}"
434
+ video_input_path = generate_video_from_frames(frames, output_path=f"{GRADIO_TEMP_DIR}/inpaint/original_{video_state['video_name']}", fps=fps)
435
+ return (gr.update(visible=True), gr.update(visible=True), inference_state, video_state, video_info,
436
+ video_state["origin_images"][0], gr.update(visible=False, maximum=len(frames), value=1, interactive=True),
437
+ gr.update(visible=False, maximum=len(frames), value=len(frames), interactive=True), gr.update(visible=True, interactive=True),
438
+ gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=True),
439
+ gr.update(visible=True, interactive=False), create_status("Upload video complete. Ready to select targets.", StatusMessage.SUCCESS), video_input_path)
440
+
441
+ def select_template(image_selection_slider, video_state, interactive_state, previous_status):
442
+ image_selection_slider -= 1
443
+ video_state["select_frame_number"] = image_selection_slider
444
+ return video_state["painted_images"][image_selection_slider], video_state, interactive_state, update_status(previous_status, f"Set tracking start at frame {image_selection_slider}.", StatusMessage.INFO)
445
+
446
+ def get_end_number(track_pause_number_slider, video_state, interactive_state, previous_status):
447
+ interactive_state["track_end_number"] = track_pause_number_slider
448
+ return video_state["painted_images"][track_pause_number_slider], interactive_state, update_status(previous_status, f"Set tracking finish at frame {track_pause_number_slider}.", StatusMessage.INFO)
449
+
450
+ def sam_refine(inference_state, video_state, point_prompt, click_state, interactive_state, evt, previous_status):
451
+ ann_obj_id = 0
452
+ ann_frame_idx = video_state["select_frame_number"]
453
+ if point_prompt == "Positive":
454
+ coordinate = f"[[{evt.index[0]},{evt.index[1]},1]]"
455
+ interactive_state["positive_click_times"] += 1
456
+ else:
457
+ coordinate = f"[[{evt.index[0]},{evt.index[1]},0]]"
458
+ interactive_state["negative_click_times"] += 1
459
+ print(f"sam_refine, point_prompt: {point_prompt}, click_state: {click_state}")
460
+ prompt = {"prompt_type":["click"], "input_point": click_state[0], "input_label": click_state[1], "multimask_output": "True"}
461
+ points = np.array(prompt["input_point"])
462
+ labels = np.array(prompt["input_label"])
463
+ height, width = video_state["origin_images"][0].shape[0:2]
464
+ for i in range(len(points)):
465
+ points[i, 0] = int(points[i, 0])
466
+ points[i, 1] = int(points[i, 1])
467
+ print(f"sam_refine points: {points}, labels: {labels}")
468
+ frame_idx, obj_ids, mask = predictor.add_new_points(inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, points=points, labels=labels)
469
+ mask_ = mask.cpu().squeeze().detach().numpy()
470
+ mask_[mask_ <= 0] = 0
471
+ mask_[mask_ > 0] = 1
472
+ org_image = video_state["origin_images"][video_state["select_frame_number"]]
473
+ mask_ = cv2.resize(mask_, (width, height))
474
+ mask_ = mask_[:, :, None]
475
+ mask_[mask_ > 0.5] = 1
476
+ mask_[mask_ <= 0.5] = 0
477
+ color = 63 * np.ones((height, width, 3)) * np.array([[[np.random.randint(5), np.random.randint(5), np.random.randint(5)]]])
478
+ painted_image = np.uint8((1 - 0.5 * mask_) * org_image + 0.5 * mask_ * color)
479
+ video_state["masks"][video_state["select_frame_number"]] = mask_
480
+ video_state["painted_images"][video_state["select_frame_number"]] = painted_image
481
+ return painted_image, video_state, interactive_state, update_status(previous_status, "Segmentation updated. Add more points or continue tracking.", StatusMessage.SUCCESS)
482
+
483
+ def clear_click(inference_state, video_state, click_state, previous_status):
484
+ predictor.reset_state(inference_state)
485
+ click_state = [[], []]
486
+ template_frame = video_state["origin_images"][video_state["select_frame_number"]]
487
+ return inference_state, template_frame, click_state, update_status(previous_status, "Click history cleared.", StatusMessage.INFO)
488
+
489
+ def vos_tracking_video(inference_state, video_state, interactive_state, previous_status):
490
+ height, width = video_state["origin_images"][0].shape[0:2]
491
+ masks = []
492
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
493
+ mask = np.zeros([480, 720, 1])
494
+ for i in range(len(out_mask_logits)):
495
+ out_mask = out_mask_logits[i].cpu().squeeze().detach().numpy()
496
+ out_mask[out_mask > 0] = 1
497
+ out_mask[out_mask <= 0] = 0
498
+ out_mask = out_mask[:, :, None]
499
+ mask += out_mask
500
+ mask = cv2.resize(mask, (width, height))
501
+ mask = mask[:, :, None]
502
+ mask[mask > 0.5] = 1
503
+ mask[mask < 1] = 0
504
+ mask = scipy.ndimage.binary_dilation(mask, iterations=6)
505
+ masks.append(mask)
506
+ masks = np.array(masks)
507
+ if interactive_state.get("track_end_number") is not None:
508
+ video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
509
+ org_images = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
510
+ color = 255 * np.ones((1, org_images.shape[-3], org_images.shape[-2], 3)) * np.array([[[[0, 1, 1]]]])
511
+ painted_images = np.uint8((1 - 0.5 * masks) * org_images + 0.5 * masks * color)
512
+ video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
513
+ else:
514
+ video_state["masks"] = masks
515
+ org_images = video_state["origin_images"]
516
+ color = 255 * np.ones((1, org_images.shape[-3], org_images.shape[-2], 3)) * np.array([[[[0, 1, 1]]]])
517
+ painted_images = np.uint8((1 - 0.5 * masks) * org_images + 0.5 * masks * color)
518
+ video_state["painted_images"] = painted_images
519
+ video_output = generate_video_from_frames(video_state["painted_images"], output_path=f"{GRADIO_TEMP_DIR}/track/{video_state['video_name']}", fps=video_state["fps"])
520
+ interactive_state["inference_times"] += 1
521
+ print(f"vos_tracking_video output: {video_output}")
522
+ return inference_state, video_output, video_state, interactive_state, update_status(previous_status, "Tracking complete.", StatusMessage.SUCCESS), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True), gr.Button.update(interactive=True)
523
+
524
+ def inpaint_video(video_state, video_caption, target_region_frame1_caption, interactive_state, previous_status, seed_param, cfg_scale, dilate_size):
525
+ seed = int(seed_param) if int(seed_param) >= 0 else np.random.randint(0, 2**32 - 1)
526
+ validation_images = video_state["origin_images"]
527
+ validation_masks = video_state["masks"]
528
+ validation_masks = [np.squeeze(mask) for mask in validation_masks]
529
+ validation_masks = [(mask > 0).astype(np.uint8) * 255 for mask in validation_masks]
530
+ validation_masks = [np.stack([m, m, m], axis=-1) for m in validation_masks]
531
+ validation_images = [Image.fromarray(np.uint8(img)).convert('RGB') for img in validation_images]
532
+ validation_masks = [Image.fromarray(np.uint8(mask)).convert('RGB') for mask in validation_masks]
533
+ validation_images = [img.resize((720, 480)) for img in validation_images]
534
+ validation_masks = [mask.resize((720, 480)) for mask in validation_masks]
535
+ print("Inpainting: video_caption=", video_caption)
536
+ images = generate_frames(
537
+ images=validation_images,
538
+ masks=validation_masks,
539
+ pipe=validation_pipeline,
540
+ pipe_img_inpainting=validation_pipeline_img,
541
+ prompt=str(video_caption),
542
+ image_inpainting_prompt=str(target_region_frame1_caption),
543
+ seed=seed,
544
+ cfg_scale=float(cfg_scale),
545
+ dilate_size=int(dilate_size)
546
+ )
547
+ images = (images * 255).astype(np.uint8)
548
+ video_output = generate_video_from_frames(images, output_path=f"{GRADIO_TEMP_DIR}/inpaint/{video_state['video_name']}", fps=8)
549
+ print(f"Inpaint_video output: {video_output}")
550
+ return video_output, update_status(previous_status, "Inpainting complete.", StatusMessage.SUCCESS)
551
+
552
+ def generate_video_from_frames(frames, output_path, fps=8):
553
+ frames_tensor = torch.from_numpy(np.asarray(frames)).to(torch.uint8)
554
+ if not os.path.exists(os.path.dirname(output_path)):
555
+ os.makedirs(os.path.dirname(output_path))
556
+ torchvision.io.write_video(output_path, frames_tensor, fps=fps, video_codec="libx264")
557
+ return output_path
558
+
559
+ def process_example(video_input, video_caption, target_region_frame1_caption, prompt, click_state):
560
+ if video_input is None or video_input == "":
561
+ return (gr.update(value=""), gr.update(value=""), init_state(),
562
+ {"user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 8, "ann_obj_id": 0},
563
+ "", None,
564
+ gr.update(value=1, visible=False, interactive=False),
565
+ gr.update(value=1, visible=False, interactive=False),
566
+ gr.update(value="Positive", interactive=False),
567
+ gr.update(visible=True, interactive=False),
568
+ gr.update(visible=True, interactive=False),
569
+ gr.update(value=None),
570
+ gr.update(visible=True, interactive=False),
571
+ create_status("Reset complete. Ready for new input.", StatusMessage.INFO),
572
+ gr.update(value=None))
573
+ video_state = gr.State({
574
+ "user_name": "",
575
+ "video_name": "",
576
+ "origin_images": None,
577
+ "painted_images": None,
578
+ "masks": None,
579
+ "inpaint_masks": None,
580
+ "logits": None,
581
+ "select_frame_number": 0,
582
+ "fps": 8,
583
+ "ann_obj_id": 0
584
+ })
585
+ results = get_frames_from_video(video_input, video_state)
586
+ if click_state[0] and click_state[1]:
587
+ print("Example detected, executing sam_refine")
588
+ (video_caption, target_region_frame1_caption, inference_state, video_state, video_info, template_frame, image_selection_slider, track_pause_number_slider, point_prompt, clear_button, tracking_button, video_output, inpaint_button, run_status, video_input) = results
589
+ class MockEvent:
590
+ def __init__(self, points, point_idx=0):
591
+ self.index = points[point_idx]
592
+ for i_click in range(len(click_state[0])):
593
+ evt = MockEvent(click_state[0], i_click)
594
+ prompt_type = "Positive" if click_state[1][i_click] == 1 else "Negative"
595
+ template_frame, video_state, interactive_state, run_status = sam_refine(inference_state, video_state, prompt_type, click_state, {"inference_times": 0, "negative_click_times": 0, "positive_click_times": 0, "mask_save": False, "multi_mask": {"mask_names": [], "masks": []}, "track_end_number": None}, evt, run_status)
596
+ return (video_caption, target_region_frame1_caption, inference_state, video_state, video_info, template_frame, image_selection_slider, track_pause_number_slider, point_prompt, clear_button, tracking_button, video_output, inpaint_button, run_status, video_input)
597
+ return results
598
+
599
+ def reset_all():
600
+ return (gr.update(value=None), gr.update(value=""), gr.update(value=""), init_state(),
601
+ {"user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 8, "ann_obj_id": 0},
602
+ {"inference_times": 0, "negative_click_times": 0, "positive_click_times": 0, "mask_save": False, "multi_mask": {"mask_names": [], "masks": []}, "track_end_number": None},
603
+ [[], []], None, gr.update(visible=True, interactive=True), "",
604
+ gr.update(value=1, visible=False, interactive=False), gr.update(value=1, visible=False, interactive=False),
605
+ gr.update(value="Positive", interactive=False), gr.Button.update(interactive=False),
606
+ gr.Button.update(interactive=False), gr.Button.update(interactive=False),
607
+ gr.Button.update(interactive=False), gr.Button.update(interactive=False),
608
+ gr.Button.update(interactive=False), gr.Number.update(value=42),
609
+ gr.Slider.update(value=6.0), gr.Slider.update(value=16),
610
+ create_status("Reset complete. Ready for new input.", StatusMessage.INFO))
611
+
612
+ ###############################
613
+ # Build Gradio Interface
614
+ ###############################
615
+ title = """<p><h1 align="center">VideoPainter</h1></p>"""
616
+ with gr.Blocks() as iface:
617
+ gr.HTML("""
618
+ <div style="text-align: center;">
619
+ <h1 style="color: #333;">πŸ–ŒοΈ VideoPainter</h1>
620
+ <h3 style="color: #333;">Any-length Video Inpainting and Editing with Plug-and-Play Context Control</h3>
621
+ <p style="font-weight: bold;">
622
+ <a href="https://yxbian23.github.io/project/video-painter/">🌍 Project Page</a> |
623
+ <a href="https://arxiv.org/abs/2503.05639">πŸ“ƒ ArXiv Preprint</a> |
624
+ <a href="https://github.com/TencentARC/VideoPainter">πŸ§‘β€πŸ’» Github Repository</a>
625
+ </p>
626
+ </div>
627
+ """)
628
+ click_state = gr.State([[], []])
629
+ interactive_state = gr.State({
630
+ "inference_times": 0,
631
+ "negative_click_times": 0,
632
+ "positive_click_times": 0,
633
+ "mask_save": False,
634
+ "multi_mask": {"mask_names": [], "masks": []},
635
+ "track_end_number": None,
636
+ })
637
+ video_state = gr.State({
638
+ "user_name": "",
639
+ "video_name": "",
640
+ "origin_images": None,
641
+ "painted_images": None,
642
+ "masks": None,
643
+ "inpaint_masks": None,
644
+ "logits": None,
645
+ "select_frame_number": 0,
646
+ "fps": 8,
647
+ "ann_obj_id": 0
648
+ })
649
+ inference_state = init_state()
650
+
651
+ with gr.Row():
652
+ with gr.Column():
653
+ with gr.Row():
654
+ video_input = gr.Video(label="Original Video", visible=True)
655
+ with gr.Row():
656
+ with gr.Column(scale=3):
657
+ template_frame = gr.Image(type="pil", interactive=True, elem_id="template_frame", visible=True)
658
+ with gr.Column(scale=1):
659
+ with gr.Accordion("Segmentation Point Prompt", open=True):
660
+ point_prompt = gr.Radio(choices=["Positive", "Negative"], value="Positive", label="Point Type", interactive=False, visible=True)
661
+ clear_button_click = gr.Button(value="Clear clicks", interactive=False, visible=True)
662
+ gr.Markdown("✨ Positive: Include target region. <br> ✨ Negative: Exclude target region.")
663
+ image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False, interactive=False)
664
+ track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False, interactive=False)
665
+ video_output = gr.Video(label="Generated Video", visible=True)
666
+ with gr.Row():
667
+ tracking_video_predict_button = gr.Button(value="Tracking", interactive=False, visible=True)
668
+ inpaint_video_predict_button = gr.Button(value="Inpainting", interactive=False, visible=True)
669
+ reset_button = gr.Button(value="Reset All", interactive=True, visible=True)
670
+
671
+ with gr.Column():
672
+ with gr.Accordion("Global Video Caption", open=True):
673
+ video_caption = gr.Textbox(label="Global Video Caption", placeholder="Input global video caption...", interactive=True, visible=True, max_lines=5, show_copy_button=True)
674
+ with gr.Row():
675
+ gr.Markdown("✨ Enhance prompt using GPT-4o (optional).")
676
+ enhance_button = gr.Button("✨ Enhance Prompt(Optional)", interactive=False)
677
+ with gr.Accordion("Target Object Caption", open=True):
678
+ target_region_frame1_caption = gr.Textbox(label="Target Object Caption", placeholder="Input target object caption...", interactive=True, visible=True, max_lines=5, show_copy_button=True)
679
+ with gr.Row():
680
+ gr.Markdown("✨ Generate target caption (optional).")
681
+ enhance_target_region_frame1_button = gr.Button("✨ Target Prompt Generation (Optional)", interactive=False)
682
+ with gr.Accordion("Editing Instruction", open=False):
683
+ gr.Markdown("✨ Modify captions based on your instruction using GPT-4o.")
684
+ with gr.Row():
685
+ editing_instruction = gr.Textbox(label="Editing Instruction", placeholder="Input editing instruction...", interactive=True, visible=True, max_lines=5, show_copy_button=True)
686
+ enhance_editing_instruction_button = gr.Button("✨ Modify Caption(For Editing)", interactive=False)
687
+ with gr.Accordion("Advanced Sampling Settings", open=False):
688
+ cfg_scale = gr.Slider(value=6.0, label="Classifier-Free Guidance Scale", minimum=1, maximum=10, step=0.1, interactive=True)
689
+ seed_param = gr.Number(label="Inference Seed (>=0)", interactive=True, value=42)
690
+ dilate_size = gr.Slider(value=16, label="Mask Dilate Size", minimum=0, maximum=32, step=1, interactive=True)
691
+ video_info = gr.Textbox(label="Video Info", visible=True, interactive=False)
692
+ model_type = gr.Textbox(label="Type", placeholder="Model type...", interactive=True, visible=False)
693
+ notes_accordion = gr.Accordion("Notes", open=False)
694
+ with notes_accordion:
695
+ gr.HTML("<p style='font-size: 1.1em;'>🧐 Reminder: VideoPainter may produce unexpected outputs. Adjust settings if needed.</p>")
696
+ run_status = gr.HighlightedText(value=[("", "")], visible=True, label="Operation Status", show_label=True,
697
+ color_map={"Success": "green", "Error": "red", "Warning": "orange", "Info": "blue"})
698
+
699
+ with gr.Row():
700
+ examples = gr.Examples(label="Quick Examples", examples=EXAMPLES,
701
+ inputs=[video_input, video_caption, target_region_frame1_caption, point_prompt, model_type, editing_instruction, seed_param, cfg_scale, dilate_size, click_state],
702
+ examples_per_page=20, cache_examples=False)
703
+
704
+ video_input.change(fn=process_example, inputs=[video_input, video_caption, target_region_frame1_caption, point_prompt, click_state],
705
+ outputs=[video_caption, target_region_frame1_caption, inference_state, video_state, video_info,
706
+ template_frame, image_selection_slider, track_pause_number_slider, point_prompt, clear_button_click,
707
+ tracking_video_predict_button, video_output, inpaint_video_predict_button, run_status, video_input])
708
+
709
+ image_selection_slider.release(fn=select_template, inputs=[image_selection_slider, video_state, interactive_state, run_status],
710
+ outputs=[template_frame, video_state, interactive_state, run_status])
711
+
712
+ track_pause_number_slider.release(fn=get_end_number, inputs=[track_pause_number_slider, video_state, interactive_state, run_status],
713
+ outputs=[template_frame, interactive_state, run_status])
714
+
715
+ template_frame.select(fn=sam_refine, inputs=[inference_state, video_state, point_prompt, click_state, interactive_state, run_status],
716
+ outputs=[template_frame, video_state, interactive_state, run_status])
717
+
718
+ tracking_video_predict_button.click(fn=vos_tracking_video, inputs=[inference_state, video_state, interactive_state, run_status],
719
+ outputs=[inference_state, video_output, video_state, interactive_state, run_status,
720
+ inpaint_video_predict_button, enhance_button, enhance_target_region_frame1_button, enhance_editing_instruction_button, notes_accordion])
721
+
722
+ inpaint_video_predict_button.click(fn=inpaint_video, inputs=[video_state, video_caption, target_region_frame1_caption, interactive_state, run_status, seed_param, cfg_scale, dilate_size],
723
+ outputs=[video_output, run_status], api_name=False, show_progress="full")
724
+
725
+ def enhance_prompt_func(video_caption):
726
+ return video_caption # Replace with your convert_prompt() if available
727
+
728
+ def enhance_target_region_frame1_prompt_func(target_region_frame1_caption, video_state):
729
+ return target_region_frame1_caption # Replace with your convert_prompt_target_region_frame1() if available
730
+
731
+ def enhance_editing_instruction_prompt_func(editing_instruction, video_caption, target_region_frame1_caption, video_state):
732
+ return video_caption, target_region_frame1_caption # Replace with your convert_prompt_editing_instruction() if available
733
+
734
+ enhance_button.click(enhance_prompt_func, inputs=[video_caption], outputs=[video_caption])
735
+ enhance_target_region_frame1_button.click(enhance_target_region_frame1_prompt_func, inputs=[target_region_frame1_caption, video_state], outputs=[target_region_frame1_caption])
736
+ enhance_editing_instruction_button.click(enhance_editing_instruction_prompt_func, inputs=[editing_instruction, video_caption, target_region_frame1_caption, video_state],
737
+ outputs=[video_caption, target_region_frame1_caption])
738
+
739
+ video_input.clear(fn=lambda: (gr.update(visible=True), gr.update(visible=True), init_state(),
740
+ {"user_name": "", "video_name": "", "origin_images": None, "painted_images": None, "masks": None, "inpaint_masks": None, "logits": None, "select_frame_number": 0, "fps": 8, "ann_obj_id": 0},
741
+ {"inference_times": 0, "negative_click_times": 0, "positive_click_times": 0, "mask_save": False, "multi_mask": {"mask_names": [], "masks": []}, "track_end_number": 0},
742
+ [[], []], None, None,
743
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True),
744
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True, value=[]),
745
+ gr.update(visible=True), gr.update(visible=True), gr.update(visible=True),
746
+ gr.Button.update(interactive=False), gr.Button.update(interactive=False), gr.Button.update(interactive=False)),
747
+ outputs=[video_caption, target_region_frame1_caption, inference_state, video_state, interactive_state, click_state, video_output, template_frame, tracking_video_predict_button, image_selection_slider, track_pause_number_slider, point_prompt, clear_button_click, template_frame, tracking_video_predict_button, video_output, inpaint_video_predict_button, run_status], queue=False, show_progress=False)
748
+
749
+ clear_button_click.click(fn=clear_click, inputs=[inference_state, video_state, click_state, run_status],
750
+ outputs=[inference_state, template_frame, click_state, run_status])
751
+
752
+ reset_button.click(fn=reset_all, inputs=[], outputs=[video_input, video_caption, target_region_frame1_caption, inference_state, video_state, interactive_state, click_state, video_output, template_frame, video_info, image_selection_slider, track_pause_number_slider, point_prompt, clear_button_click, tracking_video_predict_button, inpaint_video_predict_button, enhance_button, enhance_target_region_frame1_button, enhance_editing_instruction_button, seed_param, cfg_scale, dilate_size, run_status])
753
+
754
+ iface.queue().launch(share=False)