ghostsInTheMachine commited on
Commit
8e07e41
1 Parent(s): 8df521f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -36
app.py CHANGED
@@ -11,6 +11,7 @@ import moviepy.editor as mp
11
  from infer import lotus # Import the depth model inference function
12
  import logging
13
  import io
 
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
@@ -35,13 +36,17 @@ def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
35
  return video
36
 
37
  # Process a single frame through the depth model
38
- def process_frame(image, seed=0):
39
  """Process a single frame through the depth model and return depth map."""
 
40
  try:
41
  # Set seeds for reproducibility
42
  torch.manual_seed(seed)
43
  np.random.seed(seed)
44
 
 
 
 
45
  # Save image to an in-memory file
46
  img_bytes = io.BytesIO()
47
  image.save(img_bytes, format='PNG')
@@ -52,15 +57,15 @@ def process_frame(image, seed=0):
52
 
53
  # Convert depth output to numpy array
54
  depth_array = np.array(output_d)
55
- return depth_array
56
-
57
  except Exception as e:
58
- logger.error(f"Error processing frame: {e}")
59
- return None
60
 
61
  # Process video frames and generate depth maps
62
- def process_video(video_path, fps=0, seed=0, batch_size=16):
63
- """Process video, batch frames, and use L40s GPU to generate depth maps."""
64
  # Create a persistent temporary directory
65
  temp_dir = tempfile.mkdtemp()
66
  try:
@@ -81,37 +86,28 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
81
  # Create directory for frame sequence and outputs
82
  frames_dir = os.path.join(temp_dir, "frames")
83
  os.makedirs(frames_dir, exist_ok=True)
84
-
85
- processed_frames = []
86
 
87
- # Process frames in batches
88
- for i in range(0, total_frames, batch_size):
89
- frames_batch = frames[i:i+batch_size]
90
- depth_maps = []
91
-
92
- # Process each frame in the batch
93
- for frame in frames_batch:
94
- depth_map = process_frame(Image.fromarray(frame), seed)
95
- depth_maps.append(depth_map)
96
-
97
- for j, depth_map in enumerate(depth_maps):
98
  if depth_map is not None:
99
  # Save frame
100
- frame_index = i + j
101
  frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png")
102
- Image.fromarray(depth_map).save(frame_path)
103
-
104
- # Collect processed frame for preview
105
- processed_frames.append(depth_map)
106
-
107
  # Update preview every 10% progress
108
- if frame_index % max(1, total_frames // 10) == 0:
109
  elapsed_time = time.time() - start_time
110
- progress = (frame_index / total_frames) * 100
111
- yield processed_frames[-1], None, None, f"Processed {frame_index}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s"
112
  else:
113
- logger.error(f"Error processing frame {frame_index}")
114
-
115
  logger.info("Creating output files...")
116
 
117
  # Create ZIP of frame sequence
@@ -129,7 +125,7 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
129
  ffmpeg
130
  .input(os.path.join(frames_dir, 'frame_%06d.png'), pattern_type='sequence', framerate=fps)
131
  .output(output_video_path, vcodec='libx264', pix_fmt='yuv420p', crf=17)
132
- .run(overwrite_output=True)
133
  )
134
  logger.info("MP4 video created successfully!")
135
 
@@ -150,13 +146,13 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
150
  # Cleanup can be handled elsewhere if necessary
151
 
152
  # Wrapper function with error handling
153
- def process_wrapper(video, fps=0, seed=0, batch_size=16):
154
  if video is None:
155
  raise gr.Error("Please upload a video.")
156
  try:
157
  outputs = []
158
  # Use video directly, since it's the file path
159
- for output in process_video(video, fps, seed, batch_size):
160
  outputs.append(output)
161
  yield output
162
  return outputs[-1]
@@ -181,7 +177,7 @@ with gr.Blocks(css=custom_css) as demo:
181
  video_input = gr.Video(label="Upload Video", interactive=True)
182
  fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
183
  seed_slider = gr.Number(value=0, label="Seed")
184
- batch_size_slider = gr.Slider(minimum=1, maximum=64, step=1, value=16, label="Batch Size")
185
  btn = gr.Button("Process Video")
186
 
187
  with gr.Column():
@@ -192,7 +188,7 @@ with gr.Blocks(css=custom_css) as demo:
192
 
193
  btn.click(
194
  fn=process_wrapper,
195
- inputs=[video_input, fps_slider, seed_slider, batch_size_slider],
196
  outputs=[preview_image, output_frames_zip, output_video, time_textbox]
197
  )
198
 
 
11
  from infer import lotus # Import the depth model inference function
12
  import logging
13
  import io
14
+ from multiprocessing import Pool, cpu_count
15
 
16
  # Set up logging
17
  logging.basicConfig(level=logging.INFO)
 
36
  return video
37
 
38
  # Process a single frame through the depth model
39
+ def process_frame(args):
40
  """Process a single frame through the depth model and return depth map."""
41
+ frame_index, frame_data, seed = args
42
  try:
43
  # Set seeds for reproducibility
44
  torch.manual_seed(seed)
45
  np.random.seed(seed)
46
 
47
+ # Convert frame data to PIL Image
48
+ image = Image.fromarray(frame_data).convert('RGB')
49
+
50
  # Save image to an in-memory file
51
  img_bytes = io.BytesIO()
52
  image.save(img_bytes, format='PNG')
 
57
 
58
  # Convert depth output to numpy array
59
  depth_array = np.array(output_d)
60
+ return (frame_index, depth_array)
61
+
62
  except Exception as e:
63
+ logger.error(f"Error processing frame {frame_index}: {e}")
64
+ return (frame_index, None)
65
 
66
  # Process video frames and generate depth maps
67
+ def process_video(video_path, fps=0, seed=0, num_workers=4):
68
+ """Process video frames in parallel and generate depth maps."""
69
  # Create a persistent temporary directory
70
  temp_dir = tempfile.mkdtemp()
71
  try:
 
86
  # Create directory for frame sequence and outputs
87
  frames_dir = os.path.join(temp_dir, "frames")
88
  os.makedirs(frames_dir, exist_ok=True)
 
 
89
 
90
+ # Prepare arguments for multiprocessing
91
+ args_list = [(i, frames[i], seed) for i in range(total_frames)]
92
+
93
+ # Use multiprocessing Pool to process frames in parallel
94
+ with Pool(processes=num_workers) as pool:
95
+ results = []
96
+ for result in pool.imap_unordered(process_frame, args_list):
97
+ frame_index, depth_map = result
 
 
 
98
  if depth_map is not None:
99
  # Save frame
 
100
  frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png")
101
+ Image.fromarray(depth_map.squeeze()).save(frame_path)
102
+
 
 
 
103
  # Update preview every 10% progress
104
+ if (frame_index + 1) % max(1, total_frames // 10) == 0:
105
  elapsed_time = time.time() - start_time
106
+ progress = ((frame_index + 1) / total_frames) * 100
107
+ yield depth_map, None, None, f"Processed {frame_index + 1}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s"
108
  else:
109
+ logger.error(f"Frame {frame_index} failed to process.")
110
+
111
  logger.info("Creating output files...")
112
 
113
  # Create ZIP of frame sequence
 
125
  ffmpeg
126
  .input(os.path.join(frames_dir, 'frame_%06d.png'), pattern_type='sequence', framerate=fps)
127
  .output(output_video_path, vcodec='libx264', pix_fmt='yuv420p', crf=17)
128
+ .run(overwrite_output=True, quiet=True)
129
  )
130
  logger.info("MP4 video created successfully!")
131
 
 
146
  # Cleanup can be handled elsewhere if necessary
147
 
148
  # Wrapper function with error handling
149
+ def process_wrapper(video, fps=0, seed=0, num_workers=4):
150
  if video is None:
151
  raise gr.Error("Please upload a video.")
152
  try:
153
  outputs = []
154
  # Use video directly, since it's the file path
155
+ for output in process_video(video, fps, seed, num_workers):
156
  outputs.append(output)
157
  yield output
158
  return outputs[-1]
 
177
  video_input = gr.Video(label="Upload Video", interactive=True)
178
  fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
179
  seed_slider = gr.Number(value=0, label="Seed")
180
+ num_workers_slider = gr.Slider(minimum=1, maximum=cpu_count(), step=1, value=4, label="Number of Workers")
181
  btn = gr.Button("Process Video")
182
 
183
  with gr.Column():
 
188
 
189
  btn.click(
190
  fn=process_wrapper,
191
+ inputs=[video_input, fps_slider, seed_slider, num_workers_slider],
192
  outputs=[preview_image, output_frames_zip, output_video, time_textbox]
193
  )
194