ghostsInTheMachine
commited on
Commit
•
afe7cc3
1
Parent(s):
8e07e41
Update app.py
Browse files
app.py
CHANGED
@@ -8,10 +8,8 @@ import ffmpeg
|
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
import moviepy.editor as mp
|
11 |
-
from infer import lotus
|
12 |
import logging
|
13 |
-
import io
|
14 |
-
from multiprocessing import Pool, cpu_count
|
15 |
|
16 |
# Set up logging
|
17 |
logging.basicConfig(level=logging.INFO)
|
@@ -20,105 +18,98 @@ logger = logging.getLogger(__name__)
|
|
20 |
# Set device to use the L40s GPU explicitly
|
21 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
22 |
|
|
|
|
|
|
|
|
|
23 |
# Preprocess the video to adjust resolution and frame rate
|
24 |
def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
|
25 |
"""Preprocess the video to resize and adjust its frame rate."""
|
26 |
video = mp.VideoFileClip(video_path)
|
27 |
-
|
28 |
# Resize video if it's larger than the target resolution
|
29 |
if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]:
|
30 |
video = video.resize(height=max_resolution[1])
|
31 |
-
|
32 |
# Adjust FPS if target_fps is specified
|
33 |
if target_fps > 0:
|
34 |
video = video.set_fps(target_fps)
|
35 |
-
|
36 |
return video
|
37 |
|
38 |
-
# Process a
|
39 |
-
def
|
40 |
-
"""Process a
|
41 |
-
frame_index, frame_data, seed = args
|
42 |
try:
|
43 |
-
#
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
img_bytes = io.BytesIO()
|
52 |
-
image.save(img_bytes, format='PNG')
|
53 |
-
img_bytes.seek(0) # Reset file pointer to the beginning
|
54 |
-
|
55 |
-
# Process through the depth model
|
56 |
-
_, output_d = lotus(img_bytes, 'depth', seed, device)
|
57 |
-
|
58 |
-
# Convert depth output to numpy array
|
59 |
-
depth_array = np.array(output_d)
|
60 |
-
return (frame_index, depth_array)
|
61 |
-
|
62 |
except Exception as e:
|
63 |
-
logger.error(f"Error processing
|
64 |
-
return (
|
65 |
|
66 |
# Process video frames and generate depth maps
|
67 |
-
def process_video(video_path, fps=0, seed=0,
|
68 |
-
"""Process video frames in
|
69 |
# Create a persistent temporary directory
|
70 |
temp_dir = tempfile.mkdtemp()
|
71 |
try:
|
72 |
start_time = time.time()
|
73 |
-
|
74 |
# Preprocess the video
|
75 |
video = preprocess_video(video_path, target_fps=fps)
|
76 |
-
|
77 |
# Use original video FPS if not specified
|
78 |
if fps == 0:
|
79 |
fps = video.fps
|
80 |
-
|
81 |
frames = list(video.iter_frames(fps=video.fps))
|
82 |
total_frames = len(frames)
|
83 |
-
|
84 |
logger.info(f"Processing {total_frames} frames at {fps} FPS...")
|
85 |
-
|
86 |
# Create directory for frame sequence and outputs
|
87 |
frames_dir = os.path.join(temp_dir, "frames")
|
88 |
os.makedirs(frames_dir, exist_ok=True)
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
98 |
if depth_map is not None:
|
99 |
# Save frame
|
100 |
frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png")
|
101 |
-
|
102 |
-
|
103 |
-
# Update preview every 10% progress
|
104 |
-
if
|
105 |
elapsed_time = time.time() - start_time
|
106 |
-
progress = (
|
107 |
-
yield depth_map, None, None, f"Processed {frame_index
|
108 |
else:
|
109 |
-
logger.error(f"
|
110 |
-
|
111 |
logger.info("Creating output files...")
|
112 |
-
|
113 |
# Create ZIP of frame sequence
|
114 |
zip_filename = f"depth_frames_{int(time.time())}.zip"
|
115 |
zip_path = os.path.join(temp_dir, zip_filename)
|
116 |
shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
|
117 |
-
|
118 |
# Create MP4 video
|
119 |
video_filename = f"depth_video_{int(time.time())}.mp4"
|
120 |
output_video_path = os.path.join(temp_dir, video_filename)
|
121 |
-
|
122 |
try:
|
123 |
# FFmpeg settings for high-quality MP4
|
124 |
(
|
@@ -128,31 +119,29 @@ def process_video(video_path, fps=0, seed=0, num_workers=4):
|
|
128 |
.run(overwrite_output=True, quiet=True)
|
129 |
)
|
130 |
logger.info("MP4 video created successfully!")
|
131 |
-
|
132 |
except ffmpeg.Error as e:
|
133 |
logger.error(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}")
|
134 |
output_video_path = None
|
135 |
-
|
136 |
total_time = time.time() - start_time
|
137 |
logger.info("Processing complete!")
|
138 |
-
|
139 |
# Yield the file paths
|
140 |
yield None, zip_path, output_video_path, f"Processing complete! Total time: {total_time:.2f} seconds"
|
141 |
-
|
142 |
except Exception as e:
|
143 |
logger.error(f"Error: {e}")
|
144 |
yield None, None, None, f"Error processing video: {e}"
|
145 |
-
# Do not delete temp_dir here; we need the files to persist
|
146 |
-
# Cleanup can be handled elsewhere if necessary
|
147 |
|
148 |
# Wrapper function with error handling
|
149 |
-
def process_wrapper(video, fps=0, seed=0,
|
150 |
if video is None:
|
151 |
raise gr.Error("Please upload a video.")
|
152 |
try:
|
153 |
outputs = []
|
154 |
# Use video directly, since it's the file path
|
155 |
-
for output in process_video(video, fps, seed,
|
156 |
outputs.append(output)
|
157 |
yield output
|
158 |
return outputs[-1]
|
@@ -161,7 +150,33 @@ def process_wrapper(video, fps=0, seed=0, num_workers=4):
|
|
161 |
|
162 |
# Custom CSS for styling (unchanged)
|
163 |
custom_css = """
|
164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
"""
|
166 |
|
167 |
# Gradio Interface
|
@@ -171,24 +186,24 @@ with gr.Blocks(css=custom_css) as demo:
|
|
171 |
<div id="title">Video Depth Estimation</div>
|
172 |
</div>
|
173 |
''')
|
174 |
-
|
175 |
with gr.Row():
|
176 |
with gr.Column():
|
177 |
video_input = gr.Video(label="Upload Video", interactive=True)
|
178 |
fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
|
179 |
seed_slider = gr.Number(value=0, label="Seed")
|
180 |
-
|
181 |
btn = gr.Button("Process Video")
|
182 |
-
|
183 |
with gr.Column():
|
184 |
preview_image = gr.Image(label="Live Preview")
|
185 |
output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)")
|
186 |
output_video = gr.File(label="Download Video (MP4)")
|
187 |
time_textbox = gr.Textbox(label="Status", interactive=False)
|
188 |
-
|
189 |
btn.click(
|
190 |
fn=process_wrapper,
|
191 |
-
inputs=[video_input, fps_slider, seed_slider,
|
192 |
outputs=[preview_image, output_frames_zip, output_video, time_textbox]
|
193 |
)
|
194 |
|
|
|
8 |
import numpy as np
|
9 |
from PIL import Image
|
10 |
import moviepy.editor as mp
|
11 |
+
from infer import lotus, load_models
|
12 |
import logging
|
|
|
|
|
13 |
|
14 |
# Set up logging
|
15 |
logging.basicConfig(level=logging.INFO)
|
|
|
18 |
# Set device to use the L40s GPU explicitly
|
19 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
20 |
|
21 |
+
# Load models once
|
22 |
+
task_name = 'depth'
|
23 |
+
pipe_g, pipe_d = load_models(task_name, device)
|
24 |
+
|
25 |
# Preprocess the video to adjust resolution and frame rate
|
26 |
def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
|
27 |
"""Preprocess the video to resize and adjust its frame rate."""
|
28 |
video = mp.VideoFileClip(video_path)
|
29 |
+
|
30 |
# Resize video if it's larger than the target resolution
|
31 |
if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]:
|
32 |
video = video.resize(height=max_resolution[1])
|
33 |
+
|
34 |
# Adjust FPS if target_fps is specified
|
35 |
if target_fps > 0:
|
36 |
video = video.set_fps(target_fps)
|
37 |
+
|
38 |
return video
|
39 |
|
40 |
+
# Process a batch of frames through the depth model
|
41 |
+
def process_frames_batch(frames_batch, seed=0):
|
42 |
+
"""Process a batch of frames and return depth maps."""
|
|
|
43 |
try:
|
44 |
+
# Convert frames to PIL Images
|
45 |
+
images_batch = [Image.fromarray(frame).convert('RGB') for frame in frames_batch]
|
46 |
+
|
47 |
+
# Run batch inference
|
48 |
+
depth_maps = lotus(images_batch, 'depth', seed, device, pipe_g, pipe_d)
|
49 |
+
|
50 |
+
return depth_maps
|
51 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
except Exception as e:
|
53 |
+
logger.error(f"Error processing batch: {e}")
|
54 |
+
return [None] * len(frames_batch)
|
55 |
|
56 |
# Process video frames and generate depth maps
|
57 |
+
def process_video(video_path, fps=0, seed=0, batch_size=16):
|
58 |
+
"""Process video frames in batches and generate depth maps."""
|
59 |
# Create a persistent temporary directory
|
60 |
temp_dir = tempfile.mkdtemp()
|
61 |
try:
|
62 |
start_time = time.time()
|
63 |
+
|
64 |
# Preprocess the video
|
65 |
video = preprocess_video(video_path, target_fps=fps)
|
66 |
+
|
67 |
# Use original video FPS if not specified
|
68 |
if fps == 0:
|
69 |
fps = video.fps
|
70 |
+
|
71 |
frames = list(video.iter_frames(fps=video.fps))
|
72 |
total_frames = len(frames)
|
73 |
+
|
74 |
logger.info(f"Processing {total_frames} frames at {fps} FPS...")
|
75 |
+
|
76 |
# Create directory for frame sequence and outputs
|
77 |
frames_dir = os.path.join(temp_dir, "frames")
|
78 |
os.makedirs(frames_dir, exist_ok=True)
|
79 |
+
|
80 |
+
processed_frames = []
|
81 |
+
|
82 |
+
# Process frames in batches
|
83 |
+
for i in range(0, total_frames, batch_size):
|
84 |
+
frames_batch = frames[i:i+batch_size]
|
85 |
+
depth_maps = process_frames_batch(frames_batch, seed)
|
86 |
+
|
87 |
+
for j, depth_map in enumerate(depth_maps):
|
88 |
+
frame_index = i + j
|
89 |
if depth_map is not None:
|
90 |
# Save frame
|
91 |
frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png")
|
92 |
+
depth_map.save(frame_path)
|
93 |
+
|
94 |
+
# Update live preview every 10% progress
|
95 |
+
if frame_index % max(1, total_frames // 10) == 0:
|
96 |
elapsed_time = time.time() - start_time
|
97 |
+
progress = (frame_index / total_frames) * 100
|
98 |
+
yield depth_map, None, None, f"Processed {frame_index}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s"
|
99 |
else:
|
100 |
+
logger.error(f"Error processing frame {frame_index}")
|
101 |
+
|
102 |
logger.info("Creating output files...")
|
103 |
+
|
104 |
# Create ZIP of frame sequence
|
105 |
zip_filename = f"depth_frames_{int(time.time())}.zip"
|
106 |
zip_path = os.path.join(temp_dir, zip_filename)
|
107 |
shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
|
108 |
+
|
109 |
# Create MP4 video
|
110 |
video_filename = f"depth_video_{int(time.time())}.mp4"
|
111 |
output_video_path = os.path.join(temp_dir, video_filename)
|
112 |
+
|
113 |
try:
|
114 |
# FFmpeg settings for high-quality MP4
|
115 |
(
|
|
|
119 |
.run(overwrite_output=True, quiet=True)
|
120 |
)
|
121 |
logger.info("MP4 video created successfully!")
|
122 |
+
|
123 |
except ffmpeg.Error as e:
|
124 |
logger.error(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}")
|
125 |
output_video_path = None
|
126 |
+
|
127 |
total_time = time.time() - start_time
|
128 |
logger.info("Processing complete!")
|
129 |
+
|
130 |
# Yield the file paths
|
131 |
yield None, zip_path, output_video_path, f"Processing complete! Total time: {total_time:.2f} seconds"
|
132 |
+
|
133 |
except Exception as e:
|
134 |
logger.error(f"Error: {e}")
|
135 |
yield None, None, None, f"Error processing video: {e}"
|
|
|
|
|
136 |
|
137 |
# Wrapper function with error handling
|
138 |
+
def process_wrapper(video, fps=0, seed=0, batch_size=16):
|
139 |
if video is None:
|
140 |
raise gr.Error("Please upload a video.")
|
141 |
try:
|
142 |
outputs = []
|
143 |
# Use video directly, since it's the file path
|
144 |
+
for output in process_video(video, fps, seed, batch_size):
|
145 |
outputs.append(output)
|
146 |
yield output
|
147 |
return outputs[-1]
|
|
|
150 |
|
151 |
# Custom CSS for styling (unchanged)
|
152 |
custom_css = """
|
153 |
+
.title-container {
|
154 |
+
text-align: center;
|
155 |
+
padding: 10px 0;
|
156 |
+
}
|
157 |
+
|
158 |
+
#title {
|
159 |
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
160 |
+
font-size: 36px;
|
161 |
+
font-weight: bold;
|
162 |
+
color: #000000;
|
163 |
+
padding: 10px;
|
164 |
+
border-radius: 10px;
|
165 |
+
display: inline-block;
|
166 |
+
background: linear-gradient(
|
167 |
+
135deg,
|
168 |
+
#e0f7fa, #e8f5e9, #fff9c4, #ffebee,
|
169 |
+
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
|
170 |
+
);
|
171 |
+
background-size: 400% 400%;
|
172 |
+
animation: gradient-animation 15s ease infinite;
|
173 |
+
}
|
174 |
+
|
175 |
+
@keyframes gradient-animation {
|
176 |
+
0% { background-position: 0% 50%; }
|
177 |
+
50% { background-position: 100% 50%; }
|
178 |
+
100% { background-position: 0% 50%; }
|
179 |
+
}
|
180 |
"""
|
181 |
|
182 |
# Gradio Interface
|
|
|
186 |
<div id="title">Video Depth Estimation</div>
|
187 |
</div>
|
188 |
''')
|
189 |
+
|
190 |
with gr.Row():
|
191 |
with gr.Column():
|
192 |
video_input = gr.Video(label="Upload Video", interactive=True)
|
193 |
fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
|
194 |
seed_slider = gr.Number(value=0, label="Seed")
|
195 |
+
batch_size_slider = gr.Slider(minimum=1, maximum=64, step=1, value=16, label="Batch Size")
|
196 |
btn = gr.Button("Process Video")
|
197 |
+
|
198 |
with gr.Column():
|
199 |
preview_image = gr.Image(label="Live Preview")
|
200 |
output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)")
|
201 |
output_video = gr.File(label="Download Video (MP4)")
|
202 |
time_textbox = gr.Textbox(label="Status", interactive=False)
|
203 |
+
|
204 |
btn.click(
|
205 |
fn=process_wrapper,
|
206 |
+
inputs=[video_input, fps_slider, seed_slider, batch_size_slider],
|
207 |
outputs=[preview_image, output_frames_zip, output_video, time_textbox]
|
208 |
)
|
209 |
|