File size: 11,045 Bytes
44189a1
09e9f28
44189a1
09e9f28
44189a1
09e9f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc78df8
09e9f28
dc78df8
 
09e9f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc78df8
09e9f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc78df8
09e9f28
 
 
 
 
 
 
 
 
 
 
dc78df8
09e9f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc78df8
09e9f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8020585
09e9f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc78df8
09e9f28
 
 
 
 
44189a1
09e9f28
44189a1
09e9f28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44189a1
dc78df8
09e9f28
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import gradio as gr
import torch
import spaces
import moviepy.editor as mp
from PIL import Image
import numpy as np
import tempfile
import time
import os
import shutil
import ffmpeg
from concurrent.futures import ThreadPoolExecutor
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts
from infer import lotus  # Import the depth model inference function

# Custom Theme Definition
class WhiteTheme(Base):
    def __init__(
        self,
        *,
        primary_hue: colors.Color | str = colors.orange,
        font: fonts.Font | str | tuple[fonts.Font | str, ...] = (
            fonts.GoogleFont("Inter"),
            "ui-sans-serif",
            "system-ui",
            "sans-serif",
        ),
        font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = (
            fonts.GoogleFont("Inter"),
            "ui-monospace",
            "system-ui",
            "monospace",
        )
    ):
        super().__init__(
            primary_hue=primary_hue,
            font=font,
            font_mono=font_mono,
        )
        
        self.set(
            background_fill_primary="*primary_50",
            background_fill_secondary="white",
            border_color_primary="*primary_300",
            body_background_fill="white",
            body_background_fill_dark="white",
            block_background_fill="white",
            block_background_fill_dark="white",
            panel_background_fill="white",
            panel_background_fill_dark="white",
            body_text_color="black",
            body_text_color_dark="black",
            block_label_text_color="black",
            block_label_text_color_dark="black",
            block_border_color="white",
            panel_border_color="white",
            input_border_color="lightgray",
            input_background_fill="white",
            input_background_fill_dark="white",
            shadow_drop="none"
        )

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def process_frame(frame, seed=0):
    """
    Process a single frame through the depth model.
    Returns the discriminative depth map.
    """
    try:
        # Convert frame to PIL Image
        image = Image.fromarray(frame)
        
        # Save temporary image (lotus requires a file path)
        with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
            image.save(tmp.name)
            
            # Process through lotus model
            _, output_d = lotus(tmp.name, 'depth', seed, device)
            
            # Clean up temp file
            os.unlink(tmp.name)
            
        # Convert depth output to numpy array
        depth_array = np.array(output_d)
        return depth_array
        
    except Exception as e:
        print(f"Error processing frame: {e}")
        return None

@spaces.GPU
def process_video(video_path, fps=0, seed=0, max_workers=6):
    """
    Process video to create depth map sequence and video.
    Maintains original resolution and framerate if fps=0.
    """
    temp_dir = None
    try:
        start_time = time.time()
        video = mp.VideoFileClip(video_path)
        
        # Use original video FPS if not specified
        if fps == 0:
            fps = video.fps
                
        frames = list(video.iter_frames(fps=fps))
        total_frames = len(frames)
        
        print(f"Processing {total_frames} frames at {fps} FPS...")
        
        # Create temporary directory for frame sequence
        temp_dir = tempfile.mkdtemp()
        frames_dir = os.path.join(temp_dir, "frames")
        os.makedirs(frames_dir, exist_ok=True)
        
        # Process frames with parallel execution
        processed_frames = []
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(process_frame, frame, seed) for frame in frames]
            for i, future in enumerate(futures):
                try:
                    result = future.result()
                    if result is not None:
                        # Save frame
                        frame_path = os.path.join(frames_dir, f"frame_{i:06d}.png")
                        Image.fromarray(result).save(frame_path)
                        
                        # Collect processed frame for preview
                        processed_frames.append(result)
                        
                        # Update preview
                        elapsed_time = time.time() - start_time
                        yield processed_frames[-1], None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds"
                        
                    if (i + 1) % 10 == 0:
                        print(f"Processed {i+1}/{total_frames} frames")
                except Exception as e:
                    print(f"Error processing frame {i+1}: {e}")
        
        print("Creating output files...")
        # Create output directory
        output_dir = os.path.join(os.path.dirname(video_path), "output")
        os.makedirs(output_dir, exist_ok=True)
        
        # Create ZIP of frame sequence
        zip_filename = f"depth_frames_{int(time.time())}.zip"
        zip_path = os.path.join(output_dir, zip_filename)
        shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
        
        # Create MP4 video
        print("Creating MP4 video...")
        video_filename = f"depth_video_{int(time.time())}.mp4"
        video_path = os.path.join(output_dir, video_filename)
        
        try:
            # FFmpeg settings for high-quality MP4
            stream = ffmpeg.input(
                os.path.join(frames_dir, 'frame_%06d.png'),
                pattern_type='sequence',
                framerate=fps
            )
            
            stream = ffmpeg.output(
                stream,
                video_path,
                vcodec='libx264',
                pix_fmt='yuv420p',
                crf=17,  # High quality
                threads=max_workers
            )
            
            ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
            print("MP4 video created successfully!")
            
        except ffmpeg.Error as e:
            print(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}")
            video_path = None
    
        print("Processing complete!")
        yield None, zip_path, video_path, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds"
        
    except Exception as e:
        print(f"Error: {e}")
        yield None, None, None, f"Error processing video: {e}"
    finally:
        if temp_dir and os.path.exists(temp_dir):
            try:
                shutil.rmtree(temp_dir)
            except Exception as e:
                print(f"Error cleaning up temp directory: {e}")

def process_wrapper(video, fps=0, seed=0, max_workers=6):
    if video is None:
        raise gr.Error("Please upload a video.")
    try:
        outputs = []
        for output in process_video(video, fps, seed, max_workers):
            outputs.append(output)
            yield output
        return outputs[-1]
    except Exception as e:
        raise gr.Error(f"Error processing video: {str(e)}")

# Custom CSS for styling
custom_css = """
    .title-container {
        text-align: center;
        padding: 10px 0;
    }
    
    #title {
        font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
        font-size: 36px;
        font-weight: bold;
        color: #000000;
        padding: 10px;
        border-radius: 10px;
        display: inline-block;
        background: linear-gradient(
            135deg,
            #e0f7fa, #e8f5e9, #fff9c4, #ffebee,
            #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
        );
        background-size: 400% 400%;
        animation: gradient-animation 15s ease infinite;
    }
    
    @keyframes gradient-animation {
        0% { background-position: 0% 50%; }
        50% { background-position: 100% 50%; }
        100% { background-position: 0% 50%; }
    }
"""

# Gradio Interface
with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
    gr.HTML('''
        <div class="title-container">
            <div id="title">Video Depth Estimation</div>
        </div>
    ''')
    
    with gr.Row():
        with gr.Column():
            video_input = gr.Video(
                label="Upload Video",
                interactive=True,
                show_label=True,
                height=360,
                width=640
            )
            with gr.Row():
                fps_slider = gr.Slider(
                    minimum=0,
                    maximum=60,
                    step=1,
                    value=0,
                    label="Output FPS (0 will inherit the original fps value)",
                )
                seed_slider = gr.Slider(
                    minimum=0,
                    maximum=999999999,
                    step=1,
                    value=0,
                    label="Seed",
                )
                max_workers_slider = gr.Slider(
                    minimum=1,
                    maximum=32,
                    step=1,
                    value=6,
                    label="Max Workers",
                    info="Determines how many frames to process in parallel"
                )
            btn = gr.Button("Process Video", elem_id="submit-button")
        
        with gr.Column():
            preview_image = gr.Image(label="Live Preview", show_label=True)
            output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)")
            output_video = gr.File(label="Download Video (MP4)")
            time_textbox = gr.Textbox(label="Status", interactive=False)
            
            gr.Markdown("""
            ### Output Information
            - High-quality MP4 video output
            - Original resolution and framerate are maintained
            - Frame sequence provided for maximum compatibility
            """)

    btn.click(
        fn=process_wrapper,
        inputs=[video_input, fps_slider, seed_slider, max_workers_slider],
        outputs=[preview_image, output_frames_zip, output_video, time_textbox]
    )

    demo.queue()

    api = gr.Interface(
        fn=process_wrapper,
        inputs=[
            gr.Video(label="Upload Video"),
            gr.Number(label="FPS", value=0),
            gr.Number(label="Seed", value=0),
            gr.Number(label="Max Workers", value=6)
        ],
        outputs=[
            gr.Image(label="Preview"),
            gr.File(label="Frame Sequence"),
            gr.File(label="Video"),
            gr.Textbox(label="Status")
        ],
        title="Video Depth Estimation API",
        description="Generate depth maps from videos",
        api_name="/process_video"
    )

if __name__ == "__main__":
    demo.launch(debug=True, show_error=True, share=False, server_name="0.0.0.0", server_port=7860)