ghostsInTheMachine
commited on
Commit
•
d5d8098
1
Parent(s):
8c25de0
Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,7 @@ task_name = 'depth'
|
|
23 |
pipe_g, pipe_d = load_models(task_name, device)
|
24 |
|
25 |
# Preprocess the video to adjust resolution and frame rate
|
26 |
-
def preprocess_video(video_path, target_fps=24, max_resolution=(
|
27 |
"""Preprocess the video to resize and adjust its frame rate."""
|
28 |
video = mp.VideoFileClip(video_path)
|
29 |
|
@@ -38,11 +38,13 @@ def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
|
|
38 |
return video
|
39 |
|
40 |
# Process a batch of frames through the depth model
|
41 |
-
def process_frames_batch(frames_batch, seed=0):
|
42 |
"""Process a batch of frames and return depth maps."""
|
43 |
try:
|
44 |
-
#
|
45 |
-
|
|
|
|
|
46 |
|
47 |
# Run batch inference
|
48 |
depth_maps = lotus(images_batch, 'depth', seed, device, pipe_g, pipe_d)
|
@@ -54,7 +56,7 @@ def process_frames_batch(frames_batch, seed=0):
|
|
54 |
return [None] * len(frames_batch)
|
55 |
|
56 |
# Process video frames and generate depth maps
|
57 |
-
def process_video(video_path, fps=0, seed=0, batch_size=
|
58 |
"""Process video frames in batches and generate depth maps."""
|
59 |
# Create a persistent temporary directory
|
60 |
temp_dir = tempfile.mkdtemp()
|
@@ -62,7 +64,7 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
|
|
62 |
start_time = time.time()
|
63 |
|
64 |
# Preprocess the video
|
65 |
-
video = preprocess_video(video_path, target_fps=fps)
|
66 |
|
67 |
# Use original video FPS if not specified
|
68 |
if fps == 0:
|
@@ -81,8 +83,19 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
|
|
81 |
|
82 |
# Process frames in batches
|
83 |
for i in range(0, total_frames, batch_size):
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
for j, depth_map in enumerate(depth_maps):
|
88 |
frame_index = i + j
|
@@ -135,7 +148,7 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
|
|
135 |
yield None, None, None, f"Error processing video: {e}"
|
136 |
|
137 |
# Wrapper function with error handling
|
138 |
-
def process_wrapper(video, fps=0, seed=0, batch_size=
|
139 |
if video is None:
|
140 |
raise gr.Error("Please upload a video.")
|
141 |
try:
|
@@ -192,7 +205,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
192 |
video_input = gr.Video(label="Upload Video", interactive=True)
|
193 |
fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
|
194 |
seed_slider = gr.Number(value=0, label="Seed")
|
195 |
-
batch_size_slider = gr.Slider(minimum=1, maximum=
|
196 |
btn = gr.Button("Process Video")
|
197 |
|
198 |
with gr.Column():
|
|
|
23 |
pipe_g, pipe_d = load_models(task_name, device)
|
24 |
|
25 |
# Preprocess the video to adjust resolution and frame rate
|
26 |
+
def preprocess_video(video_path, target_fps=24, max_resolution=(512, 512)):
|
27 |
"""Preprocess the video to resize and adjust its frame rate."""
|
28 |
video = mp.VideoFileClip(video_path)
|
29 |
|
|
|
38 |
return video
|
39 |
|
40 |
# Process a batch of frames through the depth model
|
41 |
+
def process_frames_batch(frames_batch, seed=0, target_size=(512, 512)):
|
42 |
"""Process a batch of frames and return depth maps."""
|
43 |
try:
|
44 |
+
torch.cuda.empty_cache() # Clear GPU cache
|
45 |
+
|
46 |
+
# Resize frames to the target size
|
47 |
+
images_batch = [Image.fromarray(frame).convert('RGB').resize(target_size, Image.BILINEAR) for frame in frames_batch]
|
48 |
|
49 |
# Run batch inference
|
50 |
depth_maps = lotus(images_batch, 'depth', seed, device, pipe_g, pipe_d)
|
|
|
56 |
return [None] * len(frames_batch)
|
57 |
|
58 |
# Process video frames and generate depth maps
|
59 |
+
def process_video(video_path, fps=0, seed=0, batch_size=4):
|
60 |
"""Process video frames in batches and generate depth maps."""
|
61 |
# Create a persistent temporary directory
|
62 |
temp_dir = tempfile.mkdtemp()
|
|
|
64 |
start_time = time.time()
|
65 |
|
66 |
# Preprocess the video
|
67 |
+
video = preprocess_video(video_path, target_fps=fps, max_resolution=(512, 512))
|
68 |
|
69 |
# Use original video FPS if not specified
|
70 |
if fps == 0:
|
|
|
83 |
|
84 |
# Process frames in batches
|
85 |
for i in range(0, total_frames, batch_size):
|
86 |
+
current_batch_size = batch_size
|
87 |
+
while current_batch_size > 0:
|
88 |
+
try:
|
89 |
+
frames_batch = frames[i:i+current_batch_size]
|
90 |
+
depth_maps = process_frames_batch(frames_batch, seed)
|
91 |
+
break
|
92 |
+
except RuntimeError as e:
|
93 |
+
if 'out of memory' in str(e):
|
94 |
+
current_batch_size = current_batch_size // 2
|
95 |
+
logger.warning(f"Reducing batch size to {current_batch_size} due to out of memory error.")
|
96 |
+
torch.cuda.empty_cache()
|
97 |
+
else:
|
98 |
+
raise e
|
99 |
|
100 |
for j, depth_map in enumerate(depth_maps):
|
101 |
frame_index = i + j
|
|
|
148 |
yield None, None, None, f"Error processing video: {e}"
|
149 |
|
150 |
# Wrapper function with error handling
|
151 |
+
def process_wrapper(video, fps=0, seed=0, batch_size=4):
|
152 |
if video is None:
|
153 |
raise gr.Error("Please upload a video.")
|
154 |
try:
|
|
|
205 |
video_input = gr.Video(label="Upload Video", interactive=True)
|
206 |
fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
|
207 |
seed_slider = gr.Number(value=0, label="Seed")
|
208 |
+
batch_size_slider = gr.Slider(minimum=1, maximum=16, step=1, value=4, label="Batch Size")
|
209 |
btn = gr.Button("Process Video")
|
210 |
|
211 |
with gr.Column():
|