SpyC0der77 commited on
Commit
29b5ebe
·
verified ·
1 Parent(s): 7419c44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -84
app.py CHANGED
@@ -7,14 +7,24 @@ import tempfile
7
  import os
8
  import gradio as gr
9
  import time
10
- import io
11
 
12
- # Set up device for torch
 
 
 
 
 
 
 
 
 
 
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"[INFO] Using device: {device}")
15
 
16
- # Try to load the RAFT model from torch.hub.
17
- # If it fails, fall back to OpenCV's Farneback optical flow.
18
  try:
19
  print("[INFO] Attempting to load RAFT model from torch.hub...")
20
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
@@ -26,64 +36,58 @@ except Exception as e:
26
  print("[INFO] Falling back to OpenCV Farneback optical flow.")
27
  raft_model = None
28
 
29
- def process_video_ai(video_file, zoom):
 
 
 
 
 
 
30
  """
31
- Generator function for Gradio:
32
- - Generates motion data (CSV) from the input video using an AI model (RAFT if available, else Farneback).
33
- - Stabilizes the video using the generated motion data.
34
-
35
- Yields:
36
- A tuple of (original_video, stabilized_video, logs, progress)
37
- During processing, original_video and stabilized_video are None.
38
- The final yield returns the video file paths with final logs and progress=100.
39
  """
40
- logs = []
41
- def add_log(msg):
42
- logs.append(msg)
43
- return "\n".join(logs)
44
-
45
- # Check and extract the file path.
46
- if isinstance(video_file, dict):
47
- video_file = video_file.get("name", None)
48
- if video_file is None:
49
- yield (None, None, "[ERROR] Please upload a video file.", 0)
50
- return
51
-
52
- add_log("[INFO] Starting AI-powered video processing...")
53
- yield (None, None, add_log("Starting processing..."), 0)
54
-
55
  # === CSV Generation Phase ===
56
- add_log("[INFO] Starting motion CSV generation...")
57
- yield (None, None, add_log("Starting CSV generation..."), 0)
58
-
59
  cap = cv2.VideoCapture(video_file)
60
  if not cap.isOpened():
61
- yield (None, None, add_log("[ERROR] Could not open video file for CSV generation."), 0)
 
62
  return
 
63
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
64
- add_log(f"[INFO] Total frames in video: {total_frames}")
65
-
66
- # Create temporary CSV file.
67
  csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv').name
68
  with open(csv_file, 'w', newline='') as csvfile:
69
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
70
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
71
  writer.writeheader()
72
-
73
  ret, first_frame = cap.read()
74
  if not ret:
75
- yield (None, None, add_log("[ERROR] Cannot read first frame from video."), 0)
 
 
76
  return
77
-
78
  if raft_model is not None:
79
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
80
  prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
81
  prev_tensor = prev_tensor.to(device)
82
- add_log("[INFO] Using RAFT model for optical flow computation.")
83
  else:
84
  prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
85
- add_log("[INFO] Using Farneback optical flow for computation.")
86
-
87
  frame_idx = 1
88
  while True:
89
  ret, frame = cap.read()
@@ -95,7 +99,7 @@ def process_video_ai(video_file, zoom):
95
  curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
96
  curr_tensor = curr_tensor.to(device)
97
  with torch.no_grad():
98
- flow_low, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True)
99
  flow = flow_up[0].permute(1, 2, 0).cpu().numpy()
100
  prev_tensor = curr_tensor.clone()
101
  else:
@@ -104,7 +108,7 @@ def process_video_ai(video_file, zoom):
104
  pyr_scale=0.5, levels=3, winsize=15,
105
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
106
  prev_gray = curr_gray
107
-
108
  # Compute median magnitude and angle.
109
  mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
110
  median_mag = np.median(mag)
@@ -117,27 +121,24 @@ def process_video_ai(video_file, zoom):
117
  y_offset = y_coords - center_y
118
  dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset
119
  zoom_factor = np.count_nonzero(dot > 0) / (w * h)
120
-
121
  writer.writerow({
122
  'frame': frame_idx,
123
  'mag': median_mag,
124
  'ang': median_ang,
125
  'zoom': zoom_factor
126
  })
127
-
128
  if frame_idx % 10 == 0 or frame_idx == total_frames:
129
  progress_csv = (frame_idx / total_frames) * 50 # CSV phase: 0-50%
130
- add_log(f"[INFO] CSV: Processed frame {frame_idx}/{total_frames}")
131
- yield (None, None, add_log(""), progress_csv)
132
  frame_idx += 1
133
  cap.release()
134
- add_log("[INFO] CSV generation complete.")
135
- yield (None, None, add_log(""), 50)
136
-
137
  # === Stabilization Phase ===
138
- add_log("[INFO] Starting video stabilization...")
139
- yield (None, None, add_log("Starting stabilization..."), 51)
140
-
141
  # Read the CSV and compute cumulative motion data.
142
  motion_data = {}
143
  cumulative_dx = 0.0
@@ -154,9 +155,9 @@ def process_video_ai(video_file, zoom):
154
  cumulative_dx += dx
155
  cumulative_dy += dy
156
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
157
- add_log("[INFO] Motion CSV read complete.")
158
- yield (None, None, add_log(""), 55)
159
-
160
  # Re-open video for stabilization.
161
  cap = cv2.VideoCapture(video_file)
162
  fps = cap.get(cv2.CAP_PROP_FPS)
@@ -167,7 +168,7 @@ def process_video_ai(video_file, zoom):
167
  temp_file.close()
168
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
169
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
170
-
171
  frame_idx = 1
172
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
173
  while True:
@@ -180,55 +181,60 @@ def process_video_ai(video_file, zoom):
180
  start_x = max((zoomed_w - width) // 2, 0)
181
  start_y = max((zoomed_h - height) // 2, 0)
182
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
183
-
184
  dx, dy = motion_data.get(frame_idx, (0, 0))
185
  transform = np.array([[1, 0, dx],
186
  [0, 1, dy]], dtype=np.float32)
187
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
188
  out.write(stabilized_frame)
189
-
190
  if frame_idx % 10 == 0 or frame_idx == total_frames:
191
  progress_stab = 50 + (frame_idx / total_frames) * 50 # Stabilization phase: 50-100%
192
- add_log(f"[INFO] Stabilization: Processed frame {frame_idx}/{total_frames}")
193
- yield (None, None, add_log(""), progress_stab)
194
  frame_idx += 1
195
  cap.release()
196
  out.release()
197
- add_log("[INFO] Stabilization complete.")
198
- yield (video_file, output_file, add_log(""), 100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  # Build the Gradio UI.
201
  with gr.Blocks() as demo:
202
  gr.Markdown("# AI-Powered Video Stabilization")
203
- gr.Markdown("Upload a video and select a zoom factor. The system will generate motion data using an AI model (RAFT if available, else Farneback) and then stabilize the video. Logs and progress will update during processing.")
204
-
205
  with gr.Row():
206
  with gr.Column():
207
  video_input = gr.Video(label="Input Video")
208
  zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
209
- process_button = gr.Button("Process Video")
210
  with gr.Column():
211
  original_video = gr.Video(label="Original Video")
212
  stabilized_video = gr.Video(label="Stabilized Video")
213
  logs_output = gr.Textbox(label="Logs", lines=15)
214
  progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, interactive=False)
215
-
216
- demo.queue() # enable queue for streaming
217
-
218
- # Try using stream=True. If that raises a TypeError, fall back without it.
219
- try:
220
- process_button.click(
221
- fn=process_video_ai,
222
- inputs=[video_input, zoom_slider],
223
- outputs=[original_video, stabilized_video, logs_output, progress_bar],
224
- stream=True
225
- )
226
- except TypeError as e:
227
- print("[WARNING] Streaming not supported in this version of Gradio. Disabling streaming.")
228
- process_button.click(
229
- fn=process_video_ai,
230
- inputs=[video_input, zoom_slider],
231
- outputs=[original_video, stabilized_video, logs_output, progress_bar]
232
- )
233
 
234
  demo.launch()
 
7
  import os
8
  import gradio as gr
9
  import time
10
+ import threading
11
 
12
+ # Global status and result dictionaries.
13
+ status = {
14
+ "logs": "",
15
+ "progress": 0, # 0 to 100
16
+ "finished": False
17
+ }
18
+ result = {
19
+ "original_video": None,
20
+ "stabilized_video": None
21
+ }
22
+
23
+ # Set up device for torch.
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  print(f"[INFO] Using device: {device}")
26
 
27
+ # Try to load the RAFT model. If it fails, we fall back to Farneback.
 
28
  try:
29
  print("[INFO] Attempting to load RAFT model from torch.hub...")
30
  raft_model = torch.hub.load("princeton-vl/RAFT", "raft_small", pretrained=True, trust_repo=True)
 
36
  print("[INFO] Falling back to OpenCV Farneback optical flow.")
37
  raft_model = None
38
 
39
+ def append_log(msg):
40
+ """Helper to append a log message to the global status."""
41
+ global status
42
+ status["logs"] += msg + "\n"
43
+ print(msg)
44
+
45
+ def background_process(video_file, zoom):
46
  """
47
+ Runs the full processing: generates a motion CSV using RAFT (or Farneback)
48
+ and then stabilizes the video. Updates the global status and result.
 
 
 
 
 
 
49
  """
50
+ global status, result
51
+
52
+ status["logs"] = ""
53
+ status["progress"] = 0
54
+ status["finished"] = False
55
+ result["original_video"] = None
56
+ result["stabilized_video"] = None
57
+
58
+ append_log("[INFO] Starting AI-powered video processing...")
 
 
 
 
 
 
59
  # === CSV Generation Phase ===
60
+ append_log("[INFO] Starting motion CSV generation...")
 
 
61
  cap = cv2.VideoCapture(video_file)
62
  if not cap.isOpened():
63
+ append_log("[ERROR] Could not open video file for CSV generation.")
64
+ status["finished"] = True
65
  return
66
+
67
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
68
+ append_log(f"[INFO] Total frames in video: {total_frames}")
 
 
69
  csv_file = tempfile.NamedTemporaryFile(delete=False, suffix='.csv').name
70
  with open(csv_file, 'w', newline='') as csvfile:
71
  fieldnames = ['frame', 'mag', 'ang', 'zoom']
72
  writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
73
  writer.writeheader()
74
+
75
  ret, first_frame = cap.read()
76
  if not ret:
77
+ append_log("[ERROR] Cannot read first frame from video.")
78
+ status["finished"] = True
79
+ cap.release()
80
  return
81
+
82
  if raft_model is not None:
83
  first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
84
  prev_tensor = torch.from_numpy(first_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
85
  prev_tensor = prev_tensor.to(device)
86
+ append_log("[INFO] Using RAFT model for optical flow computation.")
87
  else:
88
  prev_gray = cv2.cvtColor(first_frame, cv2.COLOR_BGR2GRAY)
89
+ append_log("[INFO] Using Farneback optical flow for computation.")
90
+
91
  frame_idx = 1
92
  while True:
93
  ret, frame = cap.read()
 
99
  curr_tensor = torch.from_numpy(curr_frame_rgb).permute(2, 0, 1).float().unsqueeze(0) / 255.0
100
  curr_tensor = curr_tensor.to(device)
101
  with torch.no_grad():
102
+ _, flow_up = raft_model(prev_tensor, curr_tensor, iters=20, test_mode=True)
103
  flow = flow_up[0].permute(1, 2, 0).cpu().numpy()
104
  prev_tensor = curr_tensor.clone()
105
  else:
 
108
  pyr_scale=0.5, levels=3, winsize=15,
109
  iterations=3, poly_n=5, poly_sigma=1.2, flags=0)
110
  prev_gray = curr_gray
111
+
112
  # Compute median magnitude and angle.
113
  mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1], angleInDegrees=True)
114
  median_mag = np.median(mag)
 
121
  y_offset = y_coords - center_y
122
  dot = flow[..., 0] * x_offset + flow[..., 1] * y_offset
123
  zoom_factor = np.count_nonzero(dot > 0) / (w * h)
 
124
  writer.writerow({
125
  'frame': frame_idx,
126
  'mag': median_mag,
127
  'ang': median_ang,
128
  'zoom': zoom_factor
129
  })
130
+
131
  if frame_idx % 10 == 0 or frame_idx == total_frames:
132
  progress_csv = (frame_idx / total_frames) * 50 # CSV phase: 0-50%
133
+ append_log(f"[INFO] CSV: Processed frame {frame_idx}/{total_frames}")
134
+ status["progress"] = progress_csv
135
  frame_idx += 1
136
  cap.release()
137
+ append_log("[INFO] CSV generation complete.")
138
+ status["progress"] = 50
139
+
140
  # === Stabilization Phase ===
141
+ append_log("[INFO] Starting video stabilization...")
 
 
142
  # Read the CSV and compute cumulative motion data.
143
  motion_data = {}
144
  cumulative_dx = 0.0
 
155
  cumulative_dx += dx
156
  cumulative_dy += dy
157
  motion_data[frame_num] = (-cumulative_dx, -cumulative_dy)
158
+ append_log("[INFO] Motion CSV read complete.")
159
+ status["progress"] = 55
160
+
161
  # Re-open video for stabilization.
162
  cap = cv2.VideoCapture(video_file)
163
  fps = cap.get(cv2.CAP_PROP_FPS)
 
168
  temp_file.close()
169
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
  out = cv2.VideoWriter(output_file, fourcc, fps, (width, height))
171
+
172
  frame_idx = 1
173
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
174
  while True:
 
181
  start_x = max((zoomed_w - width) // 2, 0)
182
  start_y = max((zoomed_h - height) // 2, 0)
183
  frame = zoomed_frame[start_y:start_y+height, start_x:start_x+width]
 
184
  dx, dy = motion_data.get(frame_idx, (0, 0))
185
  transform = np.array([[1, 0, dx],
186
  [0, 1, dy]], dtype=np.float32)
187
  stabilized_frame = cv2.warpAffine(frame, transform, (width, height))
188
  out.write(stabilized_frame)
 
189
  if frame_idx % 10 == 0 or frame_idx == total_frames:
190
  progress_stab = 50 + (frame_idx / total_frames) * 50 # Stabilization phase: 50-100%
191
+ append_log(f"[INFO] Stabilization: Processed frame {frame_idx}/{total_frames}")
192
+ status["progress"] = progress_stab
193
  frame_idx += 1
194
  cap.release()
195
  out.release()
196
+ append_log("[INFO] Stabilization complete.")
197
+ status["progress"] = 100
198
+ status["finished"] = True
199
+ result["original_video"] = video_file
200
+ result["stabilized_video"] = output_file
201
+
202
+ def start_processing(video_file, zoom):
203
+ """Starts background processing in a new thread."""
204
+ thread = threading.Thread(target=background_process, args=(video_file, zoom), daemon=True)
205
+ thread.start()
206
+ return "Processing started."
207
+
208
+ def poll_status():
209
+ """
210
+ Returns the current processing status:
211
+ - original_video: path if finished (else None)
212
+ - stabilized_video: path if finished (else None)
213
+ - logs: current logs string
214
+ - progress: current progress value (0 to 100)
215
+ """
216
+ return result["original_video"], result["stabilized_video"], status["logs"], status["progress"]
217
 
218
  # Build the Gradio UI.
219
  with gr.Blocks() as demo:
220
  gr.Markdown("# AI-Powered Video Stabilization")
221
+ gr.Markdown("Upload a video and select a zoom factor. Click **Process Video** to start processing in the background. Then click **Refresh Status** to update the logs and progress (once processing finishes, the stabilized video will be shown).")
222
+
223
  with gr.Row():
224
  with gr.Column():
225
  video_input = gr.Video(label="Input Video")
226
  zoom_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Zoom Factor")
227
+ start_button = gr.Button("Process Video")
228
  with gr.Column():
229
  original_video = gr.Video(label="Original Video")
230
  stabilized_video = gr.Video(label="Stabilized Video")
231
  logs_output = gr.Textbox(label="Logs", lines=15)
232
  progress_bar = gr.Slider(label="Progress", minimum=0, maximum=100, value=0, interactive=False)
233
+ refresh_button = gr.Button("Refresh Status")
234
+
235
+ # When "Process Video" is clicked, start processing.
236
+ start_button.click(fn=start_processing, inputs=[video_input, zoom_slider], outputs=[logs_output])
237
+ # When "Refresh Status" is clicked, update logs, progress, and videos.
238
+ refresh_button.click(fn=poll_status, inputs=[], outputs=[original_video, stabilized_video, logs_output, progress_bar])
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  demo.launch()