|
import gradio as gr |
|
import torch |
|
import os |
|
import tempfile |
|
import shutil |
|
import time |
|
import ffmpeg |
|
import numpy as np |
|
from PIL import Image |
|
from concurrent.futures import ThreadPoolExecutor |
|
import moviepy.editor as mp |
|
from infer import lotus |
|
import spaces |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)): |
|
"""Preprocess the video to resize and reduce its frame rate.""" |
|
video = mp.VideoFileClip(video_path) |
|
|
|
|
|
if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]: |
|
video = video.resize(newsize=max_resolution) |
|
|
|
|
|
video = video.set_fps(target_fps) |
|
|
|
return video |
|
|
|
def process_frame(frame, seed=0): |
|
"""Process a single frame through the depth model and return depth map.""" |
|
try: |
|
|
|
image = Image.fromarray(frame) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp: |
|
image.save(tmp.name) |
|
|
|
|
|
_, output_d = lotus(tmp.name, 'depth', seed, device) |
|
|
|
|
|
os.unlink(tmp.name) |
|
|
|
|
|
depth_array = np.array(output_d) |
|
return depth_array |
|
|
|
except Exception as e: |
|
print(f"Error processing frame: {e}") |
|
return None |
|
|
|
@spaces.GPU |
|
def process_video(video_path, fps=0, seed=0, max_workers=32): |
|
"""Process video, batch frames, and use L40s GPU to generate depth maps.""" |
|
temp_dir = None |
|
try: |
|
start_time = time.time() |
|
|
|
|
|
video = preprocess_video(video_path) |
|
|
|
|
|
if fps == 0: |
|
fps = video.fps |
|
|
|
frames = list(video.iter_frames(fps=fps)) |
|
total_frames = len(frames) |
|
|
|
print(f"Processing {total_frames} frames at {fps} FPS...") |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
frames_dir = os.path.join(temp_dir, "frames") |
|
os.makedirs(frames_dir, exist_ok=True) |
|
|
|
|
|
batch_size = 50 |
|
processed_frames = [] |
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
for i in range(0, total_frames, batch_size): |
|
futures = [executor.submit(process_frame, frames[j], seed) for j in range(i, min(i + batch_size, total_frames))] |
|
for j, future in enumerate(futures): |
|
try: |
|
result = future.result() |
|
if result is not None: |
|
|
|
frame_path = os.path.join(frames_dir, f"frame_{i+j:06d}.png") |
|
Image.fromarray(result).save(frame_path) |
|
|
|
|
|
processed_frames.append(result) |
|
|
|
|
|
if (i + j + 1) % 10 == 0: |
|
elapsed_time = time.time() - start_time |
|
yield processed_frames[-1], None, None, f"Processed {i+j+1}/{total_frames} frames... Elapsed: {elapsed_time:.2f}s" |
|
except Exception as e: |
|
print(f"Error processing frame {i + j + 1}: {e}") |
|
|
|
print("Creating output files...") |
|
|
|
output_dir = os.path.join(os.path.dirname(video_path), "output") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
zip_filename = f"depth_frames_{int(time.time())}.zip" |
|
zip_path = os.path.join(output_dir, zip_filename) |
|
shutil.make_archive(zip_path[:-4], 'zip', frames_dir) |
|
|
|
|
|
video_filename = f"depth_video_{int(time.time())}.mp4" |
|
video_path = os.path.join(output_dir, video_filename) |
|
|
|
try: |
|
|
|
stream = ffmpeg.input( |
|
os.path.join(frames_dir, 'frame_%06d.png'), |
|
pattern_type='sequence', |
|
framerate=fps |
|
) |
|
|
|
stream = ffmpeg.output( |
|
stream, |
|
video_path, |
|
vcodec='libx264', |
|
pix_fmt='yuv420p', |
|
crf=17, |
|
threads=max_workers |
|
) |
|
|
|
ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True) |
|
print("MP4 video created successfully!") |
|
|
|
except ffmpeg.Error as e: |
|
print(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}") |
|
video_path = None |
|
|
|
print("Processing complete!") |
|
yield None, zip_path, video_path, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds" |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
yield None, None, None, f"Error processing video: {e}" |
|
finally: |
|
if temp_dir and os.path.exists(temp_dir): |
|
try: |
|
shutil.rmtree(temp_dir) |
|
except Exception as e: |
|
print(f"Error cleaning up temp directory: {e}") |
|
|
|
def process_wrapper(video, fps=0, seed=0, max_workers=32): |
|
if video is None: |
|
raise gr.Error("Please upload a video.") |
|
try: |
|
outputs = [] |
|
for output in process_video(video, fps, seed, max_workers): |
|
outputs.append(output) |
|
yield output |
|
return outputs[-1] |
|
except Exception as e: |
|
raise gr.Error(f"Error processing video: {str(e)}") |
|
|
|
|
|
custom_css = """ |
|
.title-container { |
|
text-align: center; |
|
padding: 10px 0; |
|
} |
|
|
|
#title { |
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; |
|
font-size: 36px; |
|
font-weight: bold; |
|
color: #000000; |
|
padding: 10px; |
|
border-radius: 10px; |
|
display: inline-block; |
|
background: linear-gradient( |
|
135deg, |
|
#e0f7fa, #e8f5e9, #fff9c4, #ffebee, |
|
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 |
|
); |
|
background-size: 400% 400%; |
|
animation: gradient-animation 15s ease infinite; |
|
} |
|
|
|
@keyframes gradient-animation { |
|
0% { background-position: 0% 50%; } |
|
50% { background-position: 100% 50%; } |
|
100% { background-position: 0% 50%; } |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.HTML(''' |
|
<div class="title-container"> |
|
<div id="title">Video Depth Estimation</div> |
|
</div> |
|
''') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video(label="Upload Video", interactive=True, show_label=True) |
|
fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS") |
|
seed_slider = gr.Slider(minimum=0, maximum=999999999, step=1, value=0, label="Seed") |
|
max_workers_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Max Workers") |
|
btn = gr.Button("Process Video", elem_id="submit-button") |
|
|
|
with gr.Column(): |
|
preview_image = gr.Image(label="Live Preview", show_label=True) |
|
output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)") |
|
output_video = gr.File(label="Download Video (MP4)") |
|
time_textbox = gr.Textbox(label="Status", interactive=False) |
|
|
|
btn.click(fn=process_wrapper |
|
|
|
, inputs=[video_input, fps_slider, seed_slider, max_workers_slider], |
|
outputs=[preview_image, output_frames_zip, output_video, time_textbox]) |
|
|
|
demo.queue() |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |