ghostsInTheMachine commited on
Commit
48a87fd
1 Parent(s): 3d7224a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -304
app.py CHANGED
@@ -1,339 +1,115 @@
1
  import gradio as gr
2
  import torch
3
- import spaces
4
- import moviepy.editor as mp
5
- from PIL import Image
6
- import numpy as np
7
- import tempfile
8
- import time
9
  import os
 
 
 
10
  import shutil
11
- import ffmpeg
12
  from concurrent.futures import ThreadPoolExecutor
13
- from gradio.themes.base import Base
14
- from gradio.themes.utils import colors, fonts
15
- from infer import lotus # Import the depth model inference function
16
-
17
- # Custom Theme Definition
18
- class WhiteTheme(Base):
19
- def __init__(
20
- self,
21
- *,
22
- primary_hue: colors.Color | str = colors.orange,
23
- font: fonts.Font | str | tuple[fonts.Font | str, ...] = (
24
- fonts.GoogleFont("Inter"),
25
- "ui-sans-serif",
26
- "system-ui",
27
- "sans-serif",
28
- ),
29
- font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = (
30
- fonts.GoogleFont("Inter"),
31
- "ui-monospace",
32
- "system-ui",
33
- "monospace",
34
- )
35
- ):
36
- super().__init__(
37
- primary_hue=primary_hue,
38
- font=font,
39
- font_mono=font_mono,
40
- )
41
-
42
- self.set(
43
- background_fill_primary="*primary_50",
44
- background_fill_secondary="white",
45
- border_color_primary="*primary_300",
46
- body_background_fill="white",
47
- body_background_fill_dark="white",
48
- block_background_fill="white",
49
- block_background_fill_dark="white",
50
- panel_background_fill="white",
51
- panel_background_fill_dark="white",
52
- body_text_color="black",
53
- body_text_color_dark="black",
54
- block_label_text_color="black",
55
- block_label_text_color_dark="black",
56
- block_border_color="white",
57
- panel_border_color="white",
58
- input_border_color="lightgray",
59
- input_background_fill="white",
60
- input_background_fill_dark="white",
61
- shadow_drop="none"
62
- )
63
 
64
  # Set device
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
 
67
- # Add the preprocess_video function to limit video resolution and frame rate
68
- def preprocess_video(video_path, target_fps=24, max_resolution=(640, 360)):
69
- """Preprocess the video to reduce its resolution and frame rate."""
70
- video = mp.VideoFileClip(video_path)
 
 
 
 
 
71
 
72
- # Resize video if it's larger than the target resolution
73
- if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]:
74
- video = video.resize(newsize=max_resolution)
75
 
76
- # Limit FPS
77
- video = video.set_fps(target_fps)
78
 
79
- return video
80
 
81
- def process_frame(frame, seed=0, start_time=None):
82
- """
83
- Process a single frame through the depth model.
84
- Returns the discriminative depth map.
85
- """
86
- try:
87
- # Convert frame to PIL Image
88
- image = Image.fromarray(frame)
89
-
90
- # Save temporary image (lotus requires a file path)
91
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
92
- image.save(tmp.name)
93
-
94
- # Process through lotus model
95
- _, output_d = lotus(tmp.name, 'depth', seed, device)
96
-
97
- # Clean up temp file
98
- os.unlink(tmp.name)
99
-
100
- # Convert depth output to numpy array
101
- depth_array = np.array(output_d)
102
- return depth_array
103
-
104
- except Exception as e:
105
- print(f"Error processing frame: {e}")
106
- return None
107
 
108
- @spaces.GPU
109
- def process_video(video_path, fps=0, seed=0, max_workers=2):
110
  """
111
- Process video to create depth map sequence and video.
112
- Maintains original resolution and framerate if fps=0.
113
  """
114
- temp_dir = None
115
- try:
116
- # Initialize start_time here for use in process_frame
117
- start_time = time.time()
118
-
119
- # Preprocess the video
120
- video = preprocess_video(video_path)
121
-
122
- # Use original video FPS if not specified
123
- if fps == 0:
124
- fps = video.fps
125
-
126
- frames = list(video.iter_frames(fps=fps))
127
- total_frames = len(frames)
128
-
129
- print(f"Processing {total_frames} frames at {fps} FPS...")
130
-
131
- # Create temporary directory for frame sequence
132
- temp_dir = tempfile.mkdtemp()
133
- frames_dir = os.path.join(temp_dir, "frames")
134
- os.makedirs(frames_dir, exist_ok=True)
135
-
136
- # Process frames in batches of 10
137
- processed_frames = []
138
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
139
- for i in range(0, total_frames, 10): # Process 10 frames at a time
140
- futures = [executor.submit(process_frame, frames[j], seed, start_time) for j in range(i, min(i + 10, total_frames))]
141
- for j, future in enumerate(futures):
142
- try:
143
- result = future.result()
144
- if result is not None:
145
- # Save frame
146
- frame_path = os.path.join(frames_dir, f"frame_{i+j:06d}.png")
147
- Image.fromarray(result).save(frame_path)
148
-
149
- # Collect processed frame for preview
150
- processed_frames.append(result)
151
-
152
- # Update preview
153
- elapsed_time = time.time() - start_time
154
- yield processed_frames[-1], None, None, f"Processing frame {i+j+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds"
155
-
156
- if (i + j + 1) % 10 == 0:
157
- print(f"Processed {i + j + 1}/{total_frames} frames")
158
- except Exception as e:
159
- print(f"Error processing frame {i + j + 1}: {e}")
160
-
161
- print("Creating output files...")
162
- # Create output directory
163
- output_dir = os.path.join(os.path.dirname(video_path), "output")
164
- os.makedirs(output_dir, exist_ok=True)
165
 
166
- # Create ZIP of frame sequence
167
- zip_filename = f"depth_frames_{int(time.time())}.zip"
168
- zip_path = os.path.join(output_dir, zip_filename)
169
- shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
170
 
171
- # Create MP4 video
172
- print("Creating MP4 video...")
173
- video_filename = f"depth_video_{int(time.time())}.mp4"
174
- video_path = os.path.join(output_dir, video_filename)
175
 
176
- try:
177
- # FFmpeg settings for high-quality MP4
178
- stream = ffmpeg.input(
179
- os.path.join(frames_dir, 'frame_%06d.png'),
180
- pattern_type='sequence',
181
- framerate=fps
182
- )
183
-
184
- stream = ffmpeg.output(
185
- stream,
186
- video_path,
187
- vcodec='libx264',
188
- pix_fmt='yuv420p',
189
- crf=17, # High quality
190
- threads=max_workers
191
- )
192
-
193
- ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
194
- print("MP4 video created successfully!")
195
-
196
- except ffmpeg.Error as e:
197
- print(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}")
198
- video_path = None
199
 
200
- print("Processing complete!")
201
- yield None, zip_path, video_path, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds"
202
-
203
- except Exception as e:
204
- print(f"Error: {e}")
205
- yield None, None, None, f"Error processing video: {e}"
206
- finally:
207
- if temp_dir and os.path.exists(temp_dir):
208
- try:
209
- shutil.rmtree(temp_dir)
210
- except Exception as e:
211
- print(f"Error cleaning up temp directory: {e}")
212
-
213
- def process_wrapper(video, fps=0, seed=0, max_workers=6):
214
- if video is None:
215
- raise gr.Error("Please upload a video.")
216
- try:
217
- outputs = []
218
- for output in process_video(video, fps, seed, max_workers):
219
- outputs.append(output)
220
- yield output
221
- return outputs[-1]
222
- except Exception as e:
223
- raise gr.Error(f"Error processing video: {str(e)}")
224
-
225
- # Custom CSS for styling
226
- custom_css = """
227
- .title-container {
228
- text-align: center;
229
- padding: 10px 0;
230
- }
231
 
232
- #title {
233
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
234
- font-size: 36px;
235
- font-weight: bold;
236
- color: #000000;
237
- padding: 10px;
238
- border-radius: 10px;
239
- display: inline-block;
240
- background: linear-gradient(
241
- 135deg,
242
- #e0f7fa, #e8f5e9, #fff9c4, #ffebee,
243
- #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
244
- );
245
- background-size: 400% 400%;
246
- animation: gradient-animation 15s ease infinite;
247
- }
248
 
249
- @keyframes gradient-animation {
250
- 0% { background-position: 0% 50%; }
251
- 50% { background-position: 100% 50%; }
252
- 100% { background-position: 0% 50%; }
253
- }
254
- """
255
 
256
  # Gradio Interface
257
- with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
258
- gr.HTML('''
259
- <div class="title-container">
260
- <div id="title">Video Depth Estimation</div>
261
- </div>
262
- ''')
263
 
264
  with gr.Row():
265
  with gr.Column():
266
  video_input = gr.Video(
267
  label="Upload Video",
268
  interactive=True,
269
- show_label=True,
270
- height=360,
271
- width=640
 
 
 
272
  )
273
- with gr.Row():
274
- fps_slider = gr.Slider(
275
- minimum=0,
276
- maximum=60,
277
- step=1,
278
- value=0,
279
- label="Output FPS (0 will inherit the original fps value)",
280
- )
281
- seed_slider = gr.Slider(
282
- minimum=0,
283
- maximum=999999999,
284
- step=1,
285
- value=0,
286
- label="Seed",
287
- )
288
- max_workers_slider = gr.Slider(
289
- minimum=1,
290
- maximum=32,
291
- step=1,
292
- value=6,
293
- label="Max Workers",
294
- info="Determines how many frames to process in parallel"
295
- )
296
- btn = gr.Button("Process Video", elem_id="submit-button")
297
 
298
  with gr.Column():
299
- preview_image = gr.Image(label="Live Preview", show_label=True)
300
- output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)")
301
- output_video = gr.File(label="Download Video (MP4)")
302
- time_textbox = gr.Textbox(label="Status", interactive=False)
303
-
304
- gr.Markdown("""
305
- ### Output Information
306
- - High-quality MP4 video output
307
- - Original resolution and framerate are maintained
308
- - Frame sequence provided for maximum compatibility
309
- """)
310
-
311
- btn.click(
312
- fn=process_wrapper,
313
- inputs=[video_input, fps_slider, seed_slider, max_workers_slider],
314
- outputs=[preview_image, output_frames_zip, output_video, time_textbox]
315
- )
316
-
317
- demo.queue()
318
-
319
- api = gr.Interface(
320
- fn=process_wrapper,
321
- inputs=[
322
- gr.Video(label="Upload Video"),
323
- gr.Number(label="FPS", value=0),
324
- gr.Number(label="Seed", value=0),
325
- gr.Number(label="Max Workers", value=6)
326
- ],
327
- outputs=[
328
- gr.Image(label="Preview"),
329
- gr.File(label="Frame Sequence"),
330
- gr.File(label="Video"),
331
- gr.Textbox(label="Status")
332
- ],
333
- title="Video Depth Estimation API",
334
- description="Generate depth maps from videos",
335
- api_name="/process_video"
336
  )
337
 
338
- if __name__ == "__main__":
339
- demo.launch(debug=True, show_error=True, share=False, server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import torch
 
 
 
 
 
 
3
  import os
4
+ import tempfile
5
+ import imageio
6
+ import numpy as np
7
  import shutil
8
+ from PIL import Image
9
  from concurrent.futures import ThreadPoolExecutor
10
+ import ffmpeg
11
+ from infer import lotus, lotus_video # Import the depth model inference function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Set device
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ def process_frame(path_input, seed):
17
+ """
18
+ Process a single frame through the depth model.
19
+ Returns the original and depth-processed images.
20
+ """
21
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
22
+
23
+ # Process the frame with the model
24
+ output_g, output_d = lotus(path_input, 'depth', seed, device)
25
 
26
+ # Save generated and depth maps to temporary paths
27
+ g_save_path = os.path.join(tempfile.gettempdir(), f"{name_base}_g{name_ext}")
28
+ d_save_path = os.path.join(tempfile.gettempdir(), f"{name_base}_d{name_ext}")
29
 
30
+ output_g.save(g_save_path)
31
+ output_d.save(d_save_path)
32
 
33
+ return [path_input, g_save_path], [path_input, d_save_path]
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def process_video_live(path_input, seed):
 
37
  """
38
+ Process video frame-by-frame, showing each processed frame live in the preview and compile the final video.
 
39
  """
40
+ temp_dir = tempfile.mkdtemp()
41
+
42
+ # Extract video frames
43
+ video = imageio.get_reader(path_input)
44
+ fps = video.get_meta_data()['fps']
45
+ frames = [frame for frame in video]
46
+ total_frames = len(frames)
47
+
48
+ print(f"Processing {total_frames} frames at {fps} FPS...")
49
+
50
+ processed_frames_g = []
51
+ processed_frames_d = []
52
+
53
+ for i, frame in enumerate(frames):
54
+ frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
55
+ Image.fromarray(frame).save(frame_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Process the frame using the lotus model
58
+ output_g_paths, output_d_paths = process_frame(frame_path, seed)
 
 
59
 
60
+ # Append processed frames for final video compilation
61
+ processed_frames_g.append(imageio.imread(output_g_paths[1]))
62
+ processed_frames_d.append(imageio.imread(output_d_paths[1]))
 
63
 
64
+ # Update the live preview
65
+ yield output_g_paths[1], output_d_paths[1], f"Processing frame {i+1}/{total_frames}..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # Compile final videos
68
+ g_video_path = os.path.join(temp_dir, "output_g.mp4")
69
+ d_video_path = os.path.join(temp_dir, "output_d.mp4")
70
+
71
+ imageio.mimsave(g_video_path, processed_frames_g, fps=fps)
72
+ imageio.mimsave(d_video_path, processed_frames_d, fps=fps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # Clean up temporary directory
75
+ if os.path.exists(temp_dir):
76
+ try:
77
+ shutil.rmtree(temp_dir)
78
+ except Exception as e:
79
+ print(f"Error cleaning up temp directory: {e}")
 
 
 
 
 
 
 
 
 
 
80
 
81
+ return g_video_path, d_video_path
82
+
 
 
 
 
83
 
84
  # Gradio Interface
85
+ with gr.Blocks() as demo:
86
+ gr.Markdown("# Video Depth Estimation: Live Frame Processing and Video Compilation")
 
 
 
 
87
 
88
  with gr.Row():
89
  with gr.Column():
90
  video_input = gr.Video(
91
  label="Upload Video",
92
  interactive=True,
93
+ show_label=True
94
+ )
95
+ seed_input = gr.Number(
96
+ label="Seed",
97
+ value=0,
98
+ interactive=True
99
  )
100
+ process_btn = gr.Button("Process Video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  with gr.Column():
103
+ live_preview_g = gr.Image(label="Live Preview (Generative)", show_label=True)
104
+ live_preview_d = gr.Image(label="Live Preview (Discriminative)", show_label=True)
105
+ status_text = gr.Textbox(label="Status", interactive=False)
106
+ final_g_video = gr.Video(label="Final Generative Video")
107
+ final_d_video = gr.Video(label="Final Discriminative Video")
108
+
109
+ process_btn.click(
110
+ fn=process_video_live,
111
+ inputs=[video_input, seed_input],
112
+ outputs=[live_preview_g, live_preview_d, status_text, final_g_video, final_d_video]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  )
114
 
115
+ demo.launch(debug=True)