ghostsInTheMachine commited on
Commit
afe7cc3
1 Parent(s): 8e07e41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -72
app.py CHANGED
@@ -8,10 +8,8 @@ import ffmpeg
8
  import numpy as np
9
  from PIL import Image
10
  import moviepy.editor as mp
11
- from infer import lotus # Import the depth model inference function
12
  import logging
13
- import io
14
- from multiprocessing import Pool, cpu_count
15
 
16
  # Set up logging
17
  logging.basicConfig(level=logging.INFO)
@@ -20,105 +18,98 @@ logger = logging.getLogger(__name__)
20
  # Set device to use the L40s GPU explicitly
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
 
 
 
 
23
  # Preprocess the video to adjust resolution and frame rate
24
  def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
25
  """Preprocess the video to resize and adjust its frame rate."""
26
  video = mp.VideoFileClip(video_path)
27
-
28
  # Resize video if it's larger than the target resolution
29
  if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]:
30
  video = video.resize(height=max_resolution[1])
31
-
32
  # Adjust FPS if target_fps is specified
33
  if target_fps > 0:
34
  video = video.set_fps(target_fps)
35
-
36
  return video
37
 
38
- # Process a single frame through the depth model
39
- def process_frame(args):
40
- """Process a single frame through the depth model and return depth map."""
41
- frame_index, frame_data, seed = args
42
  try:
43
- # Set seeds for reproducibility
44
- torch.manual_seed(seed)
45
- np.random.seed(seed)
46
-
47
- # Convert frame data to PIL Image
48
- image = Image.fromarray(frame_data).convert('RGB')
49
-
50
- # Save image to an in-memory file
51
- img_bytes = io.BytesIO()
52
- image.save(img_bytes, format='PNG')
53
- img_bytes.seek(0) # Reset file pointer to the beginning
54
-
55
- # Process through the depth model
56
- _, output_d = lotus(img_bytes, 'depth', seed, device)
57
-
58
- # Convert depth output to numpy array
59
- depth_array = np.array(output_d)
60
- return (frame_index, depth_array)
61
-
62
  except Exception as e:
63
- logger.error(f"Error processing frame {frame_index}: {e}")
64
- return (frame_index, None)
65
 
66
  # Process video frames and generate depth maps
67
- def process_video(video_path, fps=0, seed=0, num_workers=4):
68
- """Process video frames in parallel and generate depth maps."""
69
  # Create a persistent temporary directory
70
  temp_dir = tempfile.mkdtemp()
71
  try:
72
  start_time = time.time()
73
-
74
  # Preprocess the video
75
  video = preprocess_video(video_path, target_fps=fps)
76
-
77
  # Use original video FPS if not specified
78
  if fps == 0:
79
  fps = video.fps
80
-
81
  frames = list(video.iter_frames(fps=video.fps))
82
  total_frames = len(frames)
83
-
84
  logger.info(f"Processing {total_frames} frames at {fps} FPS...")
85
-
86
  # Create directory for frame sequence and outputs
87
  frames_dir = os.path.join(temp_dir, "frames")
88
  os.makedirs(frames_dir, exist_ok=True)
89
-
90
- # Prepare arguments for multiprocessing
91
- args_list = [(i, frames[i], seed) for i in range(total_frames)]
92
-
93
- # Use multiprocessing Pool to process frames in parallel
94
- with Pool(processes=num_workers) as pool:
95
- results = []
96
- for result in pool.imap_unordered(process_frame, args_list):
97
- frame_index, depth_map = result
 
98
  if depth_map is not None:
99
  # Save frame
100
  frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png")
101
- Image.fromarray(depth_map.squeeze()).save(frame_path)
102
-
103
- # Update preview every 10% progress
104
- if (frame_index + 1) % max(1, total_frames // 10) == 0:
105
  elapsed_time = time.time() - start_time
106
- progress = ((frame_index + 1) / total_frames) * 100
107
- yield depth_map, None, None, f"Processed {frame_index + 1}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s"
108
  else:
109
- logger.error(f"Frame {frame_index} failed to process.")
110
-
111
  logger.info("Creating output files...")
112
-
113
  # Create ZIP of frame sequence
114
  zip_filename = f"depth_frames_{int(time.time())}.zip"
115
  zip_path = os.path.join(temp_dir, zip_filename)
116
  shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
117
-
118
  # Create MP4 video
119
  video_filename = f"depth_video_{int(time.time())}.mp4"
120
  output_video_path = os.path.join(temp_dir, video_filename)
121
-
122
  try:
123
  # FFmpeg settings for high-quality MP4
124
  (
@@ -128,31 +119,29 @@ def process_video(video_path, fps=0, seed=0, num_workers=4):
128
  .run(overwrite_output=True, quiet=True)
129
  )
130
  logger.info("MP4 video created successfully!")
131
-
132
  except ffmpeg.Error as e:
133
  logger.error(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}")
134
  output_video_path = None
135
-
136
  total_time = time.time() - start_time
137
  logger.info("Processing complete!")
138
-
139
  # Yield the file paths
140
  yield None, zip_path, output_video_path, f"Processing complete! Total time: {total_time:.2f} seconds"
141
-
142
  except Exception as e:
143
  logger.error(f"Error: {e}")
144
  yield None, None, None, f"Error processing video: {e}"
145
- # Do not delete temp_dir here; we need the files to persist
146
- # Cleanup can be handled elsewhere if necessary
147
 
148
  # Wrapper function with error handling
149
- def process_wrapper(video, fps=0, seed=0, num_workers=4):
150
  if video is None:
151
  raise gr.Error("Please upload a video.")
152
  try:
153
  outputs = []
154
  # Use video directly, since it's the file path
155
- for output in process_video(video, fps, seed, num_workers):
156
  outputs.append(output)
157
  yield output
158
  return outputs[-1]
@@ -161,7 +150,33 @@ def process_wrapper(video, fps=0, seed=0, num_workers=4):
161
 
162
  # Custom CSS for styling (unchanged)
163
  custom_css = """
164
- /* Your existing custom CSS */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  """
166
 
167
  # Gradio Interface
@@ -171,24 +186,24 @@ with gr.Blocks(css=custom_css) as demo:
171
  <div id="title">Video Depth Estimation</div>
172
  </div>
173
  ''')
174
-
175
  with gr.Row():
176
  with gr.Column():
177
  video_input = gr.Video(label="Upload Video", interactive=True)
178
  fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
179
  seed_slider = gr.Number(value=0, label="Seed")
180
- num_workers_slider = gr.Slider(minimum=1, maximum=cpu_count(), step=1, value=4, label="Number of Workers")
181
  btn = gr.Button("Process Video")
182
-
183
  with gr.Column():
184
  preview_image = gr.Image(label="Live Preview")
185
  output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)")
186
  output_video = gr.File(label="Download Video (MP4)")
187
  time_textbox = gr.Textbox(label="Status", interactive=False)
188
-
189
  btn.click(
190
  fn=process_wrapper,
191
- inputs=[video_input, fps_slider, seed_slider, num_workers_slider],
192
  outputs=[preview_image, output_frames_zip, output_video, time_textbox]
193
  )
194
 
 
8
  import numpy as np
9
  from PIL import Image
10
  import moviepy.editor as mp
11
+ from infer import lotus, load_models
12
  import logging
 
 
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
 
18
  # Set device to use the L40s GPU explicitly
19
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
 
21
+ # Load models once
22
+ task_name = 'depth'
23
+ pipe_g, pipe_d = load_models(task_name, device)
24
+
25
  # Preprocess the video to adjust resolution and frame rate
26
  def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
27
  """Preprocess the video to resize and adjust its frame rate."""
28
  video = mp.VideoFileClip(video_path)
29
+
30
  # Resize video if it's larger than the target resolution
31
  if video.size[0] > max_resolution[0] or video.size[1] > max_resolution[1]:
32
  video = video.resize(height=max_resolution[1])
33
+
34
  # Adjust FPS if target_fps is specified
35
  if target_fps > 0:
36
  video = video.set_fps(target_fps)
37
+
38
  return video
39
 
40
+ # Process a batch of frames through the depth model
41
+ def process_frames_batch(frames_batch, seed=0):
42
+ """Process a batch of frames and return depth maps."""
 
43
  try:
44
+ # Convert frames to PIL Images
45
+ images_batch = [Image.fromarray(frame).convert('RGB') for frame in frames_batch]
46
+
47
+ # Run batch inference
48
+ depth_maps = lotus(images_batch, 'depth', seed, device, pipe_g, pipe_d)
49
+
50
+ return depth_maps
51
+
 
 
 
 
 
 
 
 
 
 
 
52
  except Exception as e:
53
+ logger.error(f"Error processing batch: {e}")
54
+ return [None] * len(frames_batch)
55
 
56
  # Process video frames and generate depth maps
57
+ def process_video(video_path, fps=0, seed=0, batch_size=16):
58
+ """Process video frames in batches and generate depth maps."""
59
  # Create a persistent temporary directory
60
  temp_dir = tempfile.mkdtemp()
61
  try:
62
  start_time = time.time()
63
+
64
  # Preprocess the video
65
  video = preprocess_video(video_path, target_fps=fps)
66
+
67
  # Use original video FPS if not specified
68
  if fps == 0:
69
  fps = video.fps
70
+
71
  frames = list(video.iter_frames(fps=video.fps))
72
  total_frames = len(frames)
73
+
74
  logger.info(f"Processing {total_frames} frames at {fps} FPS...")
75
+
76
  # Create directory for frame sequence and outputs
77
  frames_dir = os.path.join(temp_dir, "frames")
78
  os.makedirs(frames_dir, exist_ok=True)
79
+
80
+ processed_frames = []
81
+
82
+ # Process frames in batches
83
+ for i in range(0, total_frames, batch_size):
84
+ frames_batch = frames[i:i+batch_size]
85
+ depth_maps = process_frames_batch(frames_batch, seed)
86
+
87
+ for j, depth_map in enumerate(depth_maps):
88
+ frame_index = i + j
89
  if depth_map is not None:
90
  # Save frame
91
  frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png")
92
+ depth_map.save(frame_path)
93
+
94
+ # Update live preview every 10% progress
95
+ if frame_index % max(1, total_frames // 10) == 0:
96
  elapsed_time = time.time() - start_time
97
+ progress = (frame_index / total_frames) * 100
98
+ yield depth_map, None, None, f"Processed {frame_index}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s"
99
  else:
100
+ logger.error(f"Error processing frame {frame_index}")
101
+
102
  logger.info("Creating output files...")
103
+
104
  # Create ZIP of frame sequence
105
  zip_filename = f"depth_frames_{int(time.time())}.zip"
106
  zip_path = os.path.join(temp_dir, zip_filename)
107
  shutil.make_archive(zip_path[:-4], 'zip', frames_dir)
108
+
109
  # Create MP4 video
110
  video_filename = f"depth_video_{int(time.time())}.mp4"
111
  output_video_path = os.path.join(temp_dir, video_filename)
112
+
113
  try:
114
  # FFmpeg settings for high-quality MP4
115
  (
 
119
  .run(overwrite_output=True, quiet=True)
120
  )
121
  logger.info("MP4 video created successfully!")
122
+
123
  except ffmpeg.Error as e:
124
  logger.error(f"Error creating video: {e.stderr.decode() if e.stderr else str(e)}")
125
  output_video_path = None
126
+
127
  total_time = time.time() - start_time
128
  logger.info("Processing complete!")
129
+
130
  # Yield the file paths
131
  yield None, zip_path, output_video_path, f"Processing complete! Total time: {total_time:.2f} seconds"
132
+
133
  except Exception as e:
134
  logger.error(f"Error: {e}")
135
  yield None, None, None, f"Error processing video: {e}"
 
 
136
 
137
  # Wrapper function with error handling
138
+ def process_wrapper(video, fps=0, seed=0, batch_size=16):
139
  if video is None:
140
  raise gr.Error("Please upload a video.")
141
  try:
142
  outputs = []
143
  # Use video directly, since it's the file path
144
+ for output in process_video(video, fps, seed, batch_size):
145
  outputs.append(output)
146
  yield output
147
  return outputs[-1]
 
150
 
151
  # Custom CSS for styling (unchanged)
152
  custom_css = """
153
+ .title-container {
154
+ text-align: center;
155
+ padding: 10px 0;
156
+ }
157
+
158
+ #title {
159
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
160
+ font-size: 36px;
161
+ font-weight: bold;
162
+ color: #000000;
163
+ padding: 10px;
164
+ border-radius: 10px;
165
+ display: inline-block;
166
+ background: linear-gradient(
167
+ 135deg,
168
+ #e0f7fa, #e8f5e9, #fff9c4, #ffebee,
169
+ #f3e5f5, #e1f5fe, #fff3e0, #e8eaf6
170
+ );
171
+ background-size: 400% 400%;
172
+ animation: gradient-animation 15s ease infinite;
173
+ }
174
+
175
+ @keyframes gradient-animation {
176
+ 0% { background-position: 0% 50%; }
177
+ 50% { background-position: 100% 50%; }
178
+ 100% { background-position: 0% 50%; }
179
+ }
180
  """
181
 
182
  # Gradio Interface
 
186
  <div id="title">Video Depth Estimation</div>
187
  </div>
188
  ''')
189
+
190
  with gr.Row():
191
  with gr.Column():
192
  video_input = gr.Video(label="Upload Video", interactive=True)
193
  fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
194
  seed_slider = gr.Number(value=0, label="Seed")
195
+ batch_size_slider = gr.Slider(minimum=1, maximum=64, step=1, value=16, label="Batch Size")
196
  btn = gr.Button("Process Video")
197
+
198
  with gr.Column():
199
  preview_image = gr.Image(label="Live Preview")
200
  output_frames_zip = gr.File(label="Download Frame Sequence (ZIP)")
201
  output_video = gr.File(label="Download Video (MP4)")
202
  time_textbox = gr.Textbox(label="Status", interactive=False)
203
+
204
  btn.click(
205
  fn=process_wrapper,
206
+ inputs=[video_input, fps_slider, seed_slider, batch_size_slider],
207
  outputs=[preview_image, output_frames_zip, output_video, time_textbox]
208
  )
209