ghostsInTheMachine commited on
Commit
d5d8098
1 Parent(s): 8c25de0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -23,7 +23,7 @@ task_name = 'depth'
23
  pipe_g, pipe_d = load_models(task_name, device)
24
 
25
  # Preprocess the video to adjust resolution and frame rate
26
- def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
27
  """Preprocess the video to resize and adjust its frame rate."""
28
  video = mp.VideoFileClip(video_path)
29
 
@@ -38,11 +38,13 @@ def preprocess_video(video_path, target_fps=24, max_resolution=(1920, 1080)):
38
  return video
39
 
40
  # Process a batch of frames through the depth model
41
- def process_frames_batch(frames_batch, seed=0):
42
  """Process a batch of frames and return depth maps."""
43
  try:
44
- # Convert frames to PIL Images
45
- images_batch = [Image.fromarray(frame).convert('RGB') for frame in frames_batch]
 
 
46
 
47
  # Run batch inference
48
  depth_maps = lotus(images_batch, 'depth', seed, device, pipe_g, pipe_d)
@@ -54,7 +56,7 @@ def process_frames_batch(frames_batch, seed=0):
54
  return [None] * len(frames_batch)
55
 
56
  # Process video frames and generate depth maps
57
- def process_video(video_path, fps=0, seed=0, batch_size=16):
58
  """Process video frames in batches and generate depth maps."""
59
  # Create a persistent temporary directory
60
  temp_dir = tempfile.mkdtemp()
@@ -62,7 +64,7 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
62
  start_time = time.time()
63
 
64
  # Preprocess the video
65
- video = preprocess_video(video_path, target_fps=fps)
66
 
67
  # Use original video FPS if not specified
68
  if fps == 0:
@@ -81,8 +83,19 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
81
 
82
  # Process frames in batches
83
  for i in range(0, total_frames, batch_size):
84
- frames_batch = frames[i:i+batch_size]
85
- depth_maps = process_frames_batch(frames_batch, seed)
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  for j, depth_map in enumerate(depth_maps):
88
  frame_index = i + j
@@ -135,7 +148,7 @@ def process_video(video_path, fps=0, seed=0, batch_size=16):
135
  yield None, None, None, f"Error processing video: {e}"
136
 
137
  # Wrapper function with error handling
138
- def process_wrapper(video, fps=0, seed=0, batch_size=16):
139
  if video is None:
140
  raise gr.Error("Please upload a video.")
141
  try:
@@ -192,7 +205,7 @@ with gr.Blocks(css=custom_css) as demo:
192
  video_input = gr.Video(label="Upload Video", interactive=True)
193
  fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
194
  seed_slider = gr.Number(value=0, label="Seed")
195
- batch_size_slider = gr.Slider(minimum=1, maximum=64, step=1, value=16, label="Batch Size")
196
  btn = gr.Button("Process Video")
197
 
198
  with gr.Column():
 
23
  pipe_g, pipe_d = load_models(task_name, device)
24
 
25
  # Preprocess the video to adjust resolution and frame rate
26
+ def preprocess_video(video_path, target_fps=24, max_resolution=(512, 512)):
27
  """Preprocess the video to resize and adjust its frame rate."""
28
  video = mp.VideoFileClip(video_path)
29
 
 
38
  return video
39
 
40
  # Process a batch of frames through the depth model
41
+ def process_frames_batch(frames_batch, seed=0, target_size=(512, 512)):
42
  """Process a batch of frames and return depth maps."""
43
  try:
44
+ torch.cuda.empty_cache() # Clear GPU cache
45
+
46
+ # Resize frames to the target size
47
+ images_batch = [Image.fromarray(frame).convert('RGB').resize(target_size, Image.BILINEAR) for frame in frames_batch]
48
 
49
  # Run batch inference
50
  depth_maps = lotus(images_batch, 'depth', seed, device, pipe_g, pipe_d)
 
56
  return [None] * len(frames_batch)
57
 
58
  # Process video frames and generate depth maps
59
+ def process_video(video_path, fps=0, seed=0, batch_size=4):
60
  """Process video frames in batches and generate depth maps."""
61
  # Create a persistent temporary directory
62
  temp_dir = tempfile.mkdtemp()
 
64
  start_time = time.time()
65
 
66
  # Preprocess the video
67
+ video = preprocess_video(video_path, target_fps=fps, max_resolution=(512, 512))
68
 
69
  # Use original video FPS if not specified
70
  if fps == 0:
 
83
 
84
  # Process frames in batches
85
  for i in range(0, total_frames, batch_size):
86
+ current_batch_size = batch_size
87
+ while current_batch_size > 0:
88
+ try:
89
+ frames_batch = frames[i:i+current_batch_size]
90
+ depth_maps = process_frames_batch(frames_batch, seed)
91
+ break
92
+ except RuntimeError as e:
93
+ if 'out of memory' in str(e):
94
+ current_batch_size = current_batch_size // 2
95
+ logger.warning(f"Reducing batch size to {current_batch_size} due to out of memory error.")
96
+ torch.cuda.empty_cache()
97
+ else:
98
+ raise e
99
 
100
  for j, depth_map in enumerate(depth_maps):
101
  frame_index = i + j
 
148
  yield None, None, None, f"Error processing video: {e}"
149
 
150
  # Wrapper function with error handling
151
+ def process_wrapper(video, fps=0, seed=0, batch_size=4):
152
  if video is None:
153
  raise gr.Error("Please upload a video.")
154
  try:
 
205
  video_input = gr.Video(label="Upload Video", interactive=True)
206
  fps_slider = gr.Slider(minimum=0, maximum=60, step=1, value=0, label="Output FPS (0 for original)")
207
  seed_slider = gr.Number(value=0, label="Seed")
208
+ batch_size_slider = gr.Slider(minimum=1, maximum=16, step=1, value=4, label="Batch Size")
209
  btn = gr.Button("Process Video")
210
 
211
  with gr.Column():