|
import gradio as gr |
|
import torch |
|
import os |
|
import tempfile |
|
import shutil |
|
import time |
|
import ffmpeg |
|
import numpy as np |
|
from PIL import Image |
|
import moviepy.editor as mp |
|
from infer import lotus |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)): |
|
"""Preprocess the video to resize and adjust its frame rate.""" |
|
video = mp.VideoFileClip(video_path) |
|
|
|
|
|
if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]: |
|
video = video.resize(newsize=max_resolution) |
|
|
|
|
|
if target_fps > 0: |
|
video = video.set_fps(target_fps) |
|
|
|
return video |
|
|
|
|
|
def process_frame(image, seed=0): |
|
"""Process a single frame through the depth model and return depth map.""" |
|
try: |
|
|
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
|
|
|
|
_, output_d = lotus(image, 'depth', seed, device) |
|
|
|
|
|
depth_array = np.array(output_d) |
|
return depth_array |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing frame: {e}") |
|
return None |
|
|
|
|
|
def process_video(video_path, fps=0, seed=0, batch_size=16): |
|
"""Process video, batch frames, and use L40s GPU to generate depth maps.""" |
|
try: |
|
start_time = time.time() |
|
|
|
|
|
video = preprocess_video(video_path, target_fps=fps) |
|
|
|
|
|
if fps == 0: |
|
fps = video.fps |
|
|
|
frames = list(video.iter_frames(fps=video.fps)) |
|
total_frames = len(frames) |
|
|
|
logger.info(f"Processing {total_frames} frames at {fps} FPS...") |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
frames_dir = os.path.join(temp_dir, "frames") |
|
os.makedirs(frames_dir, exist_ok=True) |
|
|
|
processed_frames = [] |
|
|
|
|
|
for i in range(0, total_frames, batch_size): |
|
frames_batch = frames[i:i+batch_size] |
|
depth_maps = [] |
|
|
|
|
|
for frame in frames_batch: |
|
depth_map = process_frame(Image.fromarray(frame), seed) |
|
depth_maps.append(depth_map) |
|
|
|
for j, depth_map in enumerate(depth_maps): |
|
if depth_map is not None: |
|
|
|
frame_index = i + j |
|
frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png") |
|
Image.fromarray(depth_map).save(frame_path) |
|
|
|
|
|
processed_frames.append(depth_map) |
|
|
|
|
|
if frame_index % max(1, total_frames // 10) == 0: |
|
elapsed_time = time.time() - start_time |
|
progress = (frame_index / total_frames) * 100 |
|
yield processed_frames[-1], None, None, f"Processed {frame_index}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s" |
|
else: |
|
logger.error(f"Error processing frame {frame_index}") |
|
|
|
logger.info("Creating output files...") |
|
|
|
|
|
zip_filename = f"depth_frames_{int(time.time())}.zip" |
|
zip_path = os.path.join(temp_dir, zip_filename) |
|
shutil.make_archive(zip_path[:-4], 'zip', frames_dir) |
|
|
|
|
|
video_filename = f"depth_video_{int(time.time())}.mp4" |
|
output_video_path = os.path.join(temp_dir, video_filename) |
|
|
|
try: |
|
|
|
( |
|
ffmpeg |
|
.input(os.path.join(frames_dir, 'frame_%06d.png'), pattern_type='sequence', framerate=fps) |
|
.output(output_video_path, vcodec='libx264', pix_fmt='yuv420p', crf=17) |
|
.run(overwrite_output=True) |
|
) |
|
logger.info("MP4 video created successfully!") |
|
|
|
except ffmpeg.Error as e: |
|
logger.error(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}") |
|
output_video_path = None |
|
|
|
total_time = time.time() - start_time |
|
logger.info("Processing complete!") |
|
|
|
|
|
with open(zip_path, 'rb') as f: |
|
zip_data = f.read() |
|
with open(output_video_path, 'rb') as f: |
|
video_data = f.read() |
|
|
|
yield None, (zip_filename, zip_data), (video_filename, video_data), f"Processing complete! Total time: {total_time:.2f} seconds" |
|
|
|
except Exception as e: |
|
logger.error(f"Error: {e}") |
|
yield None, None, None, f"Error processing video: {e}" |
|
|
|
|
|
def process_wrapper(video, fps=0, seed=0, batch_size=16): |
|
if video is None: |
|
raise gr.Error("Please upload a video.") |
|
try: |
|
outputs = [] |
|
|
|
for output in process_video(video, fps, seed, batch_size): |
|
outputs.append(output) |
|
yield output |
|
return outputs[-1] |
|
except Exception as e: |
|
raise gr.Error(f"Error processing video: {str(e)}") |
|
|
|
|
|
custom_css = """ |
|
.title-container { |
|
text-align: center; |
|
padding: 10px 0; |
|
} |
|
|
|
#title { |
|
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; |
|
font-size: 36px; |
|
font-weight: bold; |
|
color: #000000; |
|
padding: 10px; |
|
border-radius: 10px; |
|
display: inline-block; |
|
background: linear-gradient( |
|
135deg, |
|
#e0f7fa, #e8f5e9, #fff9c4, #ffebee, |
|
#f3e5f5, #e1f5fe, #fff3e0, #e8eaf6 |
|
); |
|
background-size: 400% 400%; |
|
animation: gradient-animation 15s ease infinite; |
|
} |
|
|
|
@keyframes gradient-animation { |
|
0% { background-position: 0% 50%; } |
|
50% { background-position: 100% 50%; } |
|
100% { background-position: 0% 50%; } |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
gr.HTML(''' |
|
<div class="title-container"> |
|
<div id="title">Video Depth Estimation</div> |
|
</div> |
|
''') |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video(label="Upload Video", interactive=True) |
|
fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)") |
|
seed_slider = gr.Number(value=0, label="Seed") |
|
batch_size_slider = gr.Slider(minimum=1, maximum=64, step=1, value=16, label="Batch Size") |
|
btn = gr.Button("Process Video") |
|
|
|
with gr.Column(): |
|
preview_image = gr.Image(label="Live Preview") |
|
output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)") |
|
output_video = gr.File(label="Download Video (MP4)") |
|
time_textbox = gr.Textbox(label="Status", interactive=False) |
|
|
|
btn.click( |
|
fn=process_wrapper, |
|
inputs=[video_input, fps_slider, seed_slider, batch_size_slider], |
|
outputs=[preview_image, output_frames_zip, output_video, time_textbox] |
|
) |
|
|
|
demo.queue() |
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |