ghostsInTheMachine commited on
Commit
09e9f28
1 Parent(s): e2ac6fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -177
app.py CHANGED
@@ -1,193 +1,321 @@
1
- from gradio_imageslider import ImageSlider
2
- import functools
3
- import os
4
- import tempfile
5
- import diffusers
6
  import gradio as gr
7
- import imageio as imageio
8
- import numpy as np
9
  import spaces
10
- import torch as torch
11
  from PIL import Image
12
- from tqdm import tqdm
13
- from pathlib import Path
14
- import gradio
15
- from gradio.utils import get_cache_folder
16
- from infer import lotus, lotus_video
17
- import transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- transformers.utils.move_cache()
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- def infer(path_input, seed):
23
- name_base, name_ext = os.path.splitext(os.path.basename(path_input))
24
- output_g, output_d = lotus(path_input, 'depth', seed, device)
25
- if not os.path.exists("files/output"):
26
- os.makedirs("files/output")
27
- g_save_path = os.path.join("files/output", f"{name_base}_g{name_ext}")
28
- d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}")
29
- output_g.save(g_save_path)
30
- output_d.save(d_save_path)
31
- return [path_input, g_save_path], [path_input, d_save_path]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def infer_video(path_input, seed):
34
- frames_g, frames_d = lotus_video(path_input, 'depth', seed, device)
35
- if not os.path.exists("files/output"):
36
- os.makedirs("files/output")
37
- name_base, _ = os.path.splitext(os.path.basename(path_input))
38
- g_save_path = os.path.join("files/output", f"{name_base}_g.mp4")
39
- d_save_path = os.path.join("files/output", f"{name_base}_d.mp4")
40
- imageio.mimsave(g_save_path, frames_g)
41
- imageio.mimsave(d_save_path, frames_d)
42
- return [g_save_path, d_save_path]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def run_demo_server():
45
- infer_gpu = spaces.GPU(functools.partial(infer))
46
- gradio_theme = gr.themes.Default()
 
 
 
 
 
 
 
 
47
 
48
- with gr.Blocks(
49
- theme=gradio_theme,
50
- title="LOTUS (Depth)",
51
- css="""
52
- #download {
53
- height: 118px;
54
- }
55
- .slider .inner {
56
- width: 5px;
57
- background: #FFF;
58
- }
59
- .viewport {
60
- aspect-ratio: 4/3;
61
- }
62
- .tabs button.selected {
63
- font-size: 20px !important;
64
- color: crimson !important;
65
- }
66
- h1 {
67
- text-align: center;
68
- display: block;
69
- }
70
- h2 {
71
- text-align: center;
72
- display: block;
73
- }
74
- h3 {
75
- text-align: center;
76
- display: block;
77
- }
78
- .md_feedback li {
79
- margin-bottom: 0px !important;
80
- }
81
- """,
82
- head="""
83
- <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
84
- <script>
85
- window.dataLayer = window.dataLayer || [];
86
- function gtag() {dataLayer.push(arguments);}
87
- gtag('js', new Date());
88
- gtag('config', 'G-1FWSVCGZTG');
89
- </script>
90
- """,
91
- ) as demo:
92
- gr.Markdown(
93
- """
94
- # LOTUS: Diffusion-based Visual Foundation Model for High-quality Dense Prediction
95
- <p align="center">
96
- <a title="Page" href="https://lotus3d.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
97
- <img src="https://img.shields.io/badge/Project-Website-pink?logo=googlechrome&logoColor=white">
98
- </a>
99
- <a title="arXiv" href="https://arxiv.org/abs/2409.18124" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
100
- <img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white">
101
- </a>
102
- <a title="Github" href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
103
- <img src="https://img.shields.io/github/stars/EnVision-Research/Lotus?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
104
- </a>
105
- <a title="Social" href="https://x.com/Jingheya/status/1839553365870784563" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
106
- <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
107
- </a>
108
- <a title="Social" href="https://x.com/haodongli00/status/1839524569058582884" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
109
- <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
110
- </a>
111
- <br>
112
- <strong>Please consider starring <span style="color: orange">&#9733;</span> the <a href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer">GitHub Repo</a> if you find this useful!</strong>
113
- """
114
- )
115
- with gr.Tabs(elem_classes=["tabs"]):
116
- with gr.Row():
117
- with gr.Column():
118
- image_input = gr.Image(
119
- label="Input Image",
120
- type="filepath",
121
- )
122
- seed = gr.Number(
123
- label="Seed (only for Generative mode)",
124
- minimum=0,
125
- maximum=999999999,
126
- )
127
- with gr.Row():
128
- image_submit_btn = gr.Button(
129
- value="Predict Depth!", variant="primary"
130
- )
131
- image_reset_btn = gr.Button(value="Reset")
132
- with gr.Column():
133
- image_output_g = ImageSlider(
134
- label="Output (Generative)",
135
- type="filepath",
136
- interactive=False,
137
- elem_classes="slider",
138
- position=0.25,
139
- )
140
- with gr.Row():
141
- image_output_d = ImageSlider(
142
- label="Output (Discriminative)",
143
- type="filepath",
144
- interactive=False,
145
- elem_classes="slider",
146
- position=0.25,
147
- )
148
 
149
- gr.Examples(
150
- fn=infer_gpu,
151
- examples=sorted([
152
- [os.path.join("files", "images", name), 0]
153
- for name in os.listdir(os.path.join("files", "images"))
154
- ]),
155
- inputs=[image_input, seed],
156
- outputs=[image_output_g, image_output_d],
157
- cache_examples=False,
 
 
 
 
 
 
 
158
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- ### Image
161
- image_submit_btn.click(
162
- fn=infer_gpu,
163
- inputs=[image_input, seed],
164
- outputs=[image_output_g, image_output_d],
165
- concurrency_limit=1,
166
- )
167
- image_reset_btn.click(
168
- fn=lambda: (
169
- None,
170
- None,
171
- None,
172
- ),
173
- inputs=[],
174
- outputs=[image_output_g, image_output_d],
175
- queue=False,
176
- )
177
 
178
- ### Server launch
179
- demo.queue(
180
- api_open=False,
181
- ).launch(
182
- server_name="0.0.0.0",
183
- server_port=7860,
184
- )
185
 
186
- def main():
187
- os.system("pip freeze")
188
- if os.path.exists("files/output"):
189
- os.system("rm -rf files/output")
190
- run_demo_server()
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  if __name__ == "__main__":
193
- main()
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
 
3
  import spaces
4
+ import moviepy.editor as mp
5
  from PIL import Image
6
+ import numpy as np
7
+ import tempfile
8
+ import time
9
+ import os
10
+ import shutil
11
+ import ffmpeg
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from gradio.themes.base import Base
14
+ from gradio.themes.utils import colors, fonts
15
+ from infer import lotus # Import the depth model inference function
16
+
17
+ # Custom Theme Definition
18
+ class WhiteTheme(Base):
19
+ def __init__(
20
+ self,
21
+ *,
22
+ primary_hue: colors.Color | str = colors.orange,
23
+ font: fonts.Font | str | tuple[fonts.Font | str, ...] = (
24
+ fonts.GoogleFont("Inter"),
25
+ "ui-sans-serif",
26
+ "system-ui",
27
+ "sans-serif",
28
+ ),
29
+ font_mono: fonts.Font | str | tuple[fonts.Font | str, ...] = (
30
+ fonts.GoogleFont("Inter"),
31
+ "ui-monospace",
32
+ "system-ui",
33
+ "monospace",
34
+ )
35
+ ):
36
+ super().__init__(
37
+ primary_hue=primary_hue,
38
+ font=font,
39
+ font_mono=font_mono,
40
+ )
41
+
42
+ self.set(
43
+ background_fill_primary="*primary_50",
44
+ background_fill_secondary="white",
45
+ border_color_primary="*primary_300",
46
+ body_background_fill="white",
47
+ body_background_fill_dark="white",
48
+ block_background_fill="white",
49
+ block_background_fill_dark="white",
50
+ panel_background_fill="white",
51
+ panel_background_fill_dark="white",
52
+ body_text_color="black",
53
+ body_text_color_dark="black",
54
+ block_label_text_color="black",
55
+ block_label_text_color_dark="black",
56
+ block_border_color="white",
57
+ panel_border_color="white",
58
+ input_border_color="lightgray",
59
+ input_background_fill="white",
60
+ input_background_fill_dark="white",
61
+ shadow_drop="none"
62
+ )
63
 
64
+ # Set device
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
 
67
+ def process_frame(frame, seed=0):
68
+ """
69
+ Process a single frame through the depth model.
70
+ Returns the discriminative depth map.
71
+ """
72
+ try:
73
+ # Convert frame to PIL Image
74
+ image = Image.fromarray(frame)
75
+
76
+ # Save temporary image (lotus requires a file path)
77
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp:
78
+ image.save(tmp.name)
79
+
80
+ # Process through lotus model
81
+ _, output_d = lotus(tmp.name, 'depth', seed, device)
82
+
83
+ # Clean up temp file
84
+ os.unlink(tmp.name)
85
+
86
+ # Convert depth output to numpy array
87
+ depth_array = np.array(output_d)
88
+ return depth_array
89
+
90
+ except Exception as e:
91
+ print(f"Error processing frame: {e}")
92
+ return None
93
 
94
+ @spaces.GPU
95
+ def process_video(video_path, fps=0, seed=0, max_workers=6):
96
+ """
97
+ Process video to create depth map sequence and video.
98
+ Maintains original resolution and framerate if fps=0.
99
+ """
100
+ temp_dir = None
101
+ try:
102
+ start_time = time.time()
103
+ video = mp.VideoFileClip(video_path)
104
+
105
+ # Use original video FPS if not specified
106
+ if fps == 0:
107
+ fps = video.fps
108
+
109
+ frames = list(video.iter_frames(fps=fps))
110
+ total_frames = len(frames)
111
+
112
+ print(f"Processing {total_frames} frames at {fps} FPS...")
113
+
114
+ # Create temporary directory for frame sequence
115
+ temp_dir = tempfile.mkdtemp()
116
+ frames_dir = os.path.join(temp_dir, "frames")
117
+ os.makedirs(frames_dir, exist_ok=True)
118
+
119
+ # Process frames with parallel execution
120
+ processed_frames = []
121
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
122
+ futures = [executor.submit(process_frame, frame, seed) for frame in frames]
123
+ for i, future in enumerate(futures):
124
+ try:
125
+ result = future.result()
126
+ if result is not None:
127
+ # Save frame
128
+ frame_path = os.path.join(frames_dir, f"frame_{i:06d}.png")
129
+ Image.fromarray(result).save(frame_path)
130
+
131
+ # Collect processed frame for preview
132
+ processed_frames.append(result)
133
+
134
+ # Update preview
135
+ elapsed_time = time.time() - start_time
136
+ yield processed_frames[-1], None, None, f"Processing frame {i+1}/{total_frames}... Elapsed time: {elapsed_time:.2f} seconds"
137
+
138
+ if (i + 1) % 10 == 0:
139
+ print(f"Processed {i+1}/{total_frames} frames")
140
+ except Exception as e:
141
+ print(f"Error processing frame {i+1}: {e}")
142
+
143
+ print("Creating output files...")
144
+ # Create output directory
145
+ output_dir = os.path.join(os.path.dirname(video_path), "output")
146
+ os.makedirs(output_dir, exist_ok=True)
147
+
148
+ # Create ZIP of frame sequence
149
+ zip_filename = f"depth_frames_{int(time.time())}.zip"
150
+ zip_path = os.path.join(output_dir, zip_filename)
151
+ shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
152
+
153
+ # Create MP4 video
154
+ print("Creating MP4 video...")
155
+ video_filename = f"depth_video_{int(time.time())}.mp4"
156
+ video_path = os.path.join(output_dir, video_filename)
157
+
158
+ try:
159
+ # FFmpeg settings for high-quality MP4
160
+ stream = ffmpeg.input(
161
+ os.path.join(frames_dir, 'frame_%06d.png'),
162
+ pattern_type='sequence',
163
+ framerate=fps
164
+ )
165
+
166
+ stream = ffmpeg.output(
167
+ stream,
168
+ video_path,
169
+ vcodec='libx264',
170
+ pix_fmt='yuv420p',
171
+ crf=17, # High quality
172
+ threads=max_workers
173
+ )
174
+
175
+ ffmpeg.run(stream, overwrite_output=True, capture_stdout=True, capture_stderr=True)
176
+ print("MP4 video created successfully!")
177
+
178
+ except ffmpeg.Error as e:
179
+ print(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}")
180
+ video_path = None
181
+
182
+ print("Processing complete!")
183
+ yield None, zip_path, video_path, f"Processing complete! Total time: {time.time() - start_time:.2f} seconds"
184
+
185
+ except Exception as e:
186
+ print(f"Error: {e}")
187
+ yield None, None, None, f"Error processing video: {e}"
188
+ finally:
189
+ if temp_dir and os.path.exists(temp_dir):
190
+ try:
191
+ shutil.rmtree(temp_dir)
192
+ except Exception as e:
193
+ print(f"Error cleaning up temp directory: {e}")
194
 
195
+ def process_wrapper(video, fps=0, seed=0, max_workers=6):
196
+ if video is None:
197
+ raise gr.Error("Please upload a video.")
198
+ try:
199
+ outputs = []
200
+ for output in process_video(video, fps, seed, max_workers):
201
+ outputs.append(output)
202
+ yield output
203
+ return outputs[-1]
204
+ except Exception as e:
205
+ raise gr.Error(f"Error processing video: {str(e)}")
206
 
207
+ # Custom CSS for styling
208
+ custom_css = """
209
+ .title-container {
210
+ text-align: center;
211
+ padding: 10px 0;
212
+ }
213
+
214
+ #title {
215
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
216
+ font-size: 36px;
217
+ font-weight: bold;
218
+ color: #000000;
219
+ padding: 10px;
220
+ border-radius: 10px;
221
+ display: inline-block;
222
+ background: linear-gradient(
223
+ 135deg,
224
+ #e0f7fa, #e8f5e9, #fff9c4, #ffebee,
225
+ #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
226
+ );
227
+ background-size: 400% 400%;
228
+ animation: gradient-animation 15s ease infinite;
229
+ }
230
+
231
+ @keyframes gradient-animation {
232
+ 0% { background-position: 0% 50%; }
233
+ 50% { background-position: 100% 50%; }
234
+ 100% { background-position: 0% 50%; }
235
+ }
236
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ # Gradio Interface
239
+ with gr.Blocks(css=custom_css, theme=WhiteTheme()) as demo:
240
+ gr.HTML('''
241
+ <div class="title-container">
242
+ <div id="title">Video Depth Estimation</div>
243
+ </div>
244
+ ''')
245
+
246
+ with gr.Row():
247
+ with gr.Column():
248
+ video_input = gr.Video(
249
+ label="Upload Video",
250
+ interactive=True,
251
+ show_label=True,
252
+ height=360,
253
+ width=640
254
  )
255
+ with gr.Row():
256
+ fps_slider = gr.Slider(
257
+ minimum=0,
258
+ maximum=60,
259
+ step=1,
260
+ value=0,
261
+ label="Output FPS (0 will inherit the original fps value)",
262
+ )
263
+ seed_slider = gr.Slider(
264
+ minimum=0,
265
+ maximum=999999999,
266
+ step=1,
267
+ value=0,
268
+ label="Seed",
269
+ )
270
+ max_workers_slider = gr.Slider(
271
+ minimum=1,
272
+ maximum=32,
273
+ step=1,
274
+ value=6,
275
+ label="Max Workers",
276
+ info="Determines how many frames to process in parallel"
277
+ )
278
+ btn = gr.Button("Process Video", elem_id="submit-button")
279
+
280
+ with gr.Column():
281
+ preview_image = gr.Image(label="Live Preview", show_label=True)
282
+ output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)")
283
+ output_video = gr.File(label="Download Video (MP4)")
284
+ time_textbox = gr.Textbox(label="Status", interactive=False)
285
+
286
+ gr.Markdown("""
287
+ ### Output Information
288
+ - High-quality MP4 video output
289
+ - Original resolution and framerate are maintained
290
+ - Frame sequence provided for maximum compatibility
291
+ """)
292
 
293
+ btn.click(
294
+ fn=process_wrapper,
295
+ inputs=[video_input, fps_slider, seed_slider, max_workers_slider],
296
+ outputs=[preview_image, output_frames_zip, output_video, time_textbox]
297
+ )
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ demo.queue()
 
 
 
 
 
 
300
 
301
+ api = gr.Interface(
302
+ fn=process_wrapper,
303
+ inputs=[
304
+ gr.Video(label="Upload Video"),
305
+ gr.Number(label="FPS", value=0),
306
+ gr.Number(label="Seed", value=0),
307
+ gr.Number(label="Max Workers", value=6)
308
+ ],
309
+ outputs=[
310
+ gr.Image(label="Preview"),
311
+ gr.File(label="Frame Sequence"),
312
+ gr.File(label="Video"),
313
+ gr.Textbox(label="Status")
314
+ ],
315
+ title="Video Depth Estimation API",
316
+ description="Generate depth maps from videos",
317
+ api_name="/process_video"
318
+ )
319
 
320
  if __name__ == "__main__":
321
+ demo.launch(debug=True, show_error=True, share=False, server_name="0.0.0.0", server_port=7860)