ghostsInTheMachine commited on
Commit
7d32b39
1 Parent(s): 693892f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -49
app.py CHANGED
@@ -37,27 +37,27 @@ def preprocess_video(video_path, target_fps=24, max_resolution=(512, 512)):
37
 
38
  return video
39
 
40
- # Process a batch of frames through the depth model
41
- def process_frames_batch(frames_batch, seed=0, target_size=(512, 512)):
42
- """Process a batch of frames and return depth maps."""
43
  try:
44
  torch.cuda.empty_cache() # Clear GPU cache
45
 
46
- # Resize frames to the target size
47
- images_batch = [Image.fromarray(frame).convert('RGB').resize(target_size, Image.BILINEAR) for frame in frames_batch]
48
 
49
- # Run batch inference
50
- depth_maps = lotus(images_batch, 'depth', seed, device, pipe_g, pipe_d)
51
 
52
- return depth_maps
53
 
54
  except Exception as e:
55
- logger.error(f"Error processing batch: {e}")
56
- return [None] * len(frames_batch)
57
 
58
  # Process video frames and generate depth maps
59
- def process_video(video_path, fps=0, seed=0, batch_size=4):
60
- """Process video frames in batches and generate depth maps."""
61
  # Create a persistent temporary directory
62
  temp_dir = tempfile.mkdtemp()
63
  try:
@@ -79,39 +79,22 @@ def process_video(video_path, fps=0, seed=0, batch_size=4):
79
  frames_dir = os.path.join(temp_dir, "frames")
80
  os.makedirs(frames_dir, exist_ok=True)
81
 
82
- processed_frames = []
83
-
84
- # Process frames in batches
85
- for i in range(0, total_frames, batch_size):
86
- current_batch_size = batch_size
87
- success = False
88
- while current_batch_size > 0 and not success:
89
- try:
90
- frames_batch = frames[i:i+current_batch_size]
91
- depth_maps = process_frames_batch(frames_batch, seed)
92
- success = True
93
- except RuntimeError as e:
94
- if 'out of memory' in str(e):
95
- current_batch_size = max(1, current_batch_size // 2)
96
- logger.warning(f"Reducing batch size to {current_batch_size} due to out of memory error.")
97
- torch.cuda.empty_cache()
98
- else:
99
- raise e
100
-
101
- for j, depth_map in enumerate(depth_maps):
102
- frame_index = i + j
103
- if depth_map is not None:
104
- # Save frame
105
- frame_path = os.path.join(frames_dir, f"frame_{frame_index:06d}.png")
106
- depth_map.save(frame_path)
107
-
108
- # Update live preview every 10% progress
109
- if frame_index % max(1, total_frames // 10) == 0:
110
- elapsed_time = time.time() - start_time
111
- progress = (frame_index / total_frames) * 100
112
- yield depth_map, None, None, f"Processed {frame_index}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s"
113
- else:
114
- logger.error(f"Error processing frame {frame_index}")
115
 
116
  logger.info("Creating output files...")
117
 
@@ -153,13 +136,13 @@ def process_video(video_path, fps=0, seed=0, batch_size=4):
153
  pass
154
 
155
  # Wrapper function with error handling
156
- def process_wrapper(video, fps=0, seed=0, batch_size=4):
157
  if video is None:
158
  raise gr.Error("Please upload a video.")
159
  try:
160
  outputs = []
161
  # Use video directly, since it's the file path
162
- for output in process_video(video, fps, seed, batch_size):
163
  outputs.append(output)
164
  yield output
165
  return outputs[-1]
@@ -210,7 +193,6 @@ with gr.Blocks(css=custom_css) as demo:
210
  video_input = gr.Video(label="Upload Video", interactive=True)
211
  fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
212
  seed_slider = gr.Number(value=0, label="Seed")
213
- batch_size_slider = gr.Slider(minimum=1, maximum=16, step=1, value=4, label="Batch Size")
214
  btn = gr.Button("Process Video")
215
 
216
  with gr.Column():
@@ -221,7 +203,7 @@ with gr.Blocks(css=custom_css) as demo:
221
 
222
  btn.click(
223
  fn=process_wrapper,
224
- inputs=[video_input, fps_slider, seed_slider, batch_size_slider],
225
  outputs=[preview_image, output_frames_zip, output_video, time_textbox]
226
  )
227
 
 
37
 
38
  return video
39
 
40
+ # Process a single frame through the depth model
41
+ def process_frame(frame, seed=0, target_size=(512, 512)):
42
+ """Process a single frame and return depth map."""
43
  try:
44
  torch.cuda.empty_cache() # Clear GPU cache
45
 
46
+ # Resize frame to the target size
47
+ image = Image.fromarray(frame).convert('RGB').resize(target_size, Image.BILINEAR)
48
 
49
+ # Run inference
50
+ depth_map = lotus(image, 'depth', seed, device, pipe_g, pipe_d)
51
 
52
+ return depth_map
53
 
54
  except Exception as e:
55
+ logger.error(f"Error processing frame: {e}")
56
+ return None
57
 
58
  # Process video frames and generate depth maps
59
+ def process_video(video_path, fps=0, seed=0):
60
+ """Process video frames individually and generate depth maps."""
61
  # Create a persistent temporary directory
62
  temp_dir = tempfile.mkdtemp()
63
  try:
 
79
  frames_dir = os.path.join(temp_dir, "frames")
80
  os.makedirs(frames_dir, exist_ok=True)
81
 
82
+ # Process frames individually
83
+ for i, frame in enumerate(frames):
84
+ depth_map = process_frame(frame, seed)
85
+
86
+ if depth_map is not None:
87
+ # Save frame
88
+ frame_path = os.path.join(frames_dir, f"frame_{i:06d}.png")
89
+ depth_map.save(frame_path)
90
+
91
+ # Update live preview every 10% progress
92
+ if i % max(1, total_frames // 10) == 0:
93
+ elapsed_time = time.time() - start_time
94
+ progress = (i / total_frames) * 100
95
+ yield depth_map, None, None, f"Processed {i}/{total_frames} frames... ({progress:.2f}%) Elapsed: {elapsed_time:.2f}s"
96
+ else:
97
+ logger.error(f"Error processing frame {i}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  logger.info("Creating output files...")
100
 
 
136
  pass
137
 
138
  # Wrapper function with error handling
139
+ def process_wrapper(video, fps=0, seed=0):
140
  if video is None:
141
  raise gr.Error("Please upload a video.")
142
  try:
143
  outputs = []
144
  # Use video directly, since it's the file path
145
+ for output in process_video(video, fps, seed):
146
  outputs.append(output)
147
  yield output
148
  return outputs[-1]
 
193
  video_input = gr.Video(label="Upload Video", interactive=True)
194
  fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
195
  seed_slider = gr.Number(value=0, label="Seed")
 
196
  btn = gr.Button("Process Video")
197
 
198
  with gr.Column():
 
203
 
204
  btn.click(
205
  fn=process_wrapper,
206
+ inputs=[video_input, fps_slider, seed_slider],
207
  outputs=[preview_image, output_frames_zip, output_video, time_textbox]
208
  )
209