KingNish commited on
Commit
9248f9f
1 Parent(s): 6e8455d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -30
app.py CHANGED
@@ -11,18 +11,24 @@ import numpy as np
11
  import os
12
  import tempfile
13
  import uuid
14
- import schedule
15
  import time
16
- import shutil
17
 
18
  torch.set_float32_matmul_precision("medium")
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
 
22
  birefnet = AutoModelForImageSegmentation.from_pretrained(
23
  "ZhengPeng7/BiRefNet", trust_remote_code=True
24
  )
25
  birefnet.to(device)
 
 
 
 
 
 
26
  transform_image = transforms.Compose(
27
  [
28
  transforms.Resize((1024, 1024)),
@@ -32,8 +38,31 @@ transform_image = transforms.Compose(
32
  )
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @spaces.GPU
36
- def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down"):
37
  try:
38
  # Load the video using moviepy
39
  video = mp.VideoFileClip(vid)
@@ -68,20 +97,20 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
68
  for i, frame in enumerate(frames):
69
  pil_image = Image.fromarray(frame)
70
  if bg_type == "Color":
71
- processed_image = process(pil_image, color)
72
  elif bg_type == "Image":
73
- processed_image = process(pil_image, bg_image)
74
  elif bg_type == "Video":
75
  if video_handling == "slow_down":
76
  background_frame = background_frames[bg_frame_index % len(background_frames)]
77
  bg_frame_index += 1
78
  background_image = Image.fromarray(background_frame)
79
- processed_image = process(pil_image, background_image)
80
  else: # video_handling == "loop"
81
  background_frame = background_frames[bg_frame_index % len(background_frames)]
82
  bg_frame_index += 1
83
  background_image = Image.fromarray(background_frame)
84
- processed_image = process(pil_image, background_image)
85
  else:
86
  processed_image = pil_image # Default to original image if no background is selected
87
 
@@ -111,13 +140,16 @@ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=
111
  yield None, f"Error processing video: {e}"
112
 
113
 
114
-
115
- def process(image, bg):
116
  image_size = image.size
117
  input_images = transform_image(image).unsqueeze(0).to("cuda")
 
 
 
 
118
  # Prediction
119
  with torch.no_grad():
120
- preds = birefnet(input_images)[-1].sigmoid().cpu()
121
  pred = preds[0].squeeze()
122
  pred_pil = transforms.ToPILImage()(pred)
123
  mask = pred_pil.resize(image_size)
@@ -135,18 +167,6 @@ def process(image, bg):
135
 
136
  return image
137
 
138
- def clear_temp_directory():
139
- temp_dir = "temp"
140
- for filename in os.listdir(temp_dir):
141
- file_path = os.path.join(temp_dir, filename)
142
- try:
143
- if os.path.isfile(file_path) or os.path.islink(file_path):
144
- os.unlink(file_path)
145
- elif os.path.isdir(file_path):
146
- shutil.rmtree(file_path)
147
- except Exception as e:
148
- print('Failed to delete %s. Reason: %s' % (file_path, e)) # Keep this print statement for debugging purposes
149
-
150
 
151
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
152
  gr.Markdown("# Video Background Remover & Changer\n### You can replace image background with any color, image or video.\nNOTE: As this Space is running on ZERO GPU it has limit. It can handle approx 200frmaes at once. So, if you have big video than use small chunks or Duplicate this space.")
@@ -170,6 +190,8 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
170
  bg_video = gr.Video(label="Background Video", visible=False, interactive=True)
171
  with gr.Column(visible=False) as video_handling_options:
172
  video_handling_radio = gr.Radio(["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True)
 
 
173
 
174
  def update_visibility(bg_type):
175
  if bg_type == "Color":
@@ -201,15 +223,9 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
201
 
202
  submit_button.click(
203
  fn,
204
- inputs=[in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling_radio],
205
  outputs=[stream_image, out_video],
206
  )
207
 
208
  if __name__ == "__main__":
209
- demo.launch(show_error=True)
210
-
211
- schedule.every(10).minutes.do(clear_temp_directory)
212
-
213
- while True:
214
- schedule.run_pending()
215
- time.sleep(1)
 
11
  import os
12
  import tempfile
13
  import uuid
 
14
  import time
15
+ import threading
16
 
17
  torch.set_float32_matmul_precision("medium")
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # Load both BiRefNet models
22
  birefnet = AutoModelForImageSegmentation.from_pretrained(
23
  "ZhengPeng7/BiRefNet", trust_remote_code=True
24
  )
25
  birefnet.to(device)
26
+
27
+ birefnet_lite = AutoModelForImageSegmentation.from_pretrained(
28
+ "ZhengPeng7/BiRefNet_lite", trust_remote_code=True
29
+ )
30
+ birefnet_lite.to(device)
31
+
32
  transform_image = transforms.Compose(
33
  [
34
  transforms.Resize((1024, 1024)),
 
38
  )
39
 
40
 
41
+ # Function to delete files older than 10 minutes in the temp directory
42
+ def cleanup_temp_files():
43
+ while True:
44
+ temp_dir = "temp"
45
+ if os.path.exists(temp_dir):
46
+ for filename in os.listdir(temp_dir):
47
+ filepath = os.path.join(temp_dir, filename)
48
+ if os.path.isfile(filepath):
49
+ file_age = time.time() - os.path.getmtime(filepath)
50
+ if file_age > 600: # 10 minutes in seconds
51
+ try:
52
+ os.remove(filepath)
53
+ print(f"Deleted temporary file: {filepath}")
54
+ except Exception as e:
55
+ print(f"Error deleting file {filepath}: {e}")
56
+ time.sleep(60) # Check every minute
57
+
58
+
59
+ # Start the cleanup thread
60
+ cleanup_thread = threading.Thread(target=cleanup_temp_files, daemon=True)
61
+ cleanup_thread.start()
62
+
63
+
64
  @spaces.GPU
65
+ def fn(vid, bg_type="Color", bg_image=None, bg_video=None, color="#00FF00", fps=0, video_handling="slow_down", fast_mode=False):
66
  try:
67
  # Load the video using moviepy
68
  video = mp.VideoFileClip(vid)
 
97
  for i, frame in enumerate(frames):
98
  pil_image = Image.fromarray(frame)
99
  if bg_type == "Color":
100
+ processed_image = process(pil_image, color, fast_mode)
101
  elif bg_type == "Image":
102
+ processed_image = process(pil_image, bg_image, fast_mode)
103
  elif bg_type == "Video":
104
  if video_handling == "slow_down":
105
  background_frame = background_frames[bg_frame_index % len(background_frames)]
106
  bg_frame_index += 1
107
  background_image = Image.fromarray(background_frame)
108
+ processed_image = process(pil_image, background_image, fast_mode)
109
  else: # video_handling == "loop"
110
  background_frame = background_frames[bg_frame_index % len(background_frames)]
111
  bg_frame_index += 1
112
  background_image = Image.fromarray(background_frame)
113
+ processed_image = process(pil_image, background_image, fast_mode)
114
  else:
115
  processed_image = pil_image # Default to original image if no background is selected
116
 
 
140
  yield None, f"Error processing video: {e}"
141
 
142
 
143
+ def process(image, bg, fast_mode=False):
 
144
  image_size = image.size
145
  input_images = transform_image(image).unsqueeze(0).to("cuda")
146
+
147
+ # Select the model based on fast_mode
148
+ model = birefnet_lite if fast_mode else birefnet
149
+
150
  # Prediction
151
  with torch.no_grad():
152
+ preds = model(input_images)[-1].sigmoid().cpu()
153
  pred = preds[0].squeeze()
154
  pred_pil = transforms.ToPILImage()(pred)
155
  mask = pred_pil.resize(image_size)
 
167
 
168
  return image
169
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
172
  gr.Markdown("# Video Background Remover & Changer\n### You can replace image background with any color, image or video.\nNOTE: As this Space is running on ZERO GPU it has limit. It can handle approx 200frmaes at once. So, if you have big video than use small chunks or Duplicate this space.")
 
190
  bg_video = gr.Video(label="Background Video", visible=False, interactive=True)
191
  with gr.Column(visible=False) as video_handling_options:
192
  video_handling_radio = gr.Radio(["slow_down", "loop"], label="Video Handling", value="slow_down", interactive=True)
193
+ fast_mode_checkbox = gr.Checkbox(label="Fast Mode (Use BiRefNet_lite)", value=False, interactive=True)
194
+
195
 
196
  def update_visibility(bg_type):
197
  if bg_type == "Color":
 
223
 
224
  submit_button.click(
225
  fn,
226
+ inputs=[in_video, bg_type, bg_image, bg_video, color_picker, fps_slider, video_handling_radio, fast_mode_checkbox],
227
  outputs=[stream_image, out_video],
228
  )
229
 
230
  if __name__ == "__main__":
231
+ demo.launch(show_error=True)